ML Katas

Batch-wise Matrix Transposition

easy (<10 mins) linear algebra einops tensor manipulation
this year by E

Description

Given a batch of matrices, transpose each matrix in the batch. The input tensor has a shape of (B, H, W), and the output should be (B, W, H).

Starter Code

import torch
from einops import rearrange

def batch_transpose(matrices):
    # Your einops code here
    pass

Verification

Create a random tensor of shape (32, 10, 20). The output of your function should have the shape (32, 20, 10).