Tensor Manipulation: Creating `unfold` with `as_strided`
Description
Warning: as_strided
is an advanced and potentially unsafe operation that can crash your program if used incorrectly, as it creates a view on memory without checks. With that said, understanding it provides deep insight into how tensors are stored. Your task is to replicate the behavior of torch.nn.functional.unfold
(sliding window extraction) using as_strided
.
Guidance
unfold
creates a view of the original tensor that contains all sliding local blocks. To do this with as_strided
, you need to calculate the correct size
and stride
for the new view. The new tensor will have an extra dimension for the block elements.
Starter Code
import torch
def custom_unfold(x, kernel_size, stride=1):
# x: (B, C, H, W)
# Note: This is a simplified example for a single image, single channel
x_single = x[0, 0, :, :]
h, w = x_single.shape
out_h = (h - kernel_size) // stride + 1
out_w = (w - kernel_size) // stride + 1
# Get original strides
stride_h, stride_w = x_single.stride()
# Calculate new shape and strides
new_shape = (out_h, out_w, kernel_size, kernel_size)
new_strides = (stride_h * stride, stride_w * stride, stride_h, stride_w)
return x_single.as_strided(new_shape, new_strides)
Verification
Create a 5x5 tensor. Unfold it with a kernel_size=3
. The output should have a shape of (3, 3, 3, 3)
. You can then reshape it to (9, 9)
to match the output format of F.unfold
. Compare your result to the output of torch.nn.functional.unfold
on the same tensor.