-
Working with an optimizer library: Optax
Optax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular. In this exercise, you will use Optax to train the Flax MLP from the...
-
Implement a Basic Optimizer with Optax
Use Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use `optax.sgd` to update the...
1