ML Katas

Image Patch Extraction with `select` and `narrow`

easy (<30 mins) vision select narrow
this month by E

In computer vision, a common operation is to extract patches from an image. Your task is to write a function that extracts a patch of a given size from a specific starting location in an image tensor. The input image tensor will have the shape (channels, height, width). You should use a combination of torch.select and torch.narrow to perform this operation. [3]

Function Signature:

import torch

def extract_patch(image, start_h, start_w, patch_height, patch_width):
    # image: (channels, height, width)
    # Your implementation here
    pass

Verification: - The output tensor should have the shape (channels, patch_height, patch_width). - The content of the output patch should correspond to the specified region of the input image.