-
Working with an optimizer library: Optax
Optax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular. In this exercise, you will use Optax to train the Flax MLP from the...
-
Linear Regression via Gradient Descent
Linear regression is a foundational supervised learning algorithm. Given a dataset of input features $X$ and corresponding target values $y$, the goal is to find a linear relationship $y =...
-
Matrix Multiplication and Efficiency
Matrix multiplication is a fundamental operation in linear algebra and a cornerstone of deep learning. Given two matrices $A$ (size $m \times k$) and $B$ (size $k \times n$), their product $C =...
-
Implement a Basic Optimizer with Optax
Use Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use `optax.sgd` to update the...
-
Manual Gradient Descent Step
Simulate one step of gradient descent for a simple quadratic loss. ### Problem Given a scalar parameter $w$ initialized at 5.0, minimize the loss $L(w) = (w - 3)^2$ using PyTorch. - **Input:**...