ML Katas

Hierarchical Patch Merging with Einops

easy (<10 mins) pytorch einops vision swin
this year by E

Description

In hierarchical vision transformers like the Swin Transformer, patch merging is used to downsample the feature map, effectively reducing the number of tokens while increasing their feature dimension. [5] This creates a hierarchical representation. Your task is to implement this operation.

Guidance

This is a classic einops reshaping problem. The function should take a tensor of patches and merge every 2x2 group of neighboring patches into a single, higher-dimensional patch. 1. Input: A tensor with shape (B, H, W, C), where H and W are the height and width of the patch grid. 2. Reshape for Merging: Use einops.rearrange to group adjacent patches. The key is to create new axes for the 2x2 blocks. The pattern b (h p1) (w p2) c -> b h w (p1 p2 c) is what you need, where p1 and p2 are both 2. 3. Projection: After rearranging, the channel dimension C will have quadrupled. Apply a linear layer to project this 4*C dimension down to 2*C, as is standard practice.

Starter Code

import torch
import torch.nn as nn
from einops import rearrange

class PatchMerging(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.projection = nn.Linear(4 * input_dim, 2 * input_dim, bias=False)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (B, H, W, C)

        Returns:
            torch.Tensor: Output tensor of shape (B, H/2, W/2, 2*C)
        """
        # Use einops to rearrange the tensor, merging 2x2 patches.
        # The new channel dimension will be 4*C.
        x_merged = rearrange(x, 'b (h p1) (w p2) c -> b h w (p1 p2 c)', p1=2, p2=2)

        # Apply the linear projection
        x_projected = self.projection(x_merged)

        return x_projected

Verification

Instantiate the PatchMerging module and pass a dummy tensor through it. Check if the output shape is correct.

# Parameters
B, H, W, C = 4, 14, 14, 96

# Dummy data
input_tensor = torch.randn(B, H, W, C)

# Create and apply the layer
patch_merger = PatchMerging(input_dim=C)
output_tensor = patch_merger(input_tensor)

# Verification checks
expected_shape = (B, H // 2, W // 2, 2 * C)
assert output_tensor.shape == expected_shape

print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")
print(f"Expected shape: {expected_shape}")
print("Verification successful!")

References

[1] Liu, Z., et al. (2021). Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. ICCV 2021. [5, 16]