Hierarchical Patch Merging with Einops
Description
In hierarchical vision transformers like the Swin Transformer, patch merging is used to downsample the feature map, effectively reducing the number of tokens while increasing their feature dimension. [5] This creates a hierarchical representation. Your task is to implement this operation.
Guidance
This is a classic einops reshaping problem. The function should take a tensor of patches and merge every 2x2 group of neighboring patches into a single, higher-dimensional patch.
1. Input: A tensor with shape (B, H, W, C), where H and W are the height and width of the patch grid.
2. Reshape for Merging: Use einops.rearrange to group adjacent patches. The key is to create new axes for the 2x2 blocks. The pattern b (h p1) (w p2) c -> b h w (p1 p2 c) is what you need, where p1 and p2 are both 2.
3. Projection: After rearranging, the channel dimension C will have quadrupled. Apply a linear layer to project this 4*C dimension down to 2*C, as is standard practice.
Starter Code
import torch
import torch.nn as nn
from einops import rearrange
class PatchMerging(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.projection = nn.Linear(4 * input_dim, 2 * input_dim, bias=False)
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape (B, H, W, C)
Returns:
torch.Tensor: Output tensor of shape (B, H/2, W/2, 2*C)
"""
# Use einops to rearrange the tensor, merging 2x2 patches.
# The new channel dimension will be 4*C.
x_merged = rearrange(x, 'b (h p1) (w p2) c -> b h w (p1 p2 c)', p1=2, p2=2)
# Apply the linear projection
x_projected = self.projection(x_merged)
return x_projected
Verification
Instantiate the PatchMerging module and pass a dummy tensor through it. Check if the output shape is correct.
# Parameters
B, H, W, C = 4, 14, 14, 96
# Dummy data
input_tensor = torch.randn(B, H, W, C)
# Create and apply the layer
patch_merger = PatchMerging(input_dim=C)
output_tensor = patch_merger(input_tensor)
# Verification checks
expected_shape = (B, H // 2, W // 2, 2 * C)
assert output_tensor.shape == expected_shape
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")
print(f"Expected shape: {expected_shape}")
print("Verification successful!")
References
[1] Liu, Z., et al. (2021). Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. ICCV 2021. [5, 16]