Tensor Manipulation: Using `scatter_add_`
Description
torch.scatter_add_
is used to add values into a tensor at specified indices. It's useful in cases like converting an edge list in a graph to an adjacency matrix or pooling operations. Your task is to use it to sum values that have duplicate indices.
Guidance
Imagine you have a tensor of values and a corresponding tensor of indices. You want to create an output tensor where output[i]
is the sum of all values
whose corresponding index
is i
.
Starter Code
import torch
def sum_by_index(values, indices, output_size):
# values: (N,)
# indices: (N,)
output = torch.zeros(output_size)
# The index tensor for scatter_add_ needs to have the same
# number of dimensions as the tensor it's modifying.
return output.scatter_add_(0, indices, values)
Verification
Create values = torch.tensor([1, 2, 3, 4, 5])
and indices = torch.tensor([0, 1, 0, 2, 1])
. With output_size=3
, your function should return [1+3, 2+5, 4]
, which is [4, 7, 4]
.