-
Taking gradients with `grad`
One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use `jax.grad` to compute gradients. 1....
-
Taking gradients with `grad` II
By default, `jax.grad` will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's...
-
Updating model parameters
In this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and...
-
A simple MLP
Now that you have all of the basic building blocks, it's time to build a simple multi-layer perceptron (MLP). In this exercise, you are to build a 2-layer MLP for a regression problem. 1....