Working with an optimizer library: Optax
Optax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular.
In this exercise, you will use Optax to train the Flax MLP from the previous exercise.
- Choose an optimizer from Optax, for example,
optax.adam. - In your training step, use the optimizer to compute the updates to the model's parameters.
- Update the
TrainStatewith the new parameters and optimizer state.