Einops: Batched Matrix Multiplication
Description
Perform a batched matrix multiplication (B, N, D) @ (B, D, M) -> (B, N, M)
using einops
einsum
. While torch.bmm
is the standard, this is a good exercise to understand how einsum
notation works.
Guidance
Use einops.einsum
. The string pattern for batched matrix multiplication is a direct representation of the operation: the batch dimension b
is kept, the shared dimension d
is summed over, and the other two dimensions n
and m
form the output.
Starter Code
import torch
from einops import einsum
def batched_matmul_einsum(a, b):
# a: (B, N, D)
# b: (B, D, M)
# The einsum pattern should multiply and sum correctly.
return einsum(a, b, 'b n d, b d m -> b n m')
Verification
Create two random tensors with compatible batch dimensions, e.g., (10, 128, 64)
and (10, 64, 32)
. Compare the output of your function with the output of torch.bmm(a, b)
. They should be identical.