Custom Dataset Class
Create a custom PyTorch Dataset
for pairs of numbers and their sum.
Problem
Implement a dataset where each sample is (x, y, x+y)
.
- Input: A list of tuples
(x, y)
. - Output: For index
i
, return(x, y, x+y)
.
Example
from torch.utils.data import Dataset
class SumDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x, y = self.data[idx]
return x, y, x+y
# Test
sample = SumDataset([(1,2), (3,4)])
print(sample[0]) # Expected: (1, 2, 3)
Solution Sketch
Subclass torch.utils.data.Dataset
and implement __len__
and __getitem__
. This is the standard pattern for custom datasets.