Einops: Multi-Head Attention Input Projection
Description
In the Multi-Head Attention mechanism, the input tensor (B, N, D)
is linearly projected to create the Query, Key, and Value matrices. These are then reshaped to have separate "heads". Your task is to perform this projection and reshaping using einops
for maximum clarity.
Guidance
- Perform a standard matrix multiplication to project your input tensor
x
with a combined QKV projection weight. - Use
einops.rearrange
to split the last dimension of the result into three parts (for Q, K, and V) and simultaneously create thenum_heads
dimension.
Starter Code
import torch
from einops import rearrange
def project_qkv(x, proj_weight, num_heads):
# x shape: (B, N, D)
# proj_weight shape: (3*D, D)
# 1. Project the input
x_proj = x @ proj_weight.T
# 2. Rearrange into Q, K, V with multiple heads
# The target shape should be (3, B, num_heads, N, head_dim)
qkv = rearrange(x_proj, 'b n (three h d) -> three b h n d', three=3, h=num_heads)
return qkv
Verification
For an input x
of shape (10, 196, 768)
and num_heads=12
, the head_dim
will be 768 // 12 = 64
. The output tensor from your function should have the shape (3, 10, 12, 196, 64)
.