ML Katas

Flatten Leading Dimensions

easy (<10 mins) einops reshaping
this year by E

Description

Given a tensor with multiple leading dimensions, flatten them into a single dimension. For example, transform a tensor of shape (D1, D2, D3, D4) into (D1*D2, D3, D4).

Starter Code

import torch
from einops import rearrange

def flatten_leading(tensor):
    # Your einops code here
    pass

Verification

For an input of shape (2, 3, 4, 5), the output should have the shape (6, 4, 5).