ML Katas

Parallelization with `pmap`

hard (>1 hr) pmap parallelization
this year by E

For large models, it is often necessary to train on multiple devices (e.g., GPUs or TPUs). JAX's pmap transformation allows for easy parallelization of computations across devices.

In this exercise, you will use pmap to train your MLP on multiple devices. 1. Modify your training script to handle multiple devices. You will need to replicate your model's parameters across devices. 2. Use pmap to parallelize your training step. 3. Make sure to average the loss and gradients across devices.

Note: This exercise requires access to multiple devices. If you are using a Colab notebook, you can select a TPU environment.