ML Katas

Selecting RoIs (Regions of Interest) with `index_select`

medium (<30 mins) vision index_select
this month by E

In object detection tasks, after a region proposal network (RPN) suggests potential object locations, these regions of interest (RoIs) need to be extracted from the feature map for further processing. Your task is to write a function that, given a batch of feature maps of shape (batch_size, num_channels, height, width) and a tensor of RoI indices (num_rois), selects the corresponding feature maps for each RoI. Each RoI is associated with a specific image in the batch. You should use torch.index_select for this. [1, 5]

Code Skeleton:

import torch

def select_rois(feature_maps, roi_batch_indices):
    # feature_maps: (batch_size, num_channels, height, width)
    # roi_batch_indices: (num_rois) tensor indicating which image in the batch each RoI belongs to.
    # Your implementation here
    pass

Verification: - The output should have the shape (num_rois, num_channels, height, width). - The i-th feature map in the output should correspond to the roi_batch_indices[i]-th feature map in the input.