ML Katas

Replicating `torch.nn.Embedding` with `gather`

medium (<30 mins) embeddings NLP gather
this month by E

The torch.nn.Embedding layer is fundamental in many deep learning models, especially in NLP. Your task is to replicate its forward pass functionality using torch.gather. You'll create a function that takes an embedding weight matrix (the same as embedding.weight) and a tensor of indices, and returns the corresponding embeddings. This will help you understand the underlying mechanism of this crucial layer. [7, 16]

Function Signature:

import torch

def embedding_lookup(weights, indices):
    # weights: (vocab_size, embedding_dim)
    # indices: can be of any shape
    # Your implementation here
    pass

Verification: - For a given set of weights and indices, your function's output should be identical to torch.nn.Embedding.from_pretrained(weights)(indices).