-
Debugging JAX code with `jax.debug.print`
In JAX, standard Python `print` statements don't always work as expected within `jit`-compiled functions because they execute at trace time. [11] The solution is to use `jax.debug.print`. [11, 23]...
-
Debug Exploding Gradients
Create a deep feedforward net (20 layers, ReLU). Train it on dummy data. Track gradient norms across layers. Observe if gradients explode. Experiment with: - Smaller learning rate. - Gradient...
1