-
Data Parallelism with `pmap`
Parallelize a training step across multiple devices (e.g., multiple CPU cores if you don't have GPUs/TPUs) using `jax.pmap`. This is a fundamental technique for scaling up training. **Task:** Take...
-
Combine `vmap` and `pmap`
For more complex parallelism patterns, you can combine `vmap` and `pmap`. For instance, you can use `pmap` for data parallelism across devices, and `vmap` for model ensembling on each device....
1