ML Katas

Advanced Indexing with `gather` for NLP

medium (<30 mins) NLP gather indexing
this month by E

In Natural Language Processing, it's common to work with sequences of varying lengths. A frequent task is to extract the activations of the last token in each sequence from a tensor of shape (batch_size, sequence_length, hidden_size). Your exercise is to implement a function that takes such a 3D tensor and a 1D tensor of sequence lengths and returns a 2D tensor of shape (batch_size, hidden_size) containing the final hidden states for each sequence in the batch. You should use torch.gather to achieve this. [11, 12]

Code Skeleton:

import torch

def get_last_hidden_states(hidden_states, sequence_lengths):
    # hidden_states: (batch_size, sequence_length, hidden_size)
    # sequence_lengths: (batch_size)
    # Your implementation here
    pass

Verification: - The output tensor should have the shape (batch_size, hidden_size). - The i-th row of the output tensor should be equal to hidden_states[i, sequence_lengths[i] - 1, :].