ML Katas

Sparse MoE Top-K Gating

easy (<10 mins) pytorch einops gating moe
this year by E

Description

In a Mixture of Experts (MoE) model, the gating network is a crucial component that determines which 'expert' subnetworks process each token. [1] A common strategy is top-k gating, where for each token, the gate selects the top k experts with the highest scores. Your task is to implement this logic.

Guidance

This exercise focuses on tensor indexing and shaping. Your function will receive a batch of token representations and produce routing weights and indices for the top k experts for each token. 1. Input: A tensor of token embeddings with shape (B, N, D), where B is batch size, N is sequence length, and D is the model dimension. 2. Gating Logits: Use a linear layer to project the input tensor from dimension D to E, the number of experts. The result is your gating logits: (B, N, E). 3. Top-K Selection: For each token, find the k experts with the highest logits. torch.topk is the perfect tool for this. 4. Softmax: Apply a softmax function to the selected top-k logits to get the final routing weights.

Starter Code

import torch
import torch.nn as nn
from einops import rearrange

def top_k_gating(x, gate_projection, k):
    """
    Args:
        x (torch.Tensor): Input tensor of shape (B, N, D)
        gate_projection (nn.Linear): A linear layer mapping D to E (num_experts)
        k (int): The number of experts to select for each token

    Returns:
        torch.Tensor: The final routing weights for the selected experts, shape (B, N, k)
        torch.Tensor: The indices of the selected experts, shape (B, N, k)
    """
    # 1. Project the input to get gating logits
    gating_logits = gate_projection(x)

    # 2. Select the top k logits and their indices
    # Hint: Use torch.topk along the last dimension (the experts)
    top_k_logits, top_k_indices = # ... your code here

    # 3. Apply softmax to the selected logits to get weights
    # Hint: The weights should sum to 1 for each token's selected experts
    top_k_weights = # ... your code here

    return top_k_weights, top_k_indices

Verification

Create a dummy batch of tokens and a gate projection layer. Pass them through your function and verify the output shapes and values.

# Parameters
batch_size = 4
seq_len = 100
model_dim = 128
num_experts = 16
k = 2

# Dummy data
input_tokens = torch.randn(batch_size, seq_len, model_dim)
gate_proj = nn.Linear(model_dim, num_experts)

# Get routing info
weights, indices = top_k_gating(input_tokens, gate_proj, k)

# Verification checks
assert weights.shape == (batch_size, seq_len, k)
assert indices.shape == (batch_size, seq_len, k)
assert torch.all(indices >= 0) and torch.all(indices < num_experts)
# Check that weights for each token sum to 1
assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size, seq_len))

print("Verification successful!")

References

[1] Shazeer, N., et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. [1, 7, 8, 22, 24]