Matrix Multiplication Basics
Implement a function in PyTorch that multiplies two matrices using torch.mm
.
Problem
Write a function matmul(A, B)
that takes two 2D tensors A
and B
and returns their matrix product.
- Input: Two tensors
A
of shape(m, n)
andB
of shape(n, p)
. - Output: A tensor of shape
(m, p)
.
Example
A = torch.tensor([[1., 2.], [3., 4.]])
B = torch.tensor([[5., 6.], [7., 8.]])
print(matmul(A, B))
# Expected: tensor([[19., 22.], [43., 50.]])
Solution Sketch
Use torch.mm(A, B)
to perform the multiplication. Ensure the inner dimensions match (A.shape[1] == B.shape[0]
).