MoE Aggregator: Combining Expert Outputs
After tokens have been dispatched to and processed by their respective experts, the outputs need to be combined based on the weights from the gating network. This exercise focuses on this 'aggregation' step.
Task:
Write a function aggregate_expert_outputs(expert_outputs, top_k_indices, top_k_weights) that correctly combines the outputs from the experts.
Given:
expert_outputs: A tensor of shape(batch_size * seq_len * top_k, embedding_dim). This is the result of processing thedispatched_tokensfrom the previous exercise through the experts.top_k_indices: The tensor of expert indices of shape(batch_size, seq_len, top_k).top_k_weights: The tensor of expert weights of shape(batch_size, seq_len, top_k).
Your function should return a final combined tensor of shape (batch_size, seq_len, embedding_dim).
Hint: This is a challenging tensor manipulation problem. You'll likely need to use einops.rearrange and potentially torch.zeros to create an intermediate tensor to scatter the weighted outputs into before summing.
Verification:
import torch
from einops import rearrange
# Your function here
def aggregate_expert_outputs(expert_outputs, top_k_indices, top_k_weights):
# ... your implementation ...
pass
# --- Verification ---
batch_size = 2
seq_len = 3
embedding_dim = 4
num_experts = 5
top_k = 2
# Dummy data that's easy to trace
expert_outputs = torch.arange(batch_size * seq_len * top_k * embedding_dim, dtype=torch.float32).reshape(batch_size * seq_len * top_k, embedding_dim)
top_k_indices = torch.randint(0, num_experts, (batch_size, seq_len, top_k))
top_k_weights = torch.rand(batch_size, seq_len, top_k)
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) # Normalize
final_output = aggregate_expert_outputs(expert_outputs, top_k_indices, top_k_weights)
# 1. Check output shape
assert final_output.shape == (batch_size, seq_len, embedding_dim)
# 2. Manual calculation for a single token
# For a single token (e.g., at batch 0, seq 0), the output should be:
# output = expert_outputs[token_index_for_expert_A] * weight_A + expert_outputs[token_index_for_expert_B] * weight_B
# The main challenge is finding the correct <!--CODE_BLOCK_19084--> in the flattened <!--CODE_BLOCK_19085-->.
# A correct implementation will pass this manual check for a few selected indices.
print("Shape verification passed!")