Batched Expert Forward Pass with Einops
A naive implementation of an MoE layer might involve a loop over the experts. This is inefficient. A much better approach is to perform a single, batched matrix multiplication for all expert computations. einops is an excellent tool for this kind of reshaping.
Task:
Let's assume you have already dispatched your tokens and have a tensor of shape (num_experts, tokens_per_expert, embedding_dim). Your experts are simple linear layers.
Write a function batched_expert_forward(dispatched_tokens, expert_weights) that performs the forward pass for all experts simultaneously.
Given:
dispatched_tokens: A tensor of shape(num_experts, tokens_per_expert, embedding_dim).expert_weights: A tensor representing the weights of all your linear experts, with a shape like(num_experts, embedding_dim, hidden_dim).
Your function should:
- Use
einopsto reshape the input tensors. - Use
torch.einsumor batchedtorch.matmulto perform the forward pass for all experts in a single operation. - Return a tensor of expert outputs of shape
(num_experts, tokens_per_expert, hidden_dim).
Verification:
Compare your batched implementation's output and performance against a naive loop-based implementation.
import torch
from einops import rearrange
import time
# Your function here
def batched_expert_forward(dispatched_tokens, expert_weights):
# ... your implementation ...
pass
def naive_expert_forward(dispatched_tokens, experts):
outputs = []
for i in range(dispatched_tokens.shape):
outputs.append(experts[i](dispatched_tokens[i]))
return torch.stack(outputs)
# --- Verification ---
num_experts = 8
tokens_per_expert = 128
embedding_dim = 64
hidden_dim = 128
dispatched_tokens = torch.randn(num_experts, tokens_per_expert, embedding_dim)
# Batched implementation setup
expert_weights = torch.randn(num_experts, embedding_dim, hidden_dim)
# Naive implementation setup
experts = [torch.nn.Linear(embedding_dim, hidden_dim, bias=False) for _ in range(num_experts)]
with torch.no_grad():
for i in range(num_experts):
experts[i].weight.copy_(expert_weights[i].T) # .T because of Linear layer's weight shape
# 1. Check for correctness
batched_output = batched_expert_forward(dispatched_tokens, expert_weights)
naive_output = naive_expert_forward(dispatched_tokens, experts)
assert torch.allclose(batched_output, naive_output, atol=1e-6)
# 2. Compare performance
start_time = time.time()
for _ in range(100):
batched_expert_forward(dispatched_tokens, expert_weights)
print(f"Batched implementation took: {time.time() - start_time:.4f}s")
start_time = time.time()
for _ in range(100):
naive_expert_forward(dispatched_tokens, experts)
print(f"Naive implementation took: {time.time() - start_time:.4f}s")