Sparse MoE Top-K Gating
Description
In a Mixture of Experts (MoE) model, the gating network is a crucial component that determines which 'expert' subnetworks process each token. [1] A common strategy is top-k gating, where for each token, the gate selects the top k experts with the highest scores. Your task is to implement this logic.
Guidance
This exercise focuses on tensor indexing and shaping. Your function will receive a batch of token representations and produce routing weights and indices for the top k experts for each token.
1. Input: A tensor of token embeddings with shape (B, N, D), where B is batch size, N is sequence length, and D is the model dimension.
2. Gating Logits: Use a linear layer to project the input tensor from dimension D to E, the number of experts. The result is your gating logits: (B, N, E).
3. Top-K Selection: For each token, find the k experts with the highest logits. torch.topk is the perfect tool for this.
4. Softmax: Apply a softmax function to the selected top-k logits to get the final routing weights.
Starter Code
import torch
import torch.nn as nn
from einops import rearrange
def top_k_gating(x, gate_projection, k):
"""
Args:
x (torch.Tensor): Input tensor of shape (B, N, D)
gate_projection (nn.Linear): A linear layer mapping D to E (num_experts)
k (int): The number of experts to select for each token
Returns:
torch.Tensor: The final routing weights for the selected experts, shape (B, N, k)
torch.Tensor: The indices of the selected experts, shape (B, N, k)
"""
# 1. Project the input to get gating logits
gating_logits = gate_projection(x)
# 2. Select the top k logits and their indices
# Hint: Use torch.topk along the last dimension (the experts)
top_k_logits, top_k_indices = # ... your code here
# 3. Apply softmax to the selected logits to get weights
# Hint: The weights should sum to 1 for each token's selected experts
top_k_weights = # ... your code here
return top_k_weights, top_k_indices
Verification
Create a dummy batch of tokens and a gate projection layer. Pass them through your function and verify the output shapes and values.
# Parameters
batch_size = 4
seq_len = 100
model_dim = 128
num_experts = 16
k = 2
# Dummy data
input_tokens = torch.randn(batch_size, seq_len, model_dim)
gate_proj = nn.Linear(model_dim, num_experts)
# Get routing info
weights, indices = top_k_gating(input_tokens, gate_proj, k)
# Verification checks
assert weights.shape == (batch_size, seq_len, k)
assert indices.shape == (batch_size, seq_len, k)
assert torch.all(indices >= 0) and torch.all(indices < num_experts)
# Check that weights for each token sum to 1
assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size, seq_len))
print("Verification successful!")
References
[1] Shazeer, N., et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. [1, 7, 8, 22, 24]