- 
                
                    Taking gradients with `grad`One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use `jax.grad` to compute gradients. 1.... 
- 
                
                    Taking gradients with `grad` IIBy default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's... 
- 
                
                    Updating model parametersIn this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and... 
- 
                
                    A simple MLPNow that you have all of the basic building blocks, it's time to build a simple multi-layer perceptron (MLP). In this exercise, you are to build a 2-layer MLP for a regression problem. 1.... 
- 
                
                    A simple CNNIn this exercise, you will implement a simple convolutional neural network (CNN) for a regression problem. You can use `jax.lax.conv_general_dilated` to implement the convolution. 1. Implement a... 
- 
                
                    Forward-mode vs. Reverse-mode autodiffJAX supports both forward-mode and reverse-mode automatic differentiation. While `grad` uses reverse-mode, you can use `jax.jvp` for forward-mode, which computes Jacobian-vector products.... 
- 
                
                    Custom VJPFor some functions, you may want to define a custom vector-Jacobian product (VJP). This can be useful for numerical stability or for implementing algorithms that are not easily expressed in terms...