Unflatten a Dimension
Description
This is the inverse of flattening. Given a tensor where the first dimension is a product of two other dimensions, unflatten it. For example, transform a tensor of shape (D1*D2, D3, D4) into (D1, D2, D3, D4). You will need to know one of the original dimensions.
Starter Code
import torch
from einops import rearrange
def unflatten_dim(tensor, d1_shape):
# Your einops code here
pass
Verification
For an input of shape (6, 4, 5) and d1_shape of 2, the output should have the shape (2, 3, 4, 5).