ML Katas

Bilinear Attention Pooling

medium (<10 mins) attention tensor manipulation einsum
this year by E

Description

In some attention mechanisms, you need to compute a bilinear interaction between two sets of features. Given two tensors of shapes (B, N, D) and (B, M, D), compute a bilinear attention map of shape (B, N, M) by taking the dot product of each pair of feature vectors. torch.einsum is well-suited for this.

Starter Code

import torch

def bilinear_attention(query, key):
    # Your einsum code here
    pass

Verification

For inputs of shapes (16, 100, 128) and (16, 50, 128), the output should have the shape (16, 100, 50).