ML Katas

Unflatten a Dimension

easy (<10 mins) einops reshaping
this year by E

Description

This is the inverse of flattening. Given a tensor where the first dimension is a product of two other dimensions, unflatten it. For example, transform a tensor of shape (D1*D2, D3, D4) into (D1, D2, D3, D4). You will need to know one of the original dimensions.

Starter Code

import torch
from einops import rearrange

def unflatten_dim(tensor, d1_shape):
    # Your einops code here
    pass

Verification

For an input of shape (6, 4, 5) and d1_shape of 2, the output should have the shape (2, 3, 4, 5).