Working with a NN library: Flax
While it is possible to build neural networks from scratch in JAX, it is often more convenient to use a library like Flax or Haiku. These libraries provide common neural network layers and utilities for managing model parameters.
In this exercise, you will re-implement the simple MLP from the A simple MLP exercise, but this time using Flax. 
- Define your MLP using flax.linen.
- Initialize the model's parameters.
- Create a TrainStateto hold the model's parameters, optimizer state, and any other training-related variables.
- Implement a jit'ed training step and a training loop.
- Plot the loss as a function of the training step.