ML Katas

Tensor Manipulation: Using `scatter_add_`

medium (<10 mins) pytorch tensor advanced
today by E

Description

torch.scatter_add_ is used to add values into a tensor at specified indices. It's useful in cases like converting an edge list in a graph to an adjacency matrix or pooling operations. Your task is to use it to sum values that have duplicate indices.

Guidance

Imagine you have a tensor of values and a corresponding tensor of indices. You want to create an output tensor where output[i] is the sum of all values whose corresponding index is i.

Starter Code

import torch

def sum_by_index(values, indices, output_size):
    # values: (N,)
    # indices: (N,)
    output = torch.zeros(output_size)
    # The index tensor for scatter_add_ needs to have the same
    # number of dimensions as the tensor it's modifying.
    return output.scatter_add_(0, indices, values)

Verification

Create values = torch.tensor([1, 2, 3, 4, 5]) and indices = torch.tensor([0, 1, 0, 2, 1]). With output_size=3, your function should return [1+3, 2+5, 4], which is [4, 7, 4].