Multi-Head Attention: Merging Heads
Description
The inverse of splitting heads. After computing attention for each head, you need to merge them back. Given a tensor of shape (B, H, N, D//H), you need to merge it back to (B, N, D).
Starter Code
import torch
from einops import rearrange
def merge_heads(tensor):
# Your einops code here
pass
Verification
For an input of shape (16, 12, 100, 64), the output should have the shape (16, 100, 768).