Vectorized Operations with vmap
You have a function that processes a single data point. Your goal is to use jax.vmap to apply this function to a whole batch of data without writing an explicit loop. For example, consider a function that computes the dot product of a vector with a fixed weight matrix.
import jax
import jax.numpy as jnp
def process_item(item, weights):
return jnp.dot(item, weights)
Task:
Vectorize process_item to handle a batch of items with shape (batch_size, num_features) and a single weights matrix of shape (num_features, output_features).
Verification:
- The output of your vectorized function should have the shape (batch_size, output_features).
- The results should be identical to what you would get by iterating through the batch in a Python loop.