Combine `vmap` and `pmap`
For more complex parallelism patterns, you can combine vmap
and pmap
. For instance, you can use pmap
for data parallelism across devices, and vmap
for model ensembling on each device.
Task:
Write a function that you first vectorize with vmap
and then parallelize the vectorized function with pmap
.
Verification: - Your code should execute without errors. - The output shape should reflect both the vectorization and the parallelization.