-
The need for speed: `jit`
JAX's `jit` function will compile your Python code, which can lead to significant speedups. This is because JAX can fuse operations together, removing the overhead of Python's interpreter. In this...
-
Updating model parameters
In this exercise, you will implement a full training step for the regression problem that you have been working on. 1. Instantiate your model parameters, `W` and `b`, and your data `x` and...
-
Batched Expert Forward Pass with Einops
A naive implementation of an MoE layer might involve a loop over the experts. This is inefficient. A much better approach is to perform a single, batched matrix multiplication for all expert...
1