ML Katas

Custom Gradient with `jax.custom_vjp`

hard (<30 mins) autograd jax custom-gradient vjp
yesterday by E

Implement a function with a custom gradient using jax.custom_vjp. This is useful for numerical stability or for defining gradients for non-differentiable operations. A good example is a function that is the identity in the forward pass but has a clipped gradient in the backward pass.

Task: Create a function that behaves like jnp.clip in the backward pass, but is the identity function in the forward pass.

Verification: - When you take the gradient of a function that uses your custom function, the gradients should be clipped within the specified range. [46]