Matrix Multiplication Basics
Implement a function in PyTorch that multiplies two matrices using torch.mm.
Problem
Write a function matmul(A, B) that takes two 2D tensors A and B and returns their matrix product.
- Input: Two tensors
Aof shape(m, n)andBof shape(n, p). - Output: A tensor of shape
(m, p).
Example
A = torch.tensor([[1., 2.], [3., 4.]])
B = torch.tensor([[5., 6.], [7., 8.]])
print(matmul(A, B))
# Expected: tensor([[19., 22.], [43., 50.]])
Solution Sketch
Use torch.mm(A, B) to perform the multiplication. Ensure the inner dimensions match (A.shape[1] == B.shape[0]).