ML Katas

Combine `vmap` and `pmap`

hard (<30 mins) jax vmap pmap parallelism
yesterday by E

For more complex parallelism patterns, you can combine vmap and pmap. For instance, you can use pmap for data parallelism across devices, and vmap for model ensembling on each device.

Task: Write a function that you first vectorize with vmap and then parallelize the vectorized function with pmap.

Verification: - Your code should execute without errors. - The output shape should reflect both the vectorization and the parallelization.