ML Katas

Forward-mode vs. Reverse-mode autodiff

hard (>1 hr) vjp autodiff jvp
this year by E

JAX supports both forward-mode and reverse-mode automatic differentiation. While grad uses reverse-mode, you can use jax.jvp for forward-mode, which computes Jacobian-vector products.

Forward-mode is generally more efficient for functions with more outputs than inputs, while reverse-mode is more efficient for functions with more inputs than outputs.

In this exercise, you will compare the performance of both modes. 1. Create a function that takes a vector of size 100 and returns a vector of size 10. 2. Create another function that does the opposite. 3. Compute the full Jacobian for both functions using both jax.jacfwd (forward-mode) and jax.jacrev (reverse-mode). 4. Benchmark the four computations. What do you observe?