Custom VJP
For some functions, you may want to define a custom vector-Jacobian product (VJP). This can be useful for numerical stability or for implementing algorithms that are not easily expressed in terms of standard JAX primitives.
In this exercise, you will implement a custom VJP for the function . This is a common pattern for avoiding numerical overflow when x is large.
- Define the function and its custom VJP using
@jax.custom_vjp. The forward pass should compute the function's value. The backward pass should take the incoming gradient and return the gradient with respect to the input. - Compare the numerical stability of your custom implementation with the standard JAX implementation for large values of
x.