ML Katas

Einops Warm-up: Reshaping Tensors for Expert Batching

easy (<30 mins) pytorch einops tensor-manipulation
this year by E

In Mixture of Experts (MoE) models, we often need to reshape tensors to efficiently process data across multiple 'experts'. Imagine you have a batch of sequences, and for each token in each sequence, you have a probability distribution from a gating network indicating which expert to send it to. To perform a batched computation, it's efficient to group all tokens assigned to the same expert from across the entire batch.

Your task is to take a tensor of token embeddings and an assignment tensor, and use einops to create a new tensor where tokens are grouped by the expert they are assigned to.

Given:

  • tokens: A tensor of shape (batch_size, seq_len, embedding_dim) representing token embeddings.
  • expert_assignments: A tensor of shape (batch_size, seq_len) where each element is an integer from 0 to num_experts - 1 indicating which expert the token is assigned to.
  • num_experts: The total number of experts.

Task:

Write a function group_tokens_by_expert(tokens, expert_assignments, num_experts) that returns a tensor of shape (num_experts, batch_size * seq_len / num_experts, embedding_dim). You can assume for simplicity that the total number of tokens is perfectly divisible by num_experts.

Hint: You might need to use rearrange from einops and potentially some intermediate reshaping.

Verification:

To check your implementation, you can create sample tensors and assert the shape of the output. A more robust check involves verifying that a specific token ends up in the correct expert's batch.

import torch
from einops import rearrange

# Your function here
def group_tokens_by_expert(tokens, expert_assignments, num_experts):
    # ... your implementation ...
    pass

# --- Verification ---
batch_size = 2
seq_len = 4
embedding_dim = 8
num_experts = 2

tokens = torch.randn(batch_size, seq_len, embedding_dim)
# Simple assignment: expert 0 for the first half, expert 1 for the second
expert_assignments = torch.tensor([,])

# A more realistic, scattered assignment might be:
# expert_assignments = torch.randint(0, num_experts, (batch_size, seq_len))

# To test your function:
batched_by_expert = group_tokens_by_expert(tokens, expert_assignments, num_experts)

# 1. Check the shape
expected_shape = (num_experts, (batch_size * seq_len) // num_experts, embedding_dim)
assert batched_by_expert.shape == expected_shape, f'Incorrect shape: {batched_by_expert.shape}'

# 2. Check a specific value (more advanced check)
# This part of the verification is tricky without the solution.
# A good approach is to manually construct the expected output for a small example and compare.

print("Shape verification passed!")