ML Katas

Vectorizing with `vmap`

easy (<30 mins) vmap batching
this year by E

Another of JAX's powerful features is its ability to automatically vectorize functions with vmap. For example, vmap can be used to automatically batch a function that was written for a single example. In this exercise, you will use vmap to batch the affine transformation function.

  1. Take the function y=Wx+b that you implemented in NumPy to JAX: The Basics.
  2. Use vmap to create a new function that can process a batch of x's, where each x is a vector of shape (10,). Your W should have shape (20, 10) and your b should have shape (20,). You will have to use the in_axes argument to specify that you only want to batch over x.
  3. Instantiate a batch of 5 x's. This means that your x should have shape (5, 10).
  4. Call your new batched function and verify that the output has the correct shape. What should it be?