diff --git a/datahandler.py b/datahandler.py index 6e06701b74a5b4e02de272bb436372650c4ab75a..525e0bf694d18be39ddd538214a0bc0fbc631585 100644 --- a/datahandler.py +++ b/datahandler.py @@ -229,3 +229,21 @@ class HAM10000(Dataset): return X, y +# Define a pytorch dataloader for this dataset +class HAM10000Seg(Dataset): + def __init__(self, df, x_transform=None, y_transform=None): + self.df = df + self.x_transform = transform + + def __len__(self): + return len(self.df) + + def __getitem__(self, index): + # Load data and get label + X = Image.open(self.df['path'][index]) + y = Image.open(f"{data_dir}/HAM10000_segmentations_lesion_tschandl/{self.df['image_id'][index]}.png") + if self.x_transform: + X = self.x_transform(X) + if self.y_transform: + y = self.y_transform(y) + return X, y