ML Katas

Vectorized Operations with vmap

easy (<10 mins) jax vmap vectorization batching
yesterday by E

You have a function that processes a single data point. Your goal is to use jax.vmap to apply this function to a whole batch of data without writing an explicit loop. For example, consider a function that computes the dot product of a vector with a fixed weight matrix.

import jax
import jax.numpy as jnp

def process_item(item, weights):
  return jnp.dot(item, weights)

Task: Vectorize process_item to handle a batch of items with shape (batch_size, num_features) and a single weights matrix of shape (num_features, output_features).

Verification: - The output of your vectorized function should have the shape (batch_size, output_features). - The results should be identical to what you would get by iterating through the batch in a Python loop.