ML Katas

Einops: Squeeze and Unsqueeze

easy (<10 mins) einops tensor
today by E

Description

torch.squeeze and torch.unsqueeze are common for removing or adding dimensions of size one. einops.rearrange can do this as well, often with more clarity by explicitly naming the dimensions.

Guidance

For unsqueeze, simply add a new dimension (often named with a 1 or a new name) in the output pattern. For squeeze, omit the dimension of size 1 from the output pattern and specify its size in the arguments.

Starter Code

import torch
from einops import rearrange

def unsqueeze_with_einops(x, dim):
    # x: (10, 20)
    # Let's say we want to unsqueeze at dim=1 -> (10, 1, 20)
    if dim == 0: return rearrange(x, 'h w -> 1 h w')
    if dim == 1: return rearrange(x, 'h w -> h 1 w')
    if dim == 2: return rearrange(x, 'h w -> h w 1')

def squeeze_with_einops(x, dim):
    # x: (10, 1, 20)
    # Let's say we want to squeeze dim=1 -> (10, 20)
    if dim == 1: return rearrange(x, 'h 1 w -> h w')

Verification

Create a tensor of shape (10, 20). Use unsqueeze_with_einops to add a dimension at index 1, and verify the shape is (10, 1, 20). Then use squeeze_with_einops on the result to verify the shape returns to (10, 20).