ML Katas

Tensor Manipulation: One-Hot Encoding

easy (<10 mins) pytorch data tensor
today by E

Description

Implement one-hot encoding for a batch of class indices. Given a 1D tensor of integer labels, create a 2D tensor where each row is a vector of zeros except for a 1 at the index of the class.

Guidance

torch.zeros is a good way to start. The most efficient way to place the 1s is to use torch.scatter_. It allows you to update a tensor based on indices from another tensor.

Starter Code

import torch

def one_hot_encode(labels, num_classes):
    # labels: (B,)
    batch_size = labels.size(0)
    # 1. Create a zero tensor of the correct shape
    y_onehot = torch.zeros(batch_size, num_classes)

    # 2. Use scatter_ to place 1s
    # The first argument to scatter_ is the dimension
    # The second is the index tensor (needs to be 2D)
    # The third is the value to scatter
    y_onehot.scatter_(1, labels.unsqueeze(1), 1)

    return y_onehot

Verification

Create a tensor of labels like torch.tensor([0, 2, 1]) with num_classes=3. The output should be [[1, 0, 0], [0, 0, 1], [0, 1, 0]].