Einops: Transpose for Attention Output
Description
After the multi-head attention calculation, the output tensor typically has the shape (B, num_heads, N, head_dim)
. To feed this into the next layer (usually a feed-forward network), it needs to be reshaped back to (B, N, D)
, where D = num_heads * head_dim
. Your task is to perform this reshaping.
Guidance
Use einops.rearrange
. The pattern should take the num_heads
and head_dim
dimensions and combine them into a single new dimension D
at the end.
Starter Code
import torch
from einops import rearrange
def reformat_attention_output(x):
# x: (B, num_heads, N, head_dim)
# Your rearrange pattern here
# Target shape: (B, N, D)
output = rearrange(x, 'b h n d -> b n (h d)')
return output
Verification
Create a random tensor of shape (10, 12, 196, 64)
. After passing it through your function, the output shape should be (10, 196, 768)
.