Working with a NN library: Flax
While it is possible to build neural networks from scratch in JAX, it is often more convenient to use a library like Flax or Haiku. These libraries provide common neural network layers and utilities for managing model parameters.
In this exercise, you will re-implement the simple MLP from the A simple MLP exercise, but this time using Flax.
- Define your MLP using
flax.linen. - Initialize the model's parameters.
- Create a
TrainStateto hold the model's parameters, optimizer state, and any other training-related variables. - Implement a
jit'ed training step and a training loop. - Plot the loss as a function of the training step.