ML Katas

Tensor Manipulation: Creating `unfold` with `as_strided`

hard (<10 mins) pytorch tensor advanced
yesterday by E

Description

Warning: as_strided is an advanced and potentially unsafe operation that can crash your program if used incorrectly, as it creates a view on memory without checks. With that said, understanding it provides deep insight into how tensors are stored. Your task is to replicate the behavior of torch.nn.functional.unfold (sliding window extraction) using as_strided.

Guidance

unfold creates a view of the original tensor that contains all sliding local blocks. To do this with as_strided, you need to calculate the correct size and stride for the new view. The new tensor will have an extra dimension for the block elements.

Starter Code

import torch

def custom_unfold(x, kernel_size, stride=1):
    # x: (B, C, H, W)
    # Note: This is a simplified example for a single image, single channel
    x_single = x[0, 0, :, :]
    h, w = x_single.shape

    out_h = (h - kernel_size) // stride + 1
    out_w = (w - kernel_size) // stride + 1

    # Get original strides
    stride_h, stride_w = x_single.stride()

    # Calculate new shape and strides
    new_shape = (out_h, out_w, kernel_size, kernel_size)
    new_strides = (stride_h * stride, stride_w * stride, stride_h, stride_w)

    return x_single.as_strided(new_shape, new_strides)

Verification

Create a 5x5 tensor. Unfold it with a kernel_size=3. The output should have a shape of (3, 3, 3, 3). You can then reshape it to (9, 9) to match the output format of F.unfold. Compare your result to the output of torch.nn.functional.unfold on the same tensor.