- 
                
                    Working with an optimizer library: OptaxOptax is a popular library for optimization in JAX. It provides a wide range of optimizers and is designed to be highly modular. In this exercise, you will use Optax to train the Flax MLP from the... 
- 
                
                    Matrix Multiplication and EfficiencyMatrix multiplication is a fundamental operation in linear algebra and a cornerstone of deep learning. Given two matrices $A$ (size $m \times k$) and $B$ (size $k \times n$), their product $C =... 
- 
                
                    Implement a Simple Linear Regression in JAXYour task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule.... 
- 
                
                    Implement a Basic Optimizer with OptaxUse Optax, JAX's optimization library, to create a simple Stochastic Gradient Descent (SGD) optimizer. You'll need to define a model, a loss function, and then use `optax.sgd` to update the... 
- 
                
                    Manual Gradient Descent StepSimulate one step of gradient descent for a simple quadratic loss. ### Problem Given a scalar parameter $w$ initialized at 5.0, minimize the loss $L(w) = (w - 3)^2$ using PyTorch. - **Input:**...