Skip to content

Model Loaders API

This page documents the model loading components of the Segmentation Robustness Framework.

Model Loaders

segmentation_robustness_framework.loaders.models.universal_loader

Classes

UniversalModelLoader()

Universal model loader that handles different model types and wraps them with adapters.

Supported model types
  • 'torchvision'
  • 'smp'
  • 'huggingface'
  • 'custom'
Source code in segmentation_robustness_framework/loaders/models/universal_loader.py
def __init__(self):
    self.loaders = {
        "torchvision": TorchvisionModelLoader() if TORCHVISION_INSTALLED else None,
        "smp": SMPModelLoader() if SMP_INSTALLED else None,
        "huggingface": HuggingFaceModelLoader() if HUGGINGFACE_INSTALLED else None,
        "custom": CustomModelLoader(),
    }

Functions

load_model(model_type: str, model_config: dict[str, Any], weights_path: Optional[str] = None, weight_type: str = 'full', adapter_cls: Optional[type] = None) -> nn.Module

Load model using appropriate loader and wrap with the correct adapter.

Parameters:

Name Type Description Default
model_type str

Model type identifier. Supported values: - 'torchvision': Torchvision segmentation models. - 'smp': segmentation-models-pytorch models. - 'huggingface': HuggingFace Transformers models. - Any string starting with 'custom_': Alias for custom user-defined models.

required
model_config dict[str, Any]

Configuration for model loading.

required
weights_path Optional[str]

Path to weights file (optional).

None
weight_type str

Type of weights to load ('full' or 'encoder').

'full'
adapter_cls Optional[type]

Adapter class to wrap the model. If provided, this adapter will be used instead of the default adapter for the model type.

None

Returns:

Type Description
Module

nn.Module: Loaded and adapted model.

Source code in segmentation_robustness_framework/loaders/models/universal_loader.py
def load_model(
    self,
    model_type: str,
    model_config: dict[str, Any],
    weights_path: Optional[str] = None,
    weight_type: str = "full",
    adapter_cls: Optional[type] = None,
) -> nn.Module:
    """Load model using appropriate loader and wrap with the correct adapter.

    Args:
        model_type (str): Model type identifier. Supported values:
            - `'torchvision'`: Torchvision segmentation models.
            - `'smp'`: segmentation-models-pytorch models.
            - `'huggingface'`: HuggingFace Transformers models.
            - Any string starting with `'custom_'`: Alias for custom user-defined models.
        model_config (dict[str, Any]): Configuration for model loading.
        weights_path (Optional[str]): Path to weights file (optional).
        weight_type (str): Type of weights to load ('full' or 'encoder').
        adapter_cls (Optional[type]): Adapter class to wrap the model. If provided, this
            adapter will be used instead of the default adapter for the model type.

    Returns:
        nn.Module: Loaded and adapted model.
    """
    if model_type in self.loaders:
        loader = self.loaders[model_type]
        if loader is None:
            logger.error(f"Required dependencies for {model_type} not available")
            raise ImportError(f"Required dependencies for {model_type} not available")
    elif model_type.startswith("custom_"):
        loader = self.loaders["custom"]
    else:
        logger.error(f"Unsupported model type: {model_type}")
        raise ValueError(f"Unsupported model type: {model_type}")

    try:
        model = loader.load_model(model_config)  # may return bundle
    except Exception as e:
        logger.exception(f"Failed to load model for type {model_type}: {e}")
        raise

    if hasattr(model, "model"):
        bundle = model
        model = bundle.model

    if weights_path is not None:
        try:
            model = loader.load_weights(model, weights_path, weight_type)
            logger.info(f"Loaded weights for {model_type} model from {weights_path} (type: {weight_type})")
        except Exception as e:
            logger.exception(f"Failed to load weights for {model_type} model: {e}")
            raise

    # Wrap with adapter if not already adapted
    if not isinstance(model, SegmentationModelProtocol):
        AdapterCls = adapter_cls or get_adapter(model_type)
        model = AdapterCls(model)
        logger.info(f"Wrapped model with {AdapterCls.__name__} adapter for type '{model_type}'")
    else:
        logger.info("Model already implements SegmentationModelProtocol; no adapter wrapping needed.")

    return model

Functions

segmentation_robustness_framework.loaders.models.torchvision_loader

Classes

TorchvisionModelLoader

Bases: BaseModelLoader

Loader for torchvision segmentation models.

Supports loading models and weights, including encoder-only weights. Uses the 'weights' argument as recommended by torchvision >=0.13.

Supported model_config keys
  • name (str): Model name.
  • num_classes (int): Number of classes (optional).
  • weights (str | TorchvisionWeightsEnum): Torchvision weights enum, string, or None (optional).
Example
loader = TorchvisionModelLoader()
model_config = {"name": "deeplabv3_resnet50", "num_classes": 21}
model = loader.load_model(model_config)
model = loader.load_weights(model, "weights.pth", weight_type="encoder")

Functions

load_model(model_config: dict[str, Any]) -> nn.Module

Load a torchvision segmentation model using the 'weights' argument.

Parameters:

Name Type Description Default
model_config dict
  • name (str): Model name.
  • num_classes (int): Number of classes (optional).
  • weights (str | TorchvisionWeightsEnum): Torchvision weights enum, string, or None (optional).
required

Raises:

Type Description
ValueError

If the model name is not supported.

Returns:

Type Description
Module

nn.Module: Instantiated torchvision model.

Source code in segmentation_robustness_framework/loaders/models/torchvision_loader.py
def load_model(self, model_config: dict[str, Any]) -> nn.Module:
    """Load a torchvision segmentation model using the 'weights' argument.

    Args:
        model_config (dict):
            - `name` (str): Model name.
            - `num_classes` (int): Number of classes (optional).
            - `weights` (str | TorchvisionWeightsEnum): Torchvision weights enum, string, or None (optional).

    Raises:
        ValueError: If the model name is not supported.

    Returns:
        `nn.Module`: Instantiated torchvision model.
    """
    try:
        name = model_config.get("name")
        num_classes = model_config.get("num_classes", 21)
        weights = model_config.get("weights", "__not_provided__")  # Sentinel value

        if name not in self.SUPPORTED_MODELS:
            logger.error(f"Unsupported model: {name}")
            raise ValueError(f"Unsupported model: {name}")

        model_fn = self.SUPPORTED_MODELS[name]

        weights_enum_cls = self.TORCHVISION_WEIGHTS_ENUMS.get(name)

        if weights == "__not_provided__":
            if weights_enum_cls is not None and hasattr(weights_enum_cls, "DEFAULT"):
                weights = weights_enum_cls.DEFAULT
                logger.info(f"No weights specified, using default weights for {name}.")
        elif isinstance(weights, str):
            if weights_enum_cls is not None and hasattr(weights_enum_cls, weights):
                weights = getattr(weights_enum_cls, weights)
            elif (
                weights.lower() == "default"
                and weights_enum_cls is not None
                and hasattr(weights_enum_cls, "DEFAULT")
            ):
                weights = weights_enum_cls.DEFAULT
            else:
                logger.error(f"Invalid weights: {weights}")
                raise ValueError(f"Invalid weights: {weights}")

        model = model_fn(weights=weights, num_classes=num_classes)
        logger.info(f"Loaded torchvision model: {name} with weights={weights}")

        default_num_classes = 21
        if name.startswith("lraspp"):
            default_num_classes = 21
        if num_classes != default_num_classes:
            if hasattr(model, "classifier"):
                cls = model.classifier
                if isinstance(cls, nn.Sequential):
                    cls[-1] = nn.Conv2d(cls[-1].in_channels, num_classes, kernel_size=1)
                elif hasattr(cls, "low_classifier") and hasattr(cls, "high_classifier"):
                    cls.low_classifier = nn.Conv2d(cls.low_classifier.in_channels, num_classes, kernel_size=1)
                    cls.high_classifier = nn.Conv2d(cls.high_classifier.in_channels, num_classes, kernel_size=1)
                else:
                    logger.warning(
                        "Unknown classifier type for model %s; num_classes may not be updated correctly.",
                        name,
                    )
            elif hasattr(model, "aux_classifier"):
                model.aux_classifier[-1] = nn.Conv2d(
                    model.aux_classifier[-1].in_channels, num_classes, kernel_size=1
                )
            logger.info(f"Modified classifier for {name} to output {num_classes} classes.")

        return model

    except Exception as e:
        logger.exception(f"Failed to load torchvision model: {e}")
        raise
load_weights(model: nn.Module, weights_path: str | Path, weight_type: str = 'full') -> nn.Module

Load weights into a torchvision model.

Parameters:

Name Type Description Default
model Module

Model instance.

required
weights_path str | Path

Path to weights file.

required
weight_type str

'full' for entire model, 'encoder' for backbone only.

'full'
Supported weight_type values
  • 'full': Load entire model weights.
  • 'encoder': Load encoder weights only.

Returns:

Type Description
Module

nn.Module: Model with loaded weights.

Source code in segmentation_robustness_framework/loaders/models/torchvision_loader.py
def load_weights(self, model: nn.Module, weights_path: str | Path, weight_type: str = "full") -> nn.Module:
    """Load weights into a torchvision model.

    Args:
        model (nn.Module): Model instance.
        weights_path (str | Path): Path to weights file.
        weight_type (str): `'full'` for entire model, `'encoder'` for backbone only.

    Supported weight_type values:
        - `'full'`: Load entire model weights.
        - `'encoder'`: Load encoder weights only.

    Returns:
        `nn.Module`: Model with loaded weights.
    """
    try:
        logger.info(f"Loading weights from {weights_path} (type: {weight_type})")
        checkpoint = torch.load(weights_path, map_location="cpu", weights_only=True)

        if "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        elif "model" in checkpoint:
            state_dict = checkpoint["model"]
        else:
            state_dict = checkpoint

        if weight_type == "full":
            missing, unexpected = model.load_state_dict(state_dict, strict=False)
            if missing:
                logger.warning(f"Missing keys when loading weights: {missing}")
            if unexpected:
                logger.warning(f"Unexpected keys when loading weights: {unexpected}")
            logger.info(f"Loaded full model weights into torchvision model from {weights_path}")
        elif weight_type == "encoder":
            backbone_state_dict = {
                k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith("backbone.")
            }
            missing, unexpected = model.backbone.load_state_dict(backbone_state_dict, strict=False)
            if missing:
                logger.warning(f"Missing keys when loading encoder weights: {missing}")
            if unexpected:
                logger.warning(f"Unexpected keys when loading encoder weights: {unexpected}")
            logger.info(f"Loaded encoder (backbone) weights into torchvision model from {weights_path}")
        else:
            logger.warning(f"Unknown weight_type: {weight_type}. No weights loaded.")
        return model
    except Exception as e:
        logger.exception(f"Failed to load weights into torchvision model: {e}")
        raise

segmentation_robustness_framework.loaders.models.smp_loader

Classes

SMPModelLoader

Bases: BaseModelLoader

Loader for segmentation_models_pytorch (SMP) models.

Supports loading models and weights, including encoder-only weights.

Supported model_config keys
  • architecture (str): Architecture of the model (default: 'unet').
  • encoder_name (str): Name of the encoder (default: 'resnet34').
  • encoder_weights (str): Weights of the encoder (default: 'imagenet').
  • classes (int): Number of classes (default: 1).
  • activation (str): Activation function (default: None).
  • checkpoint (str | Path): Path to the checkpoint (default: None).
Example
# Basic usage
loader = SMPModelLoader()
model_config = {
    "architecture": "unet",
    "encoder_name": "resnet34",
    "classes": 2,
}
model = loader.load_model(model_config)
model = loader.load_weights(
    model, "weights.pth", weight_type="full"
)  # load full model weights if needed
Example
# With checkpoint
loader = SMPModelLoader()
model_config = {"checkpoint": "smp-hub/upernet-convnext-tiny"}
model = loader.load_model(model_config)
Example
# With encoder-only weights
loader = SMPModelLoader()
model_config = {
    "architecture": "unet",
    "encoder_name": "resnet34",
    "classes": 2,
}
model = loader.load_model(model_config)
model = loader.load_weights(model, "weights.pth", weight_type="encoder")

Functions

load_model(model_config: dict[str, Any]) -> nn.Module

Load an SMP model from config or checkpoint.

Parameters:

Name Type Description Default
model_config dict[str, Any]
  • architecture (str): Architecture of the model.
  • encoder_name (str): Name of the encoder (optional).
  • encoder_weights (str): Weights of the encoder (optional).
  • classes (int): Number of classes (optional).
  • activation (str): Activation function (optional).
  • checkpoint (str): Path to the checkpoint (optional).
required

Raises:

Type Description
RuntimeError

If the checkpoint cannot be loaded.

Returns:

Type Description
Module

nn.Module: Instantiated SMP model.

Source code in segmentation_robustness_framework/loaders/models/smp_loader.py
def load_model(self, model_config: dict[str, Any]) -> nn.Module:
    """Load an SMP model from config or checkpoint.

    Args:
        model_config (dict[str, Any]):
            - `architecture` (str): Architecture of the model.
            - `encoder_name` (str): Name of the encoder (optional).
            - `encoder_weights` (str): Weights of the encoder (optional).
            - `classes` (int): Number of classes (optional).
            - `activation` (str): Activation function (optional).
            - `checkpoint` (str): Path to the checkpoint (optional).

    Raises:
        RuntimeError: If the checkpoint cannot be loaded.

    Returns:
        `nn.Module`: Instantiated SMP model.
    """
    smp = self._import_smp()
    checkpoint = model_config.get("checkpoint")
    try:
        if checkpoint and checkpoint.startswith("smp-hub/"):
            try:
                model = smp.from_pretrained(checkpoint)
                logger.info(f"Loaded SMP model from checkpoint: {checkpoint}")
            except Exception as e:
                logger.exception(f"Could not load checkpoint {checkpoint}: {e}")
                raise RuntimeError(f"Could not load checkpoint {checkpoint}") from e
        else:
            architecture = model_config.get("architecture", "unet")
            encoder_name = model_config.get("encoder_name", "resnet34")
            encoder_weights = model_config.get("encoder_weights", "imagenet")
            classes = model_config.get("classes", 3)
            activation = model_config.get("activation", None)

            model = smp.create_model(
                arch=architecture,
                encoder_name=encoder_name,
                encoder_weights=encoder_weights,
                classes=classes,
                activation=activation,
            )
            logger.info(f"Loaded SMP model: {architecture} with encoder {encoder_name}")

        if "classes" in model_config and hasattr(model, "classifier"):
            if model.classifier.out_channels != model_config["classes"]:
                in_ch = model.classifier.in_channels
                model.classifier = nn.Conv2d(in_ch, model_config["classes"], 1)
                logger.info(f"Adjusted classifier out_channels to {model_config['classes']}")
        return model
    except Exception as e:
        logger.exception(f"Failed to load SMP model: {e}")
        raise
load_weights(model: nn.Module, weights_path: str, weight_type: str = 'full') -> nn.Module

Load weights into SMP model.

Parameters:

Name Type Description Default
model Module

Model instance.

required
weights_path str | Path

Path to weights file.

required
weight_type str

'full' for entire model, 'encoder' for encoder only.

'full'
Supported weight_type values
  • 'full': Load entire model weights.
  • 'encoder': Load encoder weights only.

Returns:

Type Description
Module

nn.Module: Model with loaded weights.

Source code in segmentation_robustness_framework/loaders/models/smp_loader.py
def load_weights(self, model: nn.Module, weights_path: str, weight_type: str = "full") -> nn.Module:
    """Load weights into SMP model.

    Args:
        model (nn.Module): Model instance.
        weights_path (str | Path): Path to weights file.
        weight_type (str): `'full'` for entire model, `'encoder'` for encoder only.

    Supported weight_type values:
        - `'full'`: Load entire model weights.
        - `'encoder'`: Load encoder weights only.

    Returns:
        `nn.Module`: Model with loaded weights.
    """
    try:
        checkpoint = torch.load(weights_path, map_location="cpu", weights_only=True)

        if "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        else:
            state_dict = checkpoint

        if weight_type == "full":
            result = model.load_state_dict(state_dict, strict=False)
            if hasattr(result, "missing_keys") and hasattr(result, "unexpected_keys"):
                missing = result.missing_keys
                unexpected = result.unexpected_keys

                if missing:
                    logger.warning(f"Missing keys when loading full model weights: {missing}")
                if unexpected:
                    logger.warning(f"Unexpected keys when loading full model weights: {unexpected}")
            logger.info(f"Loaded full model weights into SMP model from {weights_path}")
        elif weight_type == "encoder":
            encoder_state_dict = {
                k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")
            }
            result = model.encoder.load_state_dict(encoder_state_dict, strict=False)

            if hasattr(result, "missing_keys") and hasattr(result, "unexpected_keys"):
                missing = result.missing_keys
                unexpected = result.unexpected_keys

                if missing:
                    logger.warning(f"Missing keys when loading encoder weights: {missing}")
                if unexpected:
                    logger.warning(f"Unexpected keys when loading encoder weights: {unexpected}")
                logger.info(f"Loaded encoder (backbone) weights into SMP model from {weights_path}")
            else:
                logger.info(
                    f"Loaded encoder weights (no missing/unexpected keys info available) into SMP model from {weights_path}"
                )
        else:
            logger.warning(f"Unknown weight_type: {weight_type}. No weights loaded.")
        return model
    except Exception as e:
        logger.exception(f"Failed to load weights into SMP model: {e}")
        raise

segmentation_robustness_framework.loaders.models.huggingface_loader

Classes

HuggingFaceModelLoader

Bases: BaseModelLoader

Loader for HuggingFace models.

Supported model_config keys
  • model_name (str): HuggingFace model id or path (required).
  • num_labels (int): Number of output classes (optional).
  • trust_remote_code (bool): Allow loading custom code from model repo (optional).
  • task (str): "semantic_segmentation", "instance_segmentation", "panoptic_segmentation", "image_segmentation" (optional).
  • return_processor (bool): Whether to return processor along with model (default: True).
  • config_overrides (dict): Arbitrary config attributes to override (optional).
  • processor_overrides (dict): Arbitrary processor attributes to override (optional).
Example
# Basic usage
loader = HuggingFaceModelLoader()
model_config = {
    "model_name": "nvidia/segformer-b2-finetuned-ade-512-512",
}
bundle = loader.load_model(model_config)

# Access model and processor (optional)
model = bundle.model
processor = bundle.processor
Example
# With config and processor overrides
loader = HuggingFaceModelLoader()
model_config = {
    "model_name": "nvidia/segformer-b2-finetuned-ade-512-512",
    "num_labels": 150,
    "config_overrides": {"ignore_mismatched_sizes": True},
    "processor_overrides": {"do_resize": False},
}
bundle = loader.load_model(model_config)
Example
# With task override
loader = HuggingFaceModelLoader()
model_config = {
    "model_name": "facebook/maskformer-swin-tiny-coco",
    "task": "instance_segmentation",
}
bundle = loader.load_model(model_config)

Functions

load_model(model_config: dict[str, Any]) -> HFSegmentationBundle | nn.Module

Load HuggingFace model and optionally its processor.

Parameters:

Name Type Description Default
model_config dict[str, Any]
  • model_name (str): Model name/path (required).
  • model_cls (Callable): Model class to use (optional).
  • num_labels (int): Number of classes (optional).
  • trust_remote_code (bool): Trust remote code (optional).
  • task (str): Model task (optional).
  • return_processor (bool): Return processor (optional).
  • config_overrides (dict[str, Any]): Config attribute overrides (optional).
  • processor_overrides (dict[str, Any]): Processor attribute overrides (optional).
required

Returns:

Type Description
HFSegmentationBundle | Module

HFSegmentationBundle | nn.Module: Model (and processor if requested).

Source code in segmentation_robustness_framework/loaders/models/huggingface_loader.py
def load_model(self, model_config: dict[str, Any]) -> "HFSegmentationBundle | nn.Module":
    """Load HuggingFace model and optionally its processor.

    Args:
        model_config (dict[str, Any]):
            - `model_name` (str): Model name/path (required).
            - `model_cls` (Callable): Model class to use (optional).
            - `num_labels` (int): Number of classes (optional).
            - `trust_remote_code` (bool): Trust remote code (optional).
            - `task` (str): Model task (optional).
            - `return_processor` (bool): Return processor (optional).
            - `config_overrides` (dict[str, Any]): Config attribute overrides (optional).
            - `processor_overrides` (dict[str, Any]): Processor attribute overrides (optional).

    Returns:
        `HFSegmentationBundle | nn.Module`: Model (and processor if requested).

    """
    try:
        transformers = self._import_transformers()
        model_name = model_config["model_name"]
        model_cls = model_config.get("model_cls", None)
        task = model_config.get("task", None)
        return_processor = model_config.get("return_processor", True)
        trust_remote_code = model_config.get("trust_remote_code", False)

        config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
        if "num_labels" in model_config:
            config.num_labels = model_config["num_labels"]
        if "config_overrides" in model_config:
            for k, v in model_config["config_overrides"].items():
                setattr(config, k, v)

        if model_cls is None:
            if task == "semantic_segmentation":
                model_cls = transformers.AutoModelForSemanticSegmentation
            elif task == "instance_segmentation":
                model_cls = transformers.AutoModelForInstanceSegmentation
            elif task == "panoptic_segmentation":
                model_cls = transformers.AutoModelForPanopticSegmentation
            elif task == "image_segmentation":
                model_cls = transformers.AutoModelForImageSegmentation
            else:
                model_cls = transformers.AutoModel
        else:
            model_cls = getattr(transformers, model_cls)

        model = model_cls.from_pretrained(model_name, config=config, trust_remote_code=trust_remote_code)
        logger.info(f"Loaded HuggingFace model: {model_name} (task: {task})")

        if return_processor:
            processor = transformers.AutoImageProcessor.from_pretrained(
                model_name, trust_remote_code=trust_remote_code
            )
            if "processor_overrides" in model_config:
                for k, v in model_config["processor_overrides"].items():
                    setattr(processor, k, v)
            from segmentation_robustness_framework.loaders.models.hf_bundle import HFSegmentationBundle

            logger.info(f"Loaded processor for HuggingFace model: {model_name}")
            return HFSegmentationBundle(model=model, processor=processor)
        return model
    except Exception as e:
        logger.exception(f"Failed to load HuggingFace model: {e}")
        raise
load_weights(model: nn.Module, weights_path: str, weight_type: str = 'full') -> nn.Module

Load weights into HuggingFace model.

Parameters:

Name Type Description Default
model Module

Model instance.

required
weights_path str | Path

Path to weights file.

required
weight_type str

'full' for entire model, 'encoder' for encoder only.

'full'

Returns:

Type Description
Module

nn.Module: Model with loaded weights.

Source code in segmentation_robustness_framework/loaders/models/huggingface_loader.py
def load_weights(self, model: nn.Module, weights_path: str, weight_type: str = "full") -> nn.Module:
    """Load weights into HuggingFace model.

    Args:
        model (nn.Module): Model instance.
        weights_path (str | Path): Path to weights file.
        weight_type (str): 'full' for entire model, 'encoder' for encoder only.

    Returns:
        `nn.Module`: Model with loaded weights.
    """
    try:
        checkpoint = torch.load(weights_path, map_location="cpu")

        if "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        else:
            state_dict = checkpoint

        if weight_type == "full":
            result = model.load_state_dict(state_dict, strict=False)
            if hasattr(result, "missing_keys") and hasattr(result, "unexpected_keys"):
                missing = result.missing_keys
                unexpected = result.unexpected_keys
                if missing:
                    logger.warning(f"Missing keys when loading full model weights: {missing}")
                if unexpected:
                    logger.warning(f"Unexpected keys when loading full model weights: {unexpected}")
            logger.info(f"Loaded full model weights into HuggingFace model from {weights_path}")
        elif weight_type == "encoder":
            encoder_state_dict = {
                k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")
            }
            result = model.encoder.load_state_dict(encoder_state_dict, strict=False)
            if hasattr(result, "missing_keys") and hasattr(result, "unexpected_keys"):
                missing = result.missing_keys
                unexpected = result.unexpected_keys
                if missing:
                    logger.warning(f"Missing keys when loading encoder weights: {missing}")
                if unexpected:
                    logger.warning(f"Unexpected keys when loading encoder weights: {unexpected}")
                logger.info(f"Loaded encoder (backbone) weights into HuggingFace model from {weights_path}")
            else:
                logger.info(
                    f"Loaded encoder weights (no missing/unexpected keys info available) into HuggingFace model from {weights_path}"
                )
        else:
            logger.warning(f"Unknown weight_type: {weight_type}. No weights loaded.")
        return model
    except Exception as e:
        logger.exception(f"Failed to load weights into HuggingFace model: {e}")
        raise

segmentation_robustness_framework.loaders.models.custom_loader

Classes

CustomModelLoader

Bases: BaseModelLoader

Loader for custom user models.

Supported model_config keys
  • model_class (str | Callable[..., Any]): Model class or factory function (required).
  • model_args (list[Any]): List of positional arguments for model initialization (optional).
  • model_kwargs (dict[str, Any]): Dict of keyword arguments for model initialization (optional).
Example
loader = CustomModelLoader()
model_config = {
    "model_class": MyCustomSegmentationModel,
    "model_args": [3, 21],
    "model_kwargs": {},
}
model = loader.load_model(model_config)
model = loader.load_weights(model, "weights.pth", weight_type="full")

Functions

load_model(model_config: dict[str, Any]) -> nn.Module

Load custom model.

Parameters:

Name Type Description Default
model_config dict[str, Any]
  • model_class (str | Callable[..., Any]): Model class or factory function.
  • model_args (list[Any]): List of positional arguments for model initialization.
  • model_kwargs (dict[str, Any]): Dict of keyword arguments for model initialization.
required

Raises:

Type Description
ValueError

If model_class is not callable.

Returns:

Type Description
Module

nn.Module: Instantiated model.

Source code in segmentation_robustness_framework/loaders/models/custom_loader.py
def load_model(self, model_config: dict[str, Any]) -> nn.Module:
    """Load custom model.

    Args:
        model_config (dict[str, Any]):
            - `model_class` (str | Callable[..., Any]): Model class or factory function.
            - `model_args` (list[Any]): List of positional arguments for model initialization.
            - `model_kwargs` (dict[str, Any]): Dict of keyword arguments for model initialization.

    Raises:
        ValueError: If `model_class` is not callable.

    Returns:
        `nn.Module`: Instantiated model.
    """
    try:
        model_class = model_config["model_class"]
        model_args = model_config.get("model_args", [])
        model_kwargs = model_config.get("model_kwargs", {})

        if isinstance(model_class, str):
            from segmentation_robustness_framework.utils.loader_utils import resolve_model_class

            model_class = resolve_model_class(model_class)
            model = model_class(*model_args, **model_kwargs)
            logger.info(f"Loaded custom model: {model_class.__name__}")
        elif callable(model_class):
            model = model_class(*model_args, **model_kwargs)
            logger.info(f"Loaded custom model: {model_class.__name__}")
        else:
            raise ValueError(f"model_class must be a string or callable, got {type(model_class)}")
        return model
    except Exception as e:
        logger.exception(f"Failed to load custom model: {e}")
        raise
load_weights(model: nn.Module, weights_path: str, weight_type: str = 'full') -> nn.Module

Load weights into custom model.

Parameters:

Name Type Description Default
model Module

Model instance.

required
weights_path str | Path

Path to weights file.

required
weight_type str

'full' for entire model, 'encoder' for encoder only.

'full'
Supported weight_type values
  • 'full': Load entire model weights.
  • 'encoder': Load encoder weights only.

Returns:

Type Description
Module

nn.Module: Model with loaded weights.

Source code in segmentation_robustness_framework/loaders/models/custom_loader.py
def load_weights(self, model: nn.Module, weights_path: str, weight_type: str = "full") -> nn.Module:
    """Load weights into custom model.

    Args:
        model (nn.Module): Model instance.
        weights_path (str | Path): Path to weights file.
        weight_type (str): `'full'` for entire model, `'encoder'` for encoder only.

    Supported weight_type values:
        - `'full'`: Load entire model weights.
        - `'encoder'`: Load encoder weights only.

    Returns:
        `nn.Module`: Model with loaded weights.
    """
    try:
        checkpoint = torch.load(weights_path, map_location="cpu")

        if "state_dict" in checkpoint:
            state_dict = checkpoint["state_dict"]
        else:
            state_dict = checkpoint

        if weight_type == "full":
            model.load_state_dict(state_dict, strict=False)
            logger.info(f"Loaded full model weights into custom model from {weights_path}")
        elif weight_type == "encoder":
            if hasattr(model, "encoder"):
                encoder_state_dict = {
                    k.replace("encoder.", ""): v for k, v in state_dict.items() if k.startswith("encoder.")
                }
                model.encoder.load_state_dict(encoder_state_dict, strict=False)
                logger.info(f"Loaded encoder (backbone) weights into custom model from {weights_path}")
            else:
                logger.warning("Model has no 'encoder' attribute, loading full model weights")
                model.load_state_dict(state_dict, strict=False)
        else:
            raise ValueError(f"Unknown weight_type: {weight_type}. No weights loaded.")
        return model
    except Exception as e:
        logger.exception(f"Failed to load weights into custom model: {e}")
        raise

Model Loading Overview

The framework provides specialized loaders for different model types, each designed to handle the specific requirements and output formats of different model architectures.

Universal Model Loader

The UniversalModelLoader is the main entry point for model loading. It automatically selects the appropriate specialized loader based on the model type.

from segmentation_robustness_framework.loaders import UniversalModelLoader

# Load a torchvision model
model = UniversalModelLoader().load_model(
    model_type="torchvision",
    model_config={"name": "deeplabv3_resnet50", "num_classes": 21}
)

# Load an SMP model
model = UniversalModelLoader().load_model(
    model_type="smp",
    model_config={"architecture": "unet", "encoder_name": "resnet34", "classes": 21}
)

# Load a HuggingFace model
model = UniversalModelLoader().load_model(
    model_type="huggingface",
    model_config={"model_name": "nvidia/segformer-b0-finetuned-ade-512-512"}
)

Supported Model Types

Torchvision Models

# Available models
torchvision_models = [
    "deeplabv3_resnet50",
    "deeplabv3_resnet101",
    "deeplabv3_mobilenetv3_large",
    "fcn_resnet50",
    "fcn_resnet101",
    "lraspp_mobilenet_v3_large"
]

# Example usage
model = UniversalModelLoader().load_model(
    model_type="torchvision",
    model_config={"name": "deeplabv3_resnet50", "num_classes": 21}
)

# With custom weights
model = UniversalModelLoader().load_model(
    model_type="torchvision",
    model_config={
        "name": "deeplabv3_resnet50", 
        "num_classes": 21,
        "weights": "COCO_WITH_VOC_LABELS_V1"
    }
)

SMP Models

# Available architectures
smp_architectures = [
    "unet", "unetplusplus", "manet", "linknet",
    "fpn", "pspnet", "pan", "deeplabv3", "deeplabv3plus"
]

# Available encoders
smp_encoders = [
    "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
    "resnext50_32x4d", "resnext101_32x8d",
    "timm-efficientnet-b0", "timm-efficientnet-b1", "timm-efficientnet-b2"
]

# Example usage
model = UniversalModelLoader().load_model(
    model_type="smp",
    model_config={
        "architecture": "unet",
        "encoder_name": "resnet34",
        "encoder_weights": "imagenet",
        "classes": 21
    }
)

HuggingFace Models

# Example usage
model = UniversalModelLoader().load_model(
    model_type="huggingface",
    model_config={
        "model_name": "nvidia/segformer-b0-finetuned-ade-512-512",
        "trust_remote_code": True
    }
)

# With additional configuration
model = UniversalModelLoader().load_model(
    model_type="huggingface",
    model_config={
        "model_name": "nvidia/segformer-b0-finetuned-ade-512-512",
        "num_labels": 150,
        "config_overrides": {"ignore_mismatched_sizes": True},
        "processor_overrides": {"do_resize": False}
    }
)

Custom Models

For custom models, you can use the CustomModelLoader:

from segmentation_robustness_framework.loaders import CustomModelLoader

class MyCustomModel(nn.Module):
    def __init__(self, num_classes=21):
        super().__init__()
        # Your model architecture here
        pass

    def forward(self, x):
        # Your forward pass
        return logits

# Load custom model
model = CustomModelLoader().load_model(
    model_type="custom",
    model_config={"model_class": MyCustomModel, "num_classes": 21}
)

Model Configuration

Each model type accepts different configuration parameters:

Torchvision Configuration

model:
  type: torchvision
  config:
    name: deeplabv3_resnet50
    num_classes: 21
    weights: COCO_WITH_VOC_LABELS_V1  # Optional: specify weights

SMP Configuration

model:
  type: smp
  architecture: unet
  encoder_name: resnet34
  encoder_weights: imagenet
  classes: 21
  activation: None  # Optional

HuggingFace Configuration

model:
  type: huggingface
  config:
    model_name: nvidia/segformer-b0-finetuned-ade-512-512
    trust_remote_code: true
    num_labels: 150

Error Handling

The model loaders include comprehensive error handling:

try:
    model = UniversalModelLoader().load_model(
        model_type="torchvision",
        model_config={"name": "nonexistent_model"}
    )
except ValueError as e:
    print(f"Model loading failed: {e}")
    # Handle error appropriately