ML Katas

Batched Expert Forward Pass with Einops

medium (<1 hr) pytorch einops performance
this year by E

A naive implementation of an MoE layer might involve a loop over the experts. This is inefficient. A much better approach is to perform a single, batched matrix multiplication for all expert computations. einops is an excellent tool for this kind of reshaping.

Task:

Let's assume you have already dispatched your tokens and have a tensor of shape (num_experts, tokens_per_expert, embedding_dim). Your experts are simple linear layers.

Write a function batched_expert_forward(dispatched_tokens, expert_weights) that performs the forward pass for all experts simultaneously.

Given:

  • dispatched_tokens: A tensor of shape (num_experts, tokens_per_expert, embedding_dim).
  • expert_weights: A tensor representing the weights of all your linear experts, with a shape like (num_experts, embedding_dim, hidden_dim).

Your function should:

  1. Use einops to reshape the input tensors.
  2. Use torch.einsum or batched torch.matmul to perform the forward pass for all experts in a single operation.
  3. Return a tensor of expert outputs of shape (num_experts, tokens_per_expert, hidden_dim).

Verification:

Compare your batched implementation's output and performance against a naive loop-based implementation.

import torch
from einops import rearrange
import time

# Your function here
def batched_expert_forward(dispatched_tokens, expert_weights):
    # ... your implementation ...
    pass

def naive_expert_forward(dispatched_tokens, experts):
    outputs = []
    for i in range(dispatched_tokens.shape):
        outputs.append(experts[i](dispatched_tokens[i]))
    return torch.stack(outputs)

# --- Verification ---
num_experts = 8
tokens_per_expert = 128
embedding_dim = 64
hidden_dim = 128

dispatched_tokens = torch.randn(num_experts, tokens_per_expert, embedding_dim)

# Batched implementation setup
expert_weights = torch.randn(num_experts, embedding_dim, hidden_dim)

# Naive implementation setup
experts = [torch.nn.Linear(embedding_dim, hidden_dim, bias=False) for _ in range(num_experts)]
with torch.no_grad():
    for i in range(num_experts):
        experts[i].weight.copy_(expert_weights[i].T) # .T because of Linear layer's weight shape

# 1. Check for correctness
batched_output = batched_expert_forward(dispatched_tokens, expert_weights)
naive_output = naive_expert_forward(dispatched_tokens, experts)

assert torch.allclose(batched_output, naive_output, atol=1e-6)

# 2. Compare performance
start_time = time.time()
for _ in range(100):
    batched_expert_forward(dispatched_tokens, expert_weights)
print(f"Batched implementation took: {time.time() - start_time:.4f}s")

start_time = time.time()
for _ in range(100):
    naive_expert_forward(dispatched_tokens, experts)
print(f"Naive implementation took: {time.time() - start_time:.4f}s")