Implement a Basic Optimizer with Optax
Use Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use optax.sgd
to update the model's parameters.
Task: Set up a training step function that takes the current parameters, a batch of data, and the optimizer state, and returns the updated parameters and optimizer state.
Verification: - The loss should decrease over a few training steps on a simple problem (e.g., linear regression).