ML Katas

Multi-Head Attention: Splitting Heads

medium (<10 mins) transformer attention einops
this year by E

Description

In multi-head attention, the query, key, and value tensors are split into multiple heads. Given a tensor of shape (B, N, D), where D is the embedding dimension, you need to split it into (B, H, N, D//H), where H is the number of heads.

Starter Code

import torch
from einops import rearrange

def split_heads(tensor, num_heads):
    # Your einops code here
    pass

Verification

For an input of shape (16, 100, 768) and num_heads of 12, the output should have the shape (16, 12, 100, 64).