ML Katas

Sparse Updates with `scatter_add_`

medium (<30 mins) scatter_add sparse
this month by E

In graph neural networks and other sparse data applications, you often need to update a tensor based on sparse indices. Your exercise is to implement a function that takes a tensor of values, a tensor of indices, and an output tensor, and adds the values to the corresponding indices in the output tensor. This is akin to a sparse addition operation. You must use torch.scatter_add_.

Example: Given output = torch.zeros(5), indices = torch.tensor([1, 3, 1, 4]), and values = torch.tensor([1, 2, 3, 4]), the output tensor after the operation should be tensor([0., 4., 0., 2., 4.]).

Verification: - The output tensor should have the correct accumulated values at the specified indices. - Indices that are not present in the indices tensor should remain unchanged.