ML Katas

Masked Autoencoder (MAE) Input Preprocessing

medium (<30 mins) pytorch self-supervised mae vision transformer
yesterday by E

Description

Masked Autoencoders (MAE) are a powerful self-supervised learning technique for vision transformers. The core idea is simple: randomly mask a large portion of the input image patches and train the model to reconstruct the missing patches. [1] Your task is to implement the input preprocessing step: converting an image into a sequence of patches and performing random masking.

Guidance

This is a tensor manipulation challenge. Your function should perform these steps: 1. Patchify: Reshape the input image batch (B, C, H, W) into a batch of non-overlapping patch sequences (B, N, P*P*C), where N is the number of patches and P is the patch size. 2. Generate Masking Indices: For each image in the batch, create a random permutation of patch indices 0 to N-1. 3. Split Indices: Split the shuffled indices into two sets: one for the visible patches and one for the masked patches, based on the mask_ratio. 4. Gather: Use the indices to gather the visible patches from the full patch sequence.

Starter Code

import torch

def mae_patchify_and_mask(images, patch_size=16, mask_ratio=0.75):
    # 1. Use <!--CODE_BLOCK_3191--> to create non-overlapping patches.
    #    This is a tricky but powerful function. You'll need to unfold
    #    along the height and width dimensions.
    #    After unfolding, you'll need to permute and reshape the dimensions
    #    to get the desired (B, N, P*P*C) shape.

    # 2. For each image in the batch, create a random permutation of its patch indices.
    #    <!--CODE_BLOCK_3192--> and <!--CODE_BLOCK_3193--> are useful here.

    # 3. Calculate how many patches to keep visible vs. mask.

    # 4. Use the random indices to select the visible patches.
    #    <!--CODE_BLOCK_3194--> is the ideal tool for this.

    # 5. Return the sequence of visible patches, the indices of the masked patches,
    #    and the random permutation indices (to restore order later).
    pass

Verification

Create a dummy batch of images. Pass it through your function. Verify that the output shapes are correct. For an input of (10, 3, 224, 224) with patch_size=16 and mask_ratio=0.75, N is 196. The output visible_patches should have shape (10, 49, 768), where 49 is 196 * (1 - 0.75). You should also be able to use the returned indices to reconstruct the original image from the visible and (hypothetically reconstructed) masked patches.

References

[1] He, K., et al. (2021). Masked Autoencoders Are Scalable Vision Learners. CVPR 2022.