ML Katas

Working with a NN library: Flax

medium (<1 hr) flax MLP TrainState
this year by E

While it is possible to build neural networks from scratch in JAX, it is often more convenient to use a library like Flax or Haiku. These libraries provide common neural network layers and utilities for managing model parameters.

In this exercise, you will re-implement the simple MLP from the A simple MLP exercise, but this time using Flax.

  1. Define your MLP using flax.linen.
  2. Initialize the model's parameters.
  3. Create a TrainState to hold the model's parameters, optimizer state, and any other training-related variables.
  4. Implement a jit'ed training step and a training loop.
  5. Plot the loss as a function of the training step.