ML Katas

Extract Diagonal from a Batch of Matrices

hard (<10 mins) linear algebra einsum
this year by E

Description

Given a batch of square matrices of shape (B, N, N), extract the diagonal of each matrix. The output should be a tensor of shape (B, N). This can be achieved with torch.einsum.

Starter Code

import torch

def extract_diagonal(matrices):
    # Your einsum code here
    pass

Verification

For an input of shape (32, 10, 10), the output should have the shape (32, 10).