Flatten Leading Dimensions
Description
Given a tensor with multiple leading dimensions, flatten them into a single dimension. For example, transform a tensor of shape (D1, D2, D3, D4) into (D1*D2, D3, D4).
Starter Code
import torch
from einops import rearrange
def flatten_leading(tensor):
# Your einops code here
pass
Verification
For an input of shape (2, 3, 4, 5), the output should have the shape (6, 4, 5).