Extract Diagonal from a Batch of Matrices
Description
Given a batch of square matrices of shape (B, N, N), extract the diagonal of each matrix. The output should be a tensor of shape (B, N). This can be achieved with torch.einsum.
Starter Code
import torch
def extract_diagonal(matrices):
# Your einsum code here
pass
Verification
For an input of shape (32, 10, 10), the output should have the shape (32, 10).