Multi-Head Attention: Splitting Heads
Description
In multi-head attention, the query, key, and value tensors are split into multiple heads. Given a tensor of shape (B, N, D), where D is the embedding dimension, you need to split it into (B, H, N, D//H), where H is the number of heads.
Starter Code
import torch
from einops import rearrange
def split_heads(tensor, num_heads):
# Your einops code here
pass
Verification
For an input of shape (16, 100, 768) and num_heads of 12, the output should have the shape (16, 12, 100, 64).