Grouped Matrix Multiplication
Description
Perform matrix multiplication on groups of matrices within a batch. For an input tensor of shape (B, G, M, K) and another of shape (B, G, K, N), the output should be of shape (B, G, M, N). This is a common operation in grouped convolutions and other structured models. You can use torch.einsum for this.
Starter Code
import torch
def grouped_matmul(tensor1, tensor2):
# Your einsum code here
pass
Verification
Create two tensors of shapes (10, 4, 5, 6) and (10, 4, 6, 7). The output should have the shape (10, 4, 5, 7).