ML Katas

Implementing One-Hot Encoding with `scatter_`

easy (<10 mins) scatter one-hot encoding
this month by E

Your task is to create a function that performs one-hot encoding on a tensor of integer labels. This is a common preprocessing step for categorical data in machine learning. You will be given a 1D tensor of labels and the total number of classes. Your function should output a 2D tensor where each row corresponds to a label and has a 1 at the index of the class and 0s elsewhere. Use the torch.scatter_ method to perform this operation efficiently.

Example: Given labels = torch.tensor([0, 2, 1, 3]) and num_classes = 4, the output should be:

tensor([[1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]])

Verification: - The output tensor should have the shape (len(labels), num_classes). - For each row in the output, the element at the index specified by the corresponding input label should be 1. - All other elements in each row should be 0.