Vectorized Operations with vmap
You have a function that processes a single data point. Your goal is to use jax.vmap
to apply this function to a whole batch of data without writing an explicit loop. For example, consider a function that computes the dot product of a vector with a fixed weight matrix.
import jax
import jax.numpy as jnp
def process_item(item, weights):
return jnp.dot(item, weights)
Task:
Vectorize process_item
to handle a batch of items
with shape (batch_size, num_features)
and a single weights
matrix of shape (num_features, output_features)
.
Verification:
- The output of your vectorized function should have the shape (batch_size, output_features)
.
- The results should be identical to what you would get by iterating through the batch in a Python loop.