Custom Gradient with `jax.custom_vjp`
Implement a function with a custom gradient using jax.custom_vjp
. This is useful for numerical stability or for defining gradients for non-differentiable operations. A good example is a function that is the identity in the forward pass but has a clipped gradient in the backward pass.
Task:
Create a function that behaves like jnp.clip
in the backward pass, but is the identity function in the forward pass.
Verification: - When you take the gradient of a function that uses your custom function, the gradients should be clipped within the specified range. [46]