ML Katas

Updating model parameters

medium (<30 mins) jit performance grad autodiff
this year by E

In this exercise, you will implement a full training step for the regression problem that you have been working on.

  1. Instantiate your model parameters, W and b, and your data x and y_true. Remember to use a PRNG key.
  2. Compute the gradients of your loss function with respect to your model parameters.
  3. Update your parameters using gradient descent. That is, Wnew=WoldαWL and bnew=boldαbL for some learning rate α. You can set α=0.01.
  4. Wrap this all in a function that you jit compile. How much of a speedup do you see?