Debugging JAX code with `jax.debug.print`
In JAX, standard Python print statements don't always work as expected within jit-compiled functions because they execute at trace time. [11] The solution is to use jax.debug.print. [11, 23]
Task:
Write a jit-compiled function with a standard print and a jax.debug.print. Observe the difference in when they are executed. Then, use jax.debug.print to inspect the intermediate values of a simple computation.
Verification:
- The standard print should only execute once when the function is first compiled.
- jax.debug.print should execute every time the compiled function is called.