ML Katas

Multi-Head Attention: Merging Heads

medium (<10 mins) transformer attention einops
this year by E

Description

The inverse of splitting heads. After computing attention for each head, you need to merge them back. Given a tensor of shape (B, H, N, D//H), you need to merge it back to (B, N, D).

Starter Code

import torch
from einops import rearrange

def merge_heads(tensor):
    # Your einops code here
    pass

Verification

For an input of shape (16, 12, 100, 64), the output should have the shape (16, 100, 768).