ML Katas

MoE Gating: Top-K Selection

medium (<10 mins) pytorch deep learning MoE gating
this year by E

Description

In a Mixture of Experts (MoE) model, a gating network is responsible for routing each input token to a subset of 'expert' networks. [6, 14] A common strategy is Top-K gating, where the gating network outputs a score for each expert, and the top 'k' experts are chosen for each token. Your task is to implement this Top-K selection mechanism.

Guidance

Given a batch of token embeddings and a gating network, your function should: 1. Compute Gating Scores: Pass the input tokens through the gating network (a simple linear layer in this case) to get the logits (scores) for each expert. 2. Select Top-K Experts: For each token, find the indices and values of the top 'k' scores. torch.topk is the perfect tool for this.

Starter Code

import torch
import torch.nn.functional as F

def top_k_gating(token_embeddings, gating_network, k=2):
    # token_embeddings shape: (batch_size, sequence_length, embedding_dim)
    # gating_network: a linear layer mapping embedding_dim to num_experts

    # 1. Get the logits from the gating network
    # The shape should be (batch_size * sequence_length, num_experts)

    # 2. Use torch.topk to get the indices and values of the top k experts for each token

    # 3. Apply softmax to the top-k values to get the weights

    # Return the weights and indices
    pass

Verification

Create a dummy batch of token embeddings and a dummy gating network. For an input of (10, 32, 128) (batch, seqlen, embeddim) and a gating network with 8 experts, and k=2, the output indices should have a shape of (320, 2) and the weights should also have a shape of (320, 2). The sum of weights for each token should be 1.

# Verification code
batch_size, seq_len, embed_dim, num_experts = 10, 32, 128, 8
tokens = torch.randn(batch_size, seq_len, embed_dim)
gating = torch.nn.Linear(embed_dim, num_experts)
k = 2

weights, indices = top_k_gating(tokens, gating, k)

assert indices.shape == (batch_size * seq_len, k)
assert weights.shape == (batch_size * seq_len, k)
assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size * seq_len))

References

[1] Shazeer, N., et al. (2017). Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer. ICLR 2017. [2] Fedus, W., et al. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR 2022.