Building a Simple Mixture of Experts (MoE) Layer
Now, let's combine the concepts of dispatching and aggregating into a full, albeit simplified, torch.nn.Module for a Mixture of Experts layer. This layer will replace a standard feed-forward network in a Transformer block.
Task:
Create a SimpleMoELayer class that inherits from torch.nn.Module.
It should have:
- An
__init__method that initializes:- A
gating_network(torch.nn.Linear). - A list or
torch.nn.ModuleListofexperts. Each expert can be a simple two-layer MLP.
- A
- A
forwardmethod that:- Takes an input tensor
xof shape(batch_size, seq_len, embedding_dim). - Uses the gating network to get
top_kexpert indices and weights for each token. - Dispatches the tokens to the correct experts. A key challenge is to efficiently map the flattened tokens to their respective experts for computation.
- Processes the dispatched tokens through their assigned experts.
- Aggregates the outputs from the experts based on the gating weights.
- Returns the final tensor of shape
(batch_size, seq_len, embedding_dim).
- Takes an input tensor
Bonus: Implement a load balancing loss, a crucial component for training real MoE models. The loss encourages the gating network to distribute tokens evenly across all experts. A common formulation is:
Where is the number of experts, is the fraction of tokens dispatched to expert , and is the average router probability for expert . is a scaling hyperparameter.
Verification:
import torch
import torch.nn as nn
class SimpleMoELayer(nn.Module):
# ... your implementation ...
pass
# --- Verification ---
batch_size = 4
seq_len = 16
embedding_dim = 32
num_experts = 8
top_k = 2
moe_layer = SimpleMoELayer(embedding_dim, num_experts, top_k)
x = torch.randn(batch_size, seq_len, embedding_dim)
output = moe_layer(x)
# 1. Check output shape
assert output.shape == (batch_size, seq_len, embedding_dim)
# 2. Check for gradients
# Ensure that gradients flow to both the gating network and the experts.
output.sum().backward()
assert moe_layer.gating_network.weight.grad is not None
for expert in moe_layer.experts:
assert expert.weight.grad is not None
print("Verification passed!")