NumPy to JAX: The-Basics
This first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API.
- Create a function that takes two
jnp.ndarray's,Wandx, as well as ajnp.ndarrayb, and implements the affine transformation . - Create a random key with
jax.random.key. - Instantiate a
Wmatrix of shape (20, 10) and a vectorxof shape (10,) using your random key. Note that you will have to split the key to get two different random arrays. - Instantiate a vector
bof shape (20,) as a vector of ones. - Verify that the shapes of the inputs and outputs are what you would expect. For example,
yshould have a shape of (20,).