ML Katas

NumPy to JAX: The-Basics

easy (<10 mins) jax.numpy jax.random
this year by E

This first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API.

  1. Create a function that takes two jnp.ndarray's, W and x, as well as a jnp.ndarray b, and implements the affine transformation y=Wx+b.
  2. Create a random key with jax.random.key.
  3. Instantiate a W matrix of shape (20, 10) and a vector x of shape (10,) using your random key. Note that you will have to split the key to get two different random arrays.
  4. Instantiate a vector b of shape (20,) as a vector of ones.
  5. Verify that the shapes of the inputs and outputs are what you would expect. For example, y should have a shape of (20,).