ML Katas

Efficiently Updating a Sub-Tensor with `scatter_`

hard (<1 hr) scatter indexing
this month by E

Sometimes you need to update a portion of a larger tensor with values from a smaller tensor, where the update locations are specified by an index tensor. This is a common pattern in various algorithms, including attention mechanisms and sparse updates. Your exercise is to implement a function that takes a source tensor, a destination tensor, and an indices tensor. The function should update the destination tensor at the specified indices with the values from the source tensor. You should use torch.scatter_ to perform this in-place update. [17]

Example: destination = torch.zeros(5, 3) source = torch.tensor([[1, 2, 3], [4, 5, 6]]) indices = torch.tensor([1, 4]) The destination tensor should be updated to:

tensor([[0., 0., 0.],
        [1., 2., 3.],
        [0., 0., 0.],
        [0., 0., 0.],
        [4., 5., 6.]])

Verification: - The destination tensor should be modified in-place. - The rows (or elements along the specified dimension) of the destination tensor indicated by indices should be equal to the corresponding rows (or elements) of the source tensor.