Vectorizing with `vmap`
Another of JAX's powerful features is its ability to automatically vectorize functions with vmap. For example, vmap can be used to automatically batch a function that was written for a single example. In this exercise, you will use vmap to batch the affine transformation function.
- Take the function that you implemented in
NumPy to JAX: The Basics. - Use
vmapto create a new function that can process a batch ofx's, where eachxis a vector of shape (10,). YourWshould have shape (20, 10) and yourbshould have shape (20,). You will have to use thein_axesargument to specify that you only want to batch overx. - Instantiate a batch of 5
x's. This means that yourxshould have shape (5, 10). - Call your new batched function and verify that the output has the correct shape. What should it be?