ML Katas

Einops: Multi-Head Attention Input Projection

easy (<10 mins) transformer attention einops
today by E

Description

In the Multi-Head Attention mechanism, the input tensor (B, N, D) is linearly projected to create the Query, Key, and Value matrices. These are then reshaped to have separate "heads". Your task is to perform this projection and reshaping using einops for maximum clarity.

Guidance

  1. Perform a standard matrix multiplication to project your input tensor x with a combined QKV projection weight.
  2. Use einops.rearrange to split the last dimension of the result into three parts (for Q, K, and V) and simultaneously create the num_heads dimension.

Starter Code

import torch
from einops import rearrange

def project_qkv(x, proj_weight, num_heads):
    # x shape: (B, N, D)
    # proj_weight shape: (3*D, D)

    # 1. Project the input
    x_proj = x @ proj_weight.T

    # 2. Rearrange into Q, K, V with multiple heads
    # The target shape should be (3, B, num_heads, N, head_dim)
    qkv = rearrange(x_proj, 'b n (three h d) -> three b h n d', three=3, h=num_heads)
    return qkv

Verification

For an input x of shape (10, 196, 768) and num_heads=12, the head_dim will be 768 // 12 = 64. The output tensor from your function should have the shape (3, 10, 12, 196, 64).