ML Katas

The JAX approach to PRNG

easy (<10 mins) jax.random state
this year by E

JAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that makes code easier to reproduce.

In this exercise, you are tasked with creating a function that takes a JAX PRNG key and a shape, and returns two random arrays of that shape, x_1 and x_2 as well as a new PRNG key that is different from the one that was passed in.

To do so, you can use the function jax.random.split. Verify your implementation by checking that all of the following conditions are true: 1. The shape of x_1 and x_2 is the one that was passed in as an argument to your function. 2. x_1 and x_2 are not identical. 3. The new key and the old key are not identical.