Masked Autoencoder (MAE) Input Preprocessing
Description
Masked Autoencoders (MAE) are a powerful self-supervised learning technique for vision transformers. The core idea is simple: randomly mask a large portion of the input image patches and train the model to reconstruct the missing patches. [1] Your task is to implement the input preprocessing step: converting an image into a sequence of patches and performing random masking.
Guidance
This is a tensor manipulation challenge. Your function should perform these steps:
1. Patchify: Reshape the input image batch (B, C, H, W)
into a batch of non-overlapping patch sequences (B, N, P*P*C)
, where N
is the number of patches and P
is the patch size.
2. Generate Masking Indices: For each image in the batch, create a random permutation of patch indices 0
to N-1
.
3. Split Indices: Split the shuffled indices into two sets: one for the visible patches and one for the masked patches, based on the mask_ratio
.
4. Gather: Use the indices to gather the visible patches from the full patch sequence.
Starter Code
import torch
def mae_patchify_and_mask(images, patch_size=16, mask_ratio=0.75):
# 1. Use <!--CODE_BLOCK_3191--> to create non-overlapping patches.
# This is a tricky but powerful function. You'll need to unfold
# along the height and width dimensions.
# After unfolding, you'll need to permute and reshape the dimensions
# to get the desired (B, N, P*P*C) shape.
# 2. For each image in the batch, create a random permutation of its patch indices.
# <!--CODE_BLOCK_3192--> and <!--CODE_BLOCK_3193--> are useful here.
# 3. Calculate how many patches to keep visible vs. mask.
# 4. Use the random indices to select the visible patches.
# <!--CODE_BLOCK_3194--> is the ideal tool for this.
# 5. Return the sequence of visible patches, the indices of the masked patches,
# and the random permutation indices (to restore order later).
pass
Verification
Create a dummy batch of images. Pass it through your function. Verify that the output shapes are correct. For an input of (10, 3, 224, 224)
with patch_size=16
and mask_ratio=0.75
, N
is 196. The output visible_patches
should have shape (10, 49, 768)
, where 49 is 196 * (1 - 0.75)
. You should also be able to use the returned indices to reconstruct the original image from the visible and (hypothetically reconstructed) masked patches.
References
[1] He, K., et al. (2021). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022.