ML Katas

MoE Gating and Dispatch

medium (<1 hr) pytorch moe tensor-manipulation
this year by E

A core component of a Mixture of Experts model is the 'gating network' which determines which expert(s) each token should be sent to. This is often a top-k selection. Your task is to implement this routing logic.

Task:

Implement a function route_tokens(tokens, gating_network, num_experts, top_k) that: 1. Takes a batch of token embeddings of shape (batch_size, seq_len, embedding_dim). 2. Uses a gating_network (which can be a simple torch.nn.Linear) to produce logits for each token over the num_experts. 3. Selects the top_k experts for each token based on these logits. 4. Returns a tuple of: * top_k_indices: A tensor of shape (batch_size, seq_len, top_k) with the indices of the chosen experts for each token. * top_k_weights: A tensor of shape (batch_size, seq_len, top_k) with the corresponding softmax-normalized weights (logits) for the chosen experts. * dispatched_tokens: A tensor that has been rearranged and filtered to group tokens for each expert. A suggested shape is (batch_size * seq_len * top_k, embedding_dim). You will also need an indexing tensor to know which expert each token in this flattened tensor belongs to.

Equations:

The gating logits are typically computed as:

logits=tokens·Wg

Where Wg is the weight matrix of the gating network. The weights are then a softmax over these logits.

Verification:

import torch
import torch.nn.functional as F

# Your function here
def route_tokens(tokens, gating_network, num_experts, top_k):
    # ... your implementation ...
    pass

# --- Verification ---
batch_size = 2
seq_len = 10
embedding_dim = 32
num_experts = 8
top_k = 2

tokens = torch.randn(batch_size, seq_len, embedding_dim)
gating_network = torch.nn.Linear(embedding_dim, num_experts)

top_k_indices, top_k_weights, dispatched_tokens, expert_indices_for_dispatch = route_tokens(tokens, gating_network, num_experts, top_k)

# 1. Check shapes
assert top_k_indices.shape == (batch_size, seq_len, top_k)
assert top_k_weights.shape == (batch_size, seq_len, top_k)
assert dispatched_tokens.shape == (batch_size * seq_len * top_k, embedding_dim)
assert expert_indices_for_dispatch.shape == (batch_size * seq_len * top_k,)

# 2. Check weight properties
# Weights for each token's top-k experts should sum to 1
assert torch.allclose(top_k_weights.sum(dim=-1), torch.ones(batch_size, seq_len))

print("Verification passed!")