ML Katas

Matrix Multiplication Basics

easy (<10 mins) linear algebra tensors torch.mm
this month by E

Implement a function in PyTorch that multiplies two matrices using torch.mm.

Problem

Write a function matmul(A, B) that takes two 2D tensors A and B and returns their matrix product.

  • Input: Two tensors A of shape (m, n) and B of shape (n, p).
  • Output: A tensor of shape (m, p).

Example

A = torch.tensor([[1., 2.], [3., 4.]])
B = torch.tensor([[5., 6.], [7., 8.]])
print(matmul(A, B))
# Expected: tensor([[19., 22.], [43., 50.]])

Solution Sketch

Use torch.mm(A, B) to perform the multiplication. Ensure the inner dimensions match (A.shape[1] == B.shape[0]).