Bilinear Attention Pooling
Description
In some attention mechanisms, you need to compute a bilinear interaction between two sets of features. Given two tensors of shapes (B, N, D) and (B, M, D), compute a bilinear attention map of shape (B, N, M) by taking the dot product of each pair of feature vectors. torch.einsum is well-suited for this.
Starter Code
import torch
def bilinear_attention(query, key):
# Your einsum code here
pass
Verification
For inputs of shapes (16, 100, 128) and (16, 50, 128), the output should have the shape (16, 100, 50).