ML Katas

Einops: Transpose for Attention Output

easy (<10 mins) transformer attention einops
today by E

Description

After the multi-head attention calculation, the output tensor typically has the shape (B, num_heads, N, head_dim). To feed this into the next layer (usually a feed-forward network), it needs to be reshaped back to (B, N, D), where D = num_heads * head_dim. Your task is to perform this reshaping.

Guidance

Use einops.rearrange. The pattern should take the num_heads and head_dim dimensions and combine them into a single new dimension D at the end.

Starter Code

import torch
from einops import rearrange

def reformat_attention_output(x):
    # x: (B, num_heads, N, head_dim)
    # Your rearrange pattern here
    # Target shape: (B, N, D)
    output = rearrange(x, 'b h n d -> b n (h d)')
    return output

Verification

Create a random tensor of shape (10, 12, 196, 64). After passing it through your function, the output shape should be (10, 196, 768).