Tensor Manipulation: One-Hot Encoding
Description
Implement one-hot encoding for a batch of class indices. Given a 1D tensor of integer labels, create a 2D tensor where each row is a vector of zeros except for a 1
at the index of the class.
Guidance
torch.zeros
is a good way to start. The most efficient way to place the 1
s is to use torch.scatter_
. It allows you to update a tensor based on indices from another tensor.
Starter Code
import torch
def one_hot_encode(labels, num_classes):
# labels: (B,)
batch_size = labels.size(0)
# 1. Create a zero tensor of the correct shape
y_onehot = torch.zeros(batch_size, num_classes)
# 2. Use scatter_ to place 1s
# The first argument to scatter_ is the dimension
# The second is the index tensor (needs to be 2D)
# The third is the value to scatter
y_onehot.scatter_(1, labels.unsqueeze(1), 1)
return y_onehot
Verification
Create a tensor of labels like torch.tensor([0, 2, 1])
with num_classes=3
. The output should be [[1, 0, 0], [0, 0, 1], [0, 1, 0]]
.