Skip to content

Datasets API

This page documents the dataset components of the Segmentation Robustness Framework.

Dataset Classes

segmentation_robustness_framework.datasets.voc

Classes

VOCSegmentation(split: str, root: Optional[Union[Path, str]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = True)

Bases: Dataset

Pascal VOC 2012 dataset for semantic segmentation.

The Pascal VOC 2012 dataset contains 21 classes of objects in natural scenes. Images are paired with pixel-level segmentation masks for training and evaluation.

Setup Instructions:

The dataset will be automatically downloaded and extracted if not present. When download=True (default): - If root is provided, the dataset will be stored at root/voc/VOCdevkit/VOC2012/. - If root is None, the dataset will be cached in the default cache directory. When download=False: - The dataset must be present at the exact path specified by root. - If root is None, the dataset will be looked for in the default cache directory.

Supported Splits: - train: Training images (1,464 samples) - val: Validation images (1,449 samples) - trainval: Combined train and validation (2,913 samples)

Attributes:

Name Type Description
root str | Path | None

Directory for dataset storage or cache location.

split str

Dataset split ('train', 'val', 'trainval').

transform callable

Image transformations.

target_transform callable

Target transformations.

download bool

Whether to download dataset if not present.

num_classes int

Number of semantic classes (21).

Initialize Pascal VOC 2012 dataset.

Parameters:

Name Type Description Default
split str

Dataset split. Must be one of 'train', 'val', or 'trainval'.

required
root str | Path | None

Directory for dataset storage. If None, uses default cache directory. Defaults to None.

None
transform callable

Transform to apply to images. Defaults to None.

None
target_transform callable

Transform to apply to masks. Defaults to None.

None
download bool

Whether to download dataset if not present. Defaults to True.

True

Raises:

Type Description
FileNotFoundError

If dataset is not found and download fails.

ValueError

If split is not valid.

Source code in segmentation_robustness_framework/datasets/voc.py
def __init__(
    self,
    split: str,
    root: Optional[Union[Path, str]] = None,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = True,
) -> None:
    """Initialize Pascal VOC 2012 dataset.

    Args:
        split (str): Dataset split. Must be one of 'train', 'val', or 'trainval'.
        root (str | Path | None, optional): Directory for dataset storage.
            If `None`, uses default cache directory. Defaults to None.
        transform (callable, optional): Transform to apply to images.
            Defaults to None.
        target_transform (callable, optional): Transform to apply to masks.
            Defaults to None.
        download (bool, optional): Whether to download dataset if not present.
            Defaults to True.

    Raises:
        FileNotFoundError: If dataset is not found and download fails.
        ValueError: If split is not valid.
    """
    from segmentation_robustness_framework.utils.dataset_utils import get_cache_dir

    if download:
        root_path = Path(root) / "voc" if root is not None else get_cache_dir("voc")
        dataset_root = root_path / "VOCdevkit" / "VOC2012"
    else:
        dataset_root = Path(root) if root is not None else get_cache_dir("voc")

    if not dataset_root.exists():
        if download:
            downloaded_file = download_dataset(self.URL, root_path, self.MD5)
            extract_dataset(downloaded_file, root_path)
        if not dataset_root.exists():
            raise FileNotFoundError(
                f"Could not find dataset at '{dataset_root}'. If you set `download=False`, "
                "make sure the dataset is present. Otherwise ensure write permissions and try again."
            )

    if split not in self.VALID_SPLITS:
        raise ValueError(f"Invalid split '{split}'. Expected one of {self.VALID_SPLITS}.")

    self.images_dir = dataset_root / "JPEGImages"
    self.masks_dir = dataset_root / "SegmentationClass"
    self.split = split
    self.transform = transform
    self.target_transform = target_transform

    with open(dataset_root / "ImageSets" / "Segmentation" / f"{split}.txt") as f:
        self.images = f.read().splitlines()

    self.num_classes = 21

Functions

segmentation_robustness_framework.datasets.ade20k

Classes

ADE20K(split: str, root: Optional[Union[Path, str]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = True)

Bases: Dataset

ADE20K dataset for semantic segmentation.

The ADE20K dataset contains 20,210 images with 150 semantic categories. Images are paired with pixel-level segmentation masks for training and evaluation.

Setup Instructions:

The dataset will be automatically downloaded and extracted if not present. When download=True (default): - If root is provided, the dataset will be stored at root/ade20k/ADEChallengeData2016/. - If root is None, the dataset will be cached in the default cache directory. When download=False: - The dataset must be present at the exact path specified by root. - If root is None, the dataset will be looked for in the default cache directory.

Supported Splits: - train: Training images (~20,000 samples) - val: Validation images (~2,000 samples)

Attributes:

Name Type Description
root str | Path | None

Directory for dataset storage or cache location.

split str

Dataset split ('train', 'val').

transform callable

Image transformations.

target_transform callable

Target transformations.

download bool

Whether to download dataset if not present.

num_classes int

Number of semantic classes (150).

Initialize ADE20K dataset.

Parameters:

Name Type Description Default
split str

Dataset split. Must be one of 'train' or 'val'.

required
root str | Path | None

Directory for dataset storage. If None, uses default cache directory. Defaults to None.

None
transform callable

Transform to apply to images. Defaults to None.

None
target_transform callable

Transform to apply to masks. Defaults to None.

None
download bool

Whether to download dataset if not present. Defaults to True.

True

Raises:

Type Description
FileNotFoundError

If dataset is not found and download fails.

ValueError

If split is not valid.

Source code in segmentation_robustness_framework/datasets/ade20k.py
def __init__(
    self,
    split: str,
    root: Optional[Union[Path, str]] = None,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = True,
) -> None:
    """Initialize ADE20K dataset.

    Args:
        split (str): Dataset split. Must be one of 'train' or 'val'.
        root (str | Path | None, optional): Directory for dataset storage.
            If `None`, uses default cache directory. Defaults to None.
        transform (callable, optional): Transform to apply to images.
            Defaults to None.
        target_transform (callable, optional): Transform to apply to masks.
            Defaults to None.
        download (bool, optional): Whether to download dataset if not present.
            Defaults to True.

    Raises:
        FileNotFoundError: If dataset is not found and download fails.
        ValueError: If split is not valid.
    """
    from segmentation_robustness_framework.utils.dataset_utils import get_cache_dir

    if download:
        root_path = Path(root) / "ade20k" if root is not None else get_cache_dir("ade20k")
        dataset_path = root_path / "ADEChallengeData2016"
    else:
        dataset_path = Path(root) if root is not None else get_cache_dir("ade20k")

    if not dataset_path.exists():
        if download:
            downloaded_file = download_dataset(self.URL, root_path, self.MD5)
            extract_dataset(downloaded_file, root_path)
        if not dataset_path.exists():
            raise FileNotFoundError(
                f"Could not find dataset at '{dataset_path}'. If you set `download=False`, "
                "make sure the dataset is present. Otherwise ensure write permissions and try again."
            )

    if split not in self.VALID_SPLITS:
        raise ValueError(f"Invalid split '{split}'. Expected one of {self.VALID_SPLITS}.")

    self.split = split
    self.images_dir = (
        dataset_path / "images/training" if self.split == "train" else dataset_path / "images/validation"
    )
    self.masks_dir = (
        dataset_path / "annotations/training" if self.split == "train" else dataset_path / "annotations/validation"
    )
    self.transform = transform
    self.target_transform = target_transform
    self.images = os.listdir(self.images_dir)

    self.num_classes = 150

Functions

segmentation_robustness_framework.datasets.cityscapes

Classes

Cityscapes(root: Union[Path, str], split: str = 'train', mode: str = 'fine', target_type: str = 'semantic', transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)

Bases: Dataset

Cityscapes dataset for semantic segmentation.

Cityscapes is a large-scale dataset for semantic understanding of urban street scenes. It contains high-quality pixel-level annotations of 5000 images in 50 cities.

Setup Instructions:

  1. Register at https://www.cityscapes-dataset.com/
  2. Download the dataset files:
  3. leftImg8bit_trainvaltest.zip (11GB) - training/validation/test images
  4. gtFine_trainval.zip (241MB) - fine annotations for train/val
  5. gtCoarse.zip (1.3GB) - coarse annotations for train/val/train_extra
  6. leftImg8bit_trainextra.zip (44GB) - extra training images (optional)
  7. Extract all archives to the same root directory
  8. Ensure the directory structure matches:
    root/
    ├── leftImg8bit/
    │   ├── train/
    │   ├── val/
    │   └── test/
    ├── gtFine/
    │   ├── train/
    │   └── val/
    ├── gtCoarse/
    │   ├── train/
    │   ├── val/
    │   └── train_extra/
    └── leftImg8bit_trainextra/
        └── train_extra/
    

Supported Splits: - train: Training images with fine annotations - val: Validation images with fine annotations - test: Test images (no annotations available) - train_extra: Extra training images with coarse annotations

Supported Modes: - fine: High-quality pixel-level annotations - coarse: Coarse polygon annotations

Supported Target Types: - semantic: Semantic segmentation masks - instance: Instance segmentation masks - color: Color-coded visualization masks - polygon: Polygon annotations (JSON format)

Attributes:

Name Type Description
root str | Path

Path to the Cityscapes dataset root directory.

split str

Dataset split ('train', 'val', 'test', 'train_extra').

mode str

Annotation mode ('fine' or 'coarse').

target_type str | list

Type of target annotations.

transform callable

Image transformations.

target_transform callable

Target transformations.

num_classes int

Number of semantic classes (35).

Initialize Cityscapes dataset.

Parameters:

Name Type Description Default
root str | Path

Path to the Cityscapes dataset root directory. Must contain the extracted dataset files with proper directory structure.

required
split str

Dataset split. Must be one of 'train', 'val', 'test', or 'train_extra'. Defaults to "train".

'train'
mode str

Annotation mode. Must be 'fine' or 'coarse'. Defaults to "fine".

'fine'
target_type str | list

Type of target annotations. Can be a single type or list of types. Must be one or more of 'semantic', 'instance', 'color', 'polygon'. Defaults to "semantic".

'semantic'
transform callable

Transform to apply to images. Defaults to None.

None
target_transform callable

Transform to apply to targets. Defaults to None.

None

Raises:

Type Description
ValueError

If root directory does not exist.

ValueError

If split is not valid.

ValueError

If mode is not valid.

ValueError

If target_type is not valid.

ValueError

If test split is used with coarse mode.

ValueError

If train_extra split is used with fine mode.

ValueError

If required dataset files are missing.

Source code in segmentation_robustness_framework/datasets/cityscapes.py
def __init__(
    self,
    root: Union[Path, str],
    split: str = "train",
    mode: str = "fine",
    target_type: str = "semantic",
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
) -> None:
    """Initialize Cityscapes dataset.

    Args:
        root (str | Path): Path to the Cityscapes dataset root directory.
            Must contain the extracted dataset files with proper directory structure.
        split (str, optional): Dataset split. Must be one of 'train', 'val', 'test',
            or 'train_extra'. Defaults to "train".
        mode (str, optional): Annotation mode. Must be 'fine' or 'coarse'.
            Defaults to "fine".
        target_type (str | list, optional): Type of target annotations. Can be a
            single type or list of types. Must be one or more of 'semantic',
            'instance', 'color', 'polygon'. Defaults to "semantic".
        transform (callable, optional): Transform to apply to images.
            Defaults to None.
        target_transform (callable, optional): Transform to apply to targets.
            Defaults to None.

    Raises:
        ValueError: If root directory does not exist.
        ValueError: If split is not valid.
        ValueError: If mode is not valid.
        ValueError: If target_type is not valid.
        ValueError: If test split is used with coarse mode.
        ValueError: If train_extra split is used with fine mode.
        ValueError: If required dataset files are missing.
    """
    if not os.path.exists(root):
        raise ValueError(f"Root directory '{root}' does not exist.")

    if split not in self.VALID_SPLITS:
        raise ValueError(f"Invalid split '{split}'. Expected one of {self.VALID_SPLITS}.")

    if mode not in self.VALID_MODES:
        raise ValueError(f"Invalid mode '{mode}'. Expected one of {self.VALID_MODES}.")

    if isinstance(target_type, str):
        target_type = [target_type]
    for t in target_type:
        if t not in self.VALID_TARGET_TYPES:
            raise ValueError(f"Invalid target_type '{t}'. Expected one of {self.VALID_TARGET_TYPES}.")

    if split == "test" and mode == "coarse":
        raise ValueError("The 'test' split is not available for 'coarse' mode. Use 'fine' mode instead.")

    if split == "train_extra" and mode == "fine":
        raise ValueError("The 'train_extra' split is not available for 'fine' mode. Use 'coarse' mode instead.")

    self.root = root
    self.split = split
    self.mode = "gtFine" if mode == "fine" else "gtCoarse"
    self.target_type = target_type
    self.transform = transform
    self.target_transform = target_transform

    if split == "train_extra":
        self.images_dir = os.path.join(root, "leftImg8bit_trainextra", split)
        if not os.path.exists(self.images_dir):
            raise ValueError(
                f"Directory '{self.images_dir}' does not exist. Download 'leftImg8bit_trainextra.zip' and extract to the {root} directory or use another split ('train or 'val)"
            )
    else:
        self.images_dir = os.path.join(root, "leftImg8bit", split)
    self.targets_dir = os.path.join(root, self.mode, split)

    self.images = []
    self.targets = []

    for city in os.listdir(self.images_dir):
        city_image_dir = os.path.join(self.images_dir, city)
        city_target_dir = os.path.join(self.targets_dir, city)

        for file_name in os.listdir(city_image_dir):
            image_path = os.path.join(city_image_dir, file_name)
            self.images.append(image_path)

            target_paths = []
            for t in self.target_type:
                target_file_name = file_name.replace(
                    "_leftImg8bit.png", f"_{self.mode}_{self._get_target_postfix(t)}"
                )

                target_path = os.path.join(city_target_dir, target_file_name)
                target_paths.append(target_path)

            self.targets.append(target_paths)

    self.num_classes = 35

Functions

segmentation_robustness_framework.datasets.stanford_background

Classes

StanfordBackground(root: Optional[Union[Path, str]] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = True)

Bases: Dataset

Stanford Background dataset for semantic segmentation.

The Stanford Background dataset contains 715 images with 9 semantic categories. Images are paired with pixel-level segmentation masks for training and evaluation.

Setup Instructions:

The dataset will be automatically downloaded and extracted if not present. When download=True (default): - If root is provided, the dataset will be stored at root/stanford_background/stanford_background/. - If root is None, the dataset will be cached in the default cache directory. When download=False: - The dataset must be present at the exact path specified by root. - If root is None, the dataset will be looked for in the default cache directory.

Dataset Structure: - images/: Input RGB images - labels_colored/: Segmentation masks (color images)

Attributes:

Name Type Description
root str | Path | None

Directory for dataset storage or cache location.

transform callable

Image transformations.

target_transform callable

Target transformations.

download bool

Whether to download dataset if not present.

num_classes int

Number of semantic classes (9).

Initialize Stanford Background dataset.

Parameters:

Name Type Description Default
root str | Path | None

Directory for dataset storage. If None, uses default cache directory. Defaults to None.

None
transform callable

Transform to apply to images. Defaults to None.

None
target_transform callable

Transform to apply to masks. Defaults to None.

None
download bool

Whether to download dataset if not present. Defaults to True.

True

Raises:

Type Description
FileNotFoundError

If dataset is not found and download fails.

Source code in segmentation_robustness_framework/datasets/stanford_background.py
def __init__(
    self,
    root: Optional[Union[Path, str]] = None,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = True,
) -> None:
    """Initialize Stanford Background dataset.

    Args:
        root (str | Path | None, optional): Directory for dataset storage.
            If `None`, uses default cache directory. Defaults to None.
        transform (callable, optional): Transform to apply to images.
            Defaults to None.
        target_transform (callable, optional): Transform to apply to masks.
            Defaults to None.
        download (bool, optional): Whether to download dataset if not present.
            Defaults to True.

    Raises:
        FileNotFoundError: If dataset is not found and download fails.
    """
    from segmentation_robustness_framework.utils.dataset_utils import get_cache_dir

    if download:
        root_path = Path(root) / "stanford_background" if root is not None else get_cache_dir("stanford_background")
        dataset_path = root_path / "stanford_background"
    else:
        dataset_path = Path(root) if root is not None else get_cache_dir("stanford_background")

    if not dataset_path.exists():
        if download:
            downloaded_file = download_dataset(self.URL, root_path, self.MD5)
            extract_dataset(downloaded_file, dataset_path)
        if not dataset_path.exists():
            raise FileNotFoundError(
                f"Could not find dataset at '{dataset_path}'. If you set `download=False`, "
                "make sure the dataset is present. Otherwise ensure write permissions and try again."
            )

    self.images_dir = dataset_path / "images"
    self.masks_dir = dataset_path / "labels_colored"
    self.transform = transform
    self.target_transform = target_transform
    self.images = os.listdir(self.images_dir)

    self.num_classes = 9

Functions

Dataset Loaders

segmentation_robustness_framework.loaders.dataset_loader

Classes

DatasetLoader(dataset_config: dict[str, Any])

Load and configure datasets for image segmentation tasks.

The DatasetLoader initializes and loads a dataset by name using the provided configuration, and applies preprocessing to input images and segmentation masks.

Supported attributes
  • config (dict[str, Any]): Configuration specifying the dataset and its parameters.
  • dataset_name (str): Name of the dataset to be loaded (e.g., VOC, ADE20K).
  • root (str): Root directory where the dataset is located.
  • images_shape (list[int]): Desired image shape for preprocessing [height, width].
Example
loader = DatasetLoader({
    "name": "VOCSegmentation",
    "root": "/path/to/voc",
    "image_shape": [256, 256],
    "split": "train",
})
dataset = loader.load_dataset()

Initialize the DatasetLoader with a dataset configuration.

Parameters:

Name Type Description Default
dataset_config dict[str, Any]
  • name (str): Dataset name.
  • root (str): Root directory of the dataset.
  • image_shape (list[int]): Desired image shape.
  • Additional dataset-specific parameters.
required
Source code in segmentation_robustness_framework/loaders/dataset_loader.py
def __init__(self, dataset_config: dict[str, Any]) -> None:
    """Initialize the DatasetLoader with a dataset configuration.

    Args:
        dataset_config (dict[str, Any]):
            - `name` (str): Dataset name.
            - `root` (str): Root directory of the dataset.
            - `image_shape` (list[int]): Desired image shape.
            - Additional dataset-specific parameters.
    """

    self.config = dataset_config
    self.dataset_name = self.config["name"]
    self.root = self.config["root"]
    self.images_shape = self.config["image_shape"]

Functions

load_dataset() -> Dataset

Loads and preprocesses the specified dataset.

Based on the dataset name in the configuration, the corresponding dataset class is initialized with appropriate preprocessing transformations applied.

Returns:

Name Type Description
Dataset Dataset

An instance of the dataset class ready for training or evaluation.

Raises:

Type Description
ValueError

If the specified dataset name is not recognized.

Source code in segmentation_robustness_framework/loaders/dataset_loader.py
def load_dataset(self) -> Dataset:
    """Loads and preprocesses the specified dataset.

    Based on the dataset name in the configuration, the corresponding dataset class
    is initialized with appropriate preprocessing transformations applied.

    Returns:
        Dataset: An instance of the dataset class ready for training or evaluation.

    Raises:
        ValueError: If the specified dataset name is not recognized.
    """
    preprocess, target_preprocess = image_preprocessing.get_preprocessing_fn(self.images_shape, self.dataset_name)

    try:
        ds_cls = DATASET_REGISTRY[self.dataset_name]
    except KeyError:
        raise ValueError(f"Unknown dataset: {self.dataset_name}. Available: {list(DATASET_REGISTRY.keys())}")

    sig = inspect.signature(ds_cls)
    common_kwargs = dict(root=self.root, transform=preprocess, target_transform=target_preprocess)

    extra = {k: v for k, v in self.config.items() if k in sig.parameters}
    common_kwargs.update(extra)

    return ds_cls(**common_kwargs)

Modules

Dataset Overview

The framework provides support for popular semantic segmentation datasets with automatic preprocessing and data loading.

Available Datasets

VOC (PASCAL VOC 2012)

The PASCAL Visual Object Classes dataset:

from segmentation_robustness_framework.datasets import VOCSegmentation

# Load VOC dataset
dataset = VOCSegmentation(
    root="./data",
    split="val",
    transform=transform,
    target_transform=target_transform
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")

Features: - 20 object classes + background - High-quality pixel-level annotations - Standard benchmark dataset - Automatic download support

ADE20K (MIT Scene Parsing)

The MIT ADE20K dataset for scene parsing:

from segmentation_robustness_framework.datasets import ADE20K

# Load ADE20K dataset
dataset = ADE20K(
    root="./data",
    split="val",
    transform=transform,
    target_transform=target_transform
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")

Features: - 150 semantic categories - Complex scene understanding - High-resolution images - Detailed annotations

Cityscapes

Urban scene understanding dataset:

from segmentation_robustness_framework.datasets import Cityscapes

# Load Cityscapes dataset
dataset = Cityscapes(
    root="./data",
    split="val",
    mode="fine",
    target_type="semantic",
    transform=transform,
    target_transform=target_transform
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")

Features: - 19 semantic categories - High-resolution urban images - Fine and coarse annotations - Multiple annotation types

Stanford Background

Natural scene parsing dataset:

from segmentation_robustness_framework.datasets import StanfordBackground

# Load Stanford Background dataset
dataset = StanfordBackground(
    root="./data",
    transform=transform,
    target_transform=target_transform
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of classes: {dataset.num_classes}")

Features: - 8 semantic categories - Natural outdoor scenes - High-quality annotations - Compact dataset for testing

Dataset Configuration

Configure datasets in YAML configuration files. The framework automatically applies preprocessing based on the image_shape parameter:

dataset:
  name: voc
  split: val
  root: ./data
  image_shape: [512, 512]  # Automatically applies resize, normalize, and mask conversion

Dataset Loading

Use the DatasetLoader for automatic dataset loading with preprocessing:

from segmentation_robustness_framework.loaders import DatasetLoader

# Load dataset with configuration
dataset_config = {
    "name": "voc",
    "split": "val",
    "root": "./data",
    "image_shape": [512, 512]  # Automatically applies preprocessing
}

dataset_loader = DatasetLoader(dataset_config)
dataset = dataset_loader.load_dataset()

Automatic Preprocessing

The framework automatically applies preprocessing based on the image_shape parameter:

from segmentation_robustness_framework.utils.image_preprocessing import get_preprocessing_fn

# Get preprocessing functions (automatically called by DatasetLoader)
image_preprocess, target_preprocess = get_preprocessing_fn(
    image_shape=[512, 512],
    dataset_name="voc"
)

# The preprocessing includes:
# - Image resize to specified shape
# - Image normalization (ImageNet stats)
# - Mask resize to match image
# - RGB to index conversion for masks
# - Stride alignment (ensures dimensions are divisible by 8)

What Gets Applied Automatically

When you specify image_shape in the dataset configuration, the framework automatically applies:

Image Preprocessing

  • Resize: Images are resized to the specified [height, width]
  • Normalization: Images are normalized using ImageNet statistics (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  • Tensor Conversion: Images are converted to PyTorch tensors

Mask Preprocessing

  • Resize: Masks are resized to match the image dimensions
  • RGB to Index: RGB masks are converted to class indices using dataset-specific color palettes
  • Stride Alignment: Dimensions are adjusted to be divisible by 8 for model compatibility

Dataset-Specific Features

  • Color Mapping: Each dataset has its own color palette for mask conversion
  • Ignore Index: Proper handling of ignored pixels (usually index 255)
  • Error Handling: Warnings for unmapped colors in masks

Custom Datasets

Create custom datasets by inheriting from torch.utils.data.Dataset:

import torch
from torch.utils.data import Dataset
from PIL import Image
import os

class MyCustomDataset(Dataset):
    def __init__(self, root, split="train", transform=None, target_transform=None):
        self.root = root
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.num_classes = 21

        # Load your data
        self.images = []  # List of image paths
        self.masks = []   # List of mask paths

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load image and mask
        image = Image.open(self.images[idx]).convert("RGB")
        mask = Image.open(self.masks[idx])

        # Apply transforms
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            mask = self.target_transform(mask)

        return image, mask

# Use custom dataset
dataset = MyCustomDataset("./data", split="train")

Dataset Registration

Register custom datasets for automatic discovery:

from segmentation_robustness_framework.datasets import register_dataset

@register_dataset("my_custom")
class MyCustomDataset(Dataset):
    def __init__(self, root, split="train", transform=None, target_transform=None):
        # Your dataset implementation
        pass

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Your data loading logic
        pass

# Now you can use it in configuration
# dataset:
#   name: my_custom
#   split: train

Dataset Usage in Pipeline

Datasets are automatically used by the pipeline:

from segmentation_robustness_framework.pipeline import SegmentationRobustnessPipeline

# Create pipeline with dataset
pipeline = SegmentationRobustnessPipeline(
    model=model,
    dataset=dataset,  # Your dataset here
    attacks=[FGSM(model, eps=0.1)],
    metrics=[metrics.mean_iou],
    batch_size=4,
    device="cuda"
)

results = pipeline.run()

Performance Considerations

  • Memory Efficiency: Lazy loading for large datasets
  • GPU Compatibility: Automatic device placement
  • Batch Processing: Optimized for batch inference
  • Data Augmentation: Built-in augmentation support