ML Katas

Implement a Basic Optimizer with Optax

medium (<30 mins) optimization jax optax sgd
yesterday by E

Use Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use optax.sgd to update the model's parameters.

Task: Set up a training step function that takes the current parameters, a batch of data, and the optimizer state, and returns the updated parameters and optimizer state.

Verification: - The loss should decrease over a few training steps on a simple problem (e.g., linear regression).