Selecting RoIs (Regions of Interest) with `index_select`
In object detection tasks, after a region proposal network (RPN) suggests potential object locations, these regions of interest (RoIs) need to be extracted from the feature map for further processing. Your task is to write a function that, given a batch of feature maps of shape (batch_size, num_channels, height, width) and a tensor of RoI indices (num_rois), selects the corresponding feature maps for each RoI. Each RoI is associated with a specific image in the batch. You should use torch.index_select for this. [1, 5]
Code Skeleton:
import torch
def select_rois(feature_maps, roi_batch_indices):
# feature_maps: (batch_size, num_channels, height, width)
# roi_batch_indices: (num_rois) tensor indicating which image in the batch each RoI belongs to.
# Your implementation here
pass
Verification:
- The output should have the shape (num_rois, num_channels, height, width).
- The i-th feature map in the output should correspond to the roi_batch_indices[i]-th feature map in the input.