ML Katas

Working with an optimizer library: Optax

medium (<30 mins) optimization optax
this year by E

Optax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular.

In this exercise, you will use Optax to train the Flax MLP from the previous exercise.

  1. Choose an optimizer from Optax, for example, optax.adam.
  2. In your training step, use the optimizer to compute the updates to the model's parameters.
  3. Update the TrainState with the new parameters and optimizer state.