Source code for dataset

import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

[docs] class SegmentationDataset(Dataset): """Dataset for segmentation tasks. The images and masks should be placed in two different directories and have the same file names. Attributes ---------- image_dir : str Path to the directory containing the images mask_dir : str Path to the directory containing the masks transform : Callable Tranformation function to be applied to the images and masks in order to perform data augmentation. images : List[str] List of image file names Methods ------- __len__(self) Return the length of the dataset __getitem__(self, idx) Return the image and mask at the given index """ def __init__(self, images_dir, masks_dir, transform=None): """Initialize the dataset Parameters ---------- images_dir : str Path to the directory containing the images masks_dir : str Path to the directory containing the masks transform : Callable Tranformation function to be applied to the images and masks in order to perform data augmentation. """ self.image_dir = images_dir self.mask_dir = masks_dir self.transform = transform self.images = os.listdir(images_dir)
[docs] def __len__(self): """Return the length of the dataset""" return len(self.images)
[docs] def __getitem__(self, idx): """Return the image and mask at the given index Parameters ---------- idx : int Index of the image and mask to return Returns ------- image : np.ndarray Image as a numpy array mask : np.ndarray Mask as a numpy array """ image_path = os.path.join(self.image_dir, self.images[idx]) mask_path = os.path.join(self.mask_dir, self.images[idx]) image = np.array(Image.open(image_path).convert("RGB")) mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) # so [0-255] mask = (mask > 128).astype(np.float32) # threshold the mask to get binary values (0 or 1) if self.transform: augmentations = self.transform(image=image, mask=mask) image = augmentations["image"] mask = augmentations["mask"] return image, mask