Einops: Squeeze and Unsqueeze
Description
torch.squeeze
and torch.unsqueeze
are common for removing or adding dimensions of size one. einops.rearrange
can do this as well, often with more clarity by explicitly naming the dimensions.
Guidance
For unsqueeze
, simply add a new dimension (often named with a 1
or a new name) in the output pattern. For squeeze
, omit the dimension of size 1 from the output pattern and specify its size in the arguments.
Starter Code
import torch
from einops import rearrange
def unsqueeze_with_einops(x, dim):
# x: (10, 20)
# Let's say we want to unsqueeze at dim=1 -> (10, 1, 20)
if dim == 0: return rearrange(x, 'h w -> 1 h w')
if dim == 1: return rearrange(x, 'h w -> h 1 w')
if dim == 2: return rearrange(x, 'h w -> h w 1')
def squeeze_with_einops(x, dim):
# x: (10, 1, 20)
# Let's say we want to squeeze dim=1 -> (10, 20)
if dim == 1: return rearrange(x, 'h 1 w -> h w')
Verification
Create a tensor of shape (10, 20)
. Use unsqueeze_with_einops
to add a dimension at index 1, and verify the shape is (10, 1, 20)
. Then use squeeze_with_einops
on the result to verify the shape returns to (10, 20)
.