ML Katas

Grouped Matrix Multiplication

medium (<10 mins) linear algebra tensor manipulation einsum
this year by E

Description

Perform matrix multiplication on groups of matrices within a batch. For an input tensor of shape (B, G, M, K) and another of shape (B, G, K, N), the output should be of shape (B, G, M, N). This is a common operation in grouped convolutions and other structured models. You can use torch.einsum for this.

Starter Code

import torch

def grouped_matmul(tensor1, tensor2):
    # Your einsum code here
    pass

Verification

Create two tensors of shapes (10, 4, 5, 6) and (10, 4, 6, 7). The output should have the shape (10, 4, 5, 7).