ML Katas

Taking gradients with `grad`

easy (<30 mins) grad autodiff
this year by E

One of JAX's most powerful features is its ability to automatically differentiate your code. In this exercise, you will implement a regression problem and use jax.grad to compute gradients.

  1. Define a function loss(W, x, b, y_true) which computes the mean-squared error of an affine transformation, i.e., L(W,x,b,ytrue)=12Ni=1N(ytrue(Wx+b))i2. Note that you can reuse the function that you created in the NumPy to JAX: The Basics exercise.
  2. Create a function loss_grad = jax.grad(loss).
  3. Instantiate some random data. You can use the shapes that you were provided in the first exercise. y_true should have a shape of (20,).
  4. Compute the gradient and verify that it has the correct shape. What do you think that shape should be?