ML Katas

Understanding `jit` and tracing

easy (<10 mins) jit tracing
this year by E

When you jit a function, JAX traces it to determine its computational graph. This graph is then compiled by XLA for efficient execution. However, this tracing mechanism has some implications. For example, any Python side-effects, like printing, will only be executed once during the tracing and not on subsequent calls to the jit'ed function.

In this exercise, you will explore this behavior. Create a simple jit'ed function that prints its input and returns it. Call the function twice with the same input. What do you observe? What happens if you call it with an input of a different shape?