Taking gradients with `grad` II
By default, jax.grad will take the gradient with respect to the first argument of the function. However, in many cases, we will want to take gradients with respect to many of the function's arguments.
In this exercise, you will create a function that takes the gradient of the loss that you implemented in Taking gradients with grad with respect to W and b. You can use the argnums argument to jax.grad to do this.
As before, verify that the shape of the gradients are as you would expect.