ML Katas

The need for speed: `jit`

easy (<10 mins) jit performance
this year by E

JAX's jit function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter.

In this exercise, you must take the function that you wrote in the NumPy to JAX: The Basics exercise and jit compile it. You can do so by using jax.jit as a decorator or by calling it directly on your function.

Then, you should benchmark the performance of the jit'ed function and the raw Python function. You can use Python's timeit for this. What kind of a speedup do you get?