Data Parallelism with `pmap`
Parallelize a training step across multiple devices (e.g., multiple CPU cores if you don't have GPUs/TPUs) using jax.pmap
. This is a fundamental technique for scaling up training.
Task:
Take a simple training step function (like the one from the linear regression exercise) and modify it to run in parallel on multiple devices. This will involve sharding the data and averaging the gradients across devices using jax.lax.pmean
.
Verification: - Your code should run without errors on a machine with multiple logical devices. - The training progress should be similar to the single-device version.