ML Katas

Custom Dataset Class

medium (<30 mins) dataloader nn datasets
this month by E

Create a custom PyTorch Dataset for pairs of numbers and their sum.

Problem

Implement a dataset where each sample is (x, y, x+y).

  • Input: A list of tuples (x, y).
  • Output: For index i, return (x, y, x+y).

Example

from torch.utils.data import Dataset

class SumDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y = self.data[idx]
        return x, y, x+y

# Test
sample = SumDataset([(1,2), (3,4)])
print(sample[0])  # Expected: (1, 2, 3)

Solution Sketch

Subclass torch.utils.data.Dataset and implement __len__ and __getitem__. This is the standard pattern for custom datasets.