Taking gradients with `grad`
One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use jax.grad to compute gradients.
- Define a function
loss(W, x, b, y_true)which computes the mean-squared error of an affine transformation, i.e., . Note that you can reuse the function that you created in theNumPy to JAX: The Basicsexercise. - Create a function
loss_grad = jax.grad(loss). - Instantiate some random data. You can use the shapes that you were provided in the first exercise.
y_trueshould have a shape of (20,). - Compute the gradient and verify that it has the correct shape. What do you think that shape should be?