MoE Gating and Dispatch
A core component of a Mixture of Experts model is the 'gating network' which determines which expert(s) each token should be sent to. This is often a top-k selection. Your task is to implement this routing logic.
Task:
Implement a function route_tokens(tokens, gating_network, num_experts, top_k) that:
1. Takes a batch of token embeddings of shape (batch_size, seq_len, embedding_dim).
2. Uses a gating_network (which can be a simple torch.nn.Linear) to produce logits for each token over the num_experts.
3. Selects the top_k experts for each token based on these logits.
4. Returns a tuple of:
* top_k_indices: A tensor of shape (batch_size, seq_len, top_k) with the indices of the chosen experts for each token.
* top_k_weights: A tensor of shape (batch_size, seq_len, top_k) with the corresponding softmax-normalized weights (logits) for the chosen experts.
* dispatched_tokens: A tensor that has been rearranged and filtered to group tokens for each expert. A suggested shape is (batch_size * seq_len * top_k, embedding_dim). You will also need an indexing tensor to know which expert each token in this flattened tensor belongs to.
Equations:
The gating logits are typically computed as:
Where is the weight matrix of the gating network. The weights are then a softmax over these logits.
Verification:
import torch
import torch.nn.functional as F
# Your function here
def route_tokens(tokens, gating_network, num_experts, top_k):
# ... your implementation ...
pass
# --- Verification ---
batch_size = 2
seq_len = 10
embedding_dim = 32
num_experts = 8
top_k = 2
tokens = torch.randn(batch_size, seq_len, embedding_dim)
gating_network = torch.nn.Linear(embedding_dim, num_experts)
top_k_indices, top_k_weights, dispatched_tokens, expert_indices_for_dispatch = route_tokens(tokens, gating_network, num_experts, top_k)
# 1. Check shapes
assert top_k_indices.shape == (batch_size, seq_len, top_k)
assert top_k_weights.shape == (batch_size, seq_len, top_k)
assert dispatched_tokens.shape == (batch_size * seq_len * top_k, embedding_dim)
assert expert_indices_for_dispatch.shape == (batch_size * seq_len * top_k,)
# 2. Check weight properties
# Weights for each token's top-k experts should sum to 1
assert torch.allclose(top_k_weights.sum(dim=-1), torch.ones(batch_size, seq_len))
print("Verification passed!")