ML Katas

Einops: Batched Matrix Multiplication

easy (<10 mins) pytorch einops tensor
today by E

Description

Perform a batched matrix multiplication (B, N, D) @ (B, D, M) -> (B, N, M) using einops einsum. While torch.bmm is the standard, this is a good exercise to understand how einsum notation works.

Guidance

Use einops.einsum. The string pattern for batched matrix multiplication is a direct representation of the operation: the batch dimension b is kept, the shared dimension d is summed over, and the other two dimensions n and m form the output.

Starter Code

import torch
from einops import einsum

def batched_matmul_einsum(a, b):
    # a: (B, N, D)
    # b: (B, D, M)
    # The einsum pattern should multiply and sum correctly.
    return einsum(a, b, 'b n d, b d m -> b n m')

Verification

Create two random tensors with compatible batch dimensions, e.g., (10, 128, 64) and (10, 64, 32). Compare the output of your function with the output of torch.bmm(a, b). They should be identical.