ML Katas

Debugging JAX code with `jax.debug.print`

easy (<10 mins) debugging jax jit
this year by E

In JAX, standard Python print statements don't always work as expected within jit-compiled functions because they execute at trace time. [11] The solution is to use jax.debug.print. [11, 23]

Task: Write a jit-compiled function with a standard print and a jax.debug.print. Observe the difference in when they are executed. Then, use jax.debug.print to inspect the intermediate values of a simple computation.

Verification: - The standard print should only execute once when the function is first compiled. - jax.debug.print should execute every time the compiled function is called.