MoE Gating: Top-K Selection
Description
In a Mixture of Experts (MoE) model, a gating network is responsible for routing each input token to a subset of 'expert' networks. [6, 14] A common strategy is Top-K gating, where the gating network outputs a score for each expert, and the top 'k' experts are chosen for each token. Your task is to implement this Top-K selection mechanism.
Guidance
Given a batch of token embeddings and a gating network, your function should:
1. Compute Gating Scores: Pass the input tokens through the gating network (a simple linear layer in this case) to get the logits (scores) for each expert.
2. Select Top-K Experts: For each token, find the indices and values of the top 'k' scores. torch.topk is the perfect tool for this.
Starter Code
import torch
import torch.nn.functional as F
def top_k_gating(token_embeddings, gating_network, k=2):
# token_embeddings shape: (batch_size, sequence_length, embedding_dim)
# gating_network: a linear layer mapping embedding_dim to num_experts
# 1. Get the logits from the gating network
# The shape should be (batch_size * sequence_length, num_experts)
# 2. Use torch.topk to get the indices and values of the top k experts for each token
# 3. Apply softmax to the top-k values to get the weights
# Return the weights and indices
pass
Verification
Create a dummy batch of token embeddings and a dummy gating network. For an input of (10, 32, 128) (batch, seqlen, embeddim) and a gating network with 8 experts, and k=2, the output indices should have a shape of (320, 2) and the weights should also have a shape of (320, 2). The sum of weights for each token should be 1.
# Verification code
batch_size, seq_len, embed_dim, num_experts = 10, 32, 128, 8
tokens = torch.randn(batch_size, seq_len, embed_dim)
gating = torch.nn.Linear(embed_dim, num_experts)
k = 2
weights, indices = top_k_gating(tokens, gating, k)
assert indices.shape == (batch_size * seq_len, k)
assert weights.shape == (batch_size * seq_len, k)
assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size * seq_len))
References
[1] Shazeer, N., et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. [2] Fedus, W., et al. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR 2022.