Replicating `torch.nn.Embedding` with `gather`
The torch.nn.Embedding layer is fundamental in many deep learning models, especially in NLP. Your task is to replicate its forward pass functionality using torch.gather. You'll create a function that takes an embedding weight matrix (the same as embedding.weight) and a tensor of indices, and returns the corresponding embeddings. This will help you understand the underlying mechanism of this crucial layer. [7, 16]
Function Signature:
import torch
def embedding_lookup(weights, indices):
# weights: (vocab_size, embedding_dim)
# indices: can be of any shape
# Your implementation here
pass
Verification:
- For a given set of weights and indices, your function's output should be identical to torch.nn.Embedding.from_pretrained(weights)(indices).