-
Forward-mode vs. Reverse-mode autodiff
JAX supports both forward-mode and reverse-mode automatic differentiation. While `grad` uses reverse-mode, you can use `jax.jvp` for forward-mode, which computes Jacobian-vector products....
-
Custom Gradient with `jax.custom_vjp`
Implement a function with a custom gradient using `jax.custom_vjp`. This is useful for numerical stability or for defining gradients for non-differentiable operations. A good example is a function...
1