ML Katas

Loops with `jit`

medium (<30 mins) jit control flow
this year by E

Similar to conditionals, standard Python for or while loops can cause problems with jit if the loop's duration depends on a traced value. JAX provides jax.lax.fori_loop and jax.lax.while_loop for these situations.

In this exercise, you will implement the Collatz conjecture. Write a jit'ed function that takes an integer n and returns the number of steps it takes to reach 1. You should use jax.lax.while_loop.

Test your function with a few values of n.