Advanced Indexing with `gather` for NLP
In Natural Language Processing, it's common to work with sequences of varying lengths. A frequent task is to extract the activations of the last token in each sequence from a tensor of shape (batch_size, sequence_length, hidden_size). Your exercise is to implement a function that takes such a 3D tensor and a 1D tensor of sequence lengths and returns a 2D tensor of shape (batch_size, hidden_size) containing the final hidden states for each sequence in the batch. You should use torch.gather to achieve this. [11, 12]
Code Skeleton:
import torch
def get_last_hidden_states(hidden_states, sequence_lengths):
# hidden_states: (batch_size, sequence_length, hidden_size)
# sequence_lengths: (batch_size)
# Your implementation here
pass
Verification:
- The output tensor should have the shape (batch_size, hidden_size).
- The i-th row of the output tensor should be equal to hidden_states[i, sequence_lengths[i] - 1, :].