ML Katas

Building a Simple Mixture of Experts (MoE) Layer

hard (>1 hr) pytorch deep-learning moe
this year by E

Now, let's combine the concepts of dispatching and aggregating into a full, albeit simplified, torch.nn.Module for a Mixture of Experts layer. This layer will replace a standard feed-forward network in a Transformer block.

Task:

Create a SimpleMoELayer class that inherits from torch.nn.Module.

It should have:

  • An __init__ method that initializes:
    • A gating_network (torch.nn.Linear).
    • A list or torch.nn.ModuleList of experts. Each expert can be a simple two-layer MLP.
  • A forward method that:
    1. Takes an input tensor x of shape (batch_size, seq_len, embedding_dim).
    2. Uses the gating network to get top_k expert indices and weights for each token.
    3. Dispatches the tokens to the correct experts. A key challenge is to efficiently map the flattened tokens to their respective experts for computation.
    4. Processes the dispatched tokens through their assigned experts.
    5. Aggregates the outputs from the experts based on the gating weights.
    6. Returns the final tensor of shape (batch_size, seq_len, embedding_dim).

Bonus: Implement a load balancing loss, a crucial component for training real MoE models. The loss encourages the gating network to distribute tokens evenly across all experts. A common formulation is:

Llb=α·i=1NfiPi

Where N is the number of experts, fi is the fraction of tokens dispatched to expert i, and Pi is the average router probability for expert i. α is a scaling hyperparameter.

Verification:

import torch
import torch.nn as nn

class SimpleMoELayer(nn.Module):
    # ... your implementation ...
    pass

# --- Verification ---
batch_size = 4
seq_len = 16
embedding_dim = 32
num_experts = 8
top_k = 2

moe_layer = SimpleMoELayer(embedding_dim, num_experts, top_k)
x = torch.randn(batch_size, seq_len, embedding_dim)

output = moe_layer(x)

# 1. Check output shape
assert output.shape == (batch_size, seq_len, embedding_dim)

# 2. Check for gradients
# Ensure that gradients flow to both the gating network and the experts.
output.sum().backward()
assert moe_layer.gating_network.weight.grad is not None
for expert in moe_layer.experts:
    assert expert.weight.grad is not None

print("Verification passed!")