ML Katas

Conditionals with `jit`

medium (<30 mins) jit control flow
this year by E

Standard Python control flow, like if statements, can cause issues with jit when the condition depends on a traced value. This is because JAX needs to know the entire computation graph at compile time.

To handle this, JAX provides structured control flow primitives like jax.lax.cond. In this exercise, you are to write a jit'ed function that takes an integer x and returns x / 2 if x is even and 3 * x + 1 otherwise.

Verify that your function works as expected. What happens if you try to use a standard Python if statement? Why?