Einops Warm-up: Reshaping Tensors for Expert Batching
In Mixture of Experts (MoE) models, we often need to reshape tensors to efficiently process data across multiple 'experts'. Imagine you have a batch of sequences, and for each token in each sequence, you have a probability distribution from a gating network indicating which expert to send it to. To perform a batched computation, it's efficient to group all tokens assigned to the same expert from across the entire batch.
Your task is to take a tensor of token embeddings and an assignment tensor, and use einops to create a new tensor where tokens are grouped by the expert they are assigned to.
Given:
tokens: A tensor of shape(batch_size, seq_len, embedding_dim)representing token embeddings.expert_assignments: A tensor of shape(batch_size, seq_len)where each element is an integer from0tonum_experts - 1indicating which expert the token is assigned to.num_experts: The total number of experts.
Task:
Write a function group_tokens_by_expert(tokens, expert_assignments, num_experts) that returns a tensor of shape (num_experts, batch_size * seq_len / num_experts, embedding_dim). You can assume for simplicity that the total number of tokens is perfectly divisible by num_experts.
Hint: You might need to use rearrange from einops and potentially some intermediate reshaping.
Verification:
To check your implementation, you can create sample tensors and assert the shape of the output. A more robust check involves verifying that a specific token ends up in the correct expert's batch.
import torch
from einops import rearrange
# Your function here
def group_tokens_by_expert(tokens, expert_assignments, num_experts):
# ... your implementation ...
pass
# --- Verification ---
batch_size = 2
seq_len = 4
embedding_dim = 8
num_experts = 2
tokens = torch.randn(batch_size, seq_len, embedding_dim)
# Simple assignment: expert 0 for the first half, expert 1 for the second
expert_assignments = torch.tensor([,])
# A more realistic, scattered assignment might be:
# expert_assignments = torch.randint(0, num_experts, (batch_size, seq_len))
# To test your function:
batched_by_expert = group_tokens_by_expert(tokens, expert_assignments, num_experts)
# 1. Check the shape
expected_shape = (num_experts, (batch_size * seq_len) // num_experts, embedding_dim)
assert batched_by_expert.shape == expected_shape, f'Incorrect shape: {batched_by_expert.shape}'
# 2. Check a specific value (more advanced check)
# This part of the verification is tricky without the solution.
# A good approach is to manually construct the expected output for a small example and compare.
print("Shape verification passed!")