Skip to content

Adapters API

This page documents the adapter components of the Segmentation Robustness Framework.

Adapters

segmentation_robustness_framework.adapters.base_protocol

Classes

SegmentationModelProtocol

Bases: Protocol

Define the interface for segmentation model adapters.

All segmentation model adapters must implement this interface, providing methods for obtaining logits and predictions, and exposing the number of output classes.

Attributes:

Name Type Description
num_classes int

Number of output classes for segmentation.

Functions

logits(x: torch.Tensor) -> torch.Tensor

Return raw logits for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/base_protocol.py
def logits(self, x: torch.Tensor) -> torch.Tensor:
    """Return raw logits for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    ...
predictions(x: torch.Tensor) -> torch.Tensor

Return predicted class labels for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Predicted label tensor of shape (B, H, W).

Source code in segmentation_robustness_framework/adapters/base_protocol.py
def predictions(self, x: torch.Tensor) -> torch.Tensor:
    """Return predicted class labels for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Predicted label tensor of shape (B, H, W).
    """
    ...

segmentation_robustness_framework.adapters.torchvision_adapter

Classes

TorchvisionAdapter(model: torch.nn.Module)

Bases: Module, SegmentationModelProtocol

Adapter for Torchvision segmentation models.

This adapter standardizes the interface for Torchvision models that return a dict with an 'out' key.

Attributes:

Name Type Description
model Module

The underlying Torchvision model.

num_classes int

Number of output classes.

Initialize the adapter.

Parameters:

Name Type Description Default
model Module

Torchvision segmentation model instance.

required
Source code in segmentation_robustness_framework/adapters/torchvision_adapter.py
def __init__(self, model: torch.nn.Module):
    """Initialize the adapter.

    Args:
        model (torch.nn.Module): Torchvision segmentation model instance.
    """
    super().__init__()
    self.model = model

    if hasattr(model, "classifier") and hasattr(model.classifier, "out_channels"):
        self.num_classes = model.classifier.out_channels
    else:
        self.num_classes = getattr(model, "num_classes", 21)

Functions

logits(x: torch.Tensor) -> torch.Tensor

Return raw logits for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/torchvision_adapter.py
def logits(self, x: torch.Tensor) -> torch.Tensor:
    """Return raw logits for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    return self.model(x)["out"]
predictions(x: torch.Tensor) -> torch.Tensor

Return predicted class labels for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Predicted label tensor of shape (B, H, W).

Source code in segmentation_robustness_framework/adapters/torchvision_adapter.py
def predictions(self, x: torch.Tensor) -> torch.Tensor:
    """Return predicted class labels for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Predicted label tensor of shape (B, H, W).
    """
    logits = self.logits(x)
    return torch.argmax(logits, dim=1)
forward(x: torch.Tensor) -> torch.Tensor

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/torchvision_adapter.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    return self.logits(x)

Functions

segmentation_robustness_framework.adapters.smp_adapter

Classes

SMPAdapter(model: torch.nn.Module)

Bases: Module, SegmentationModelProtocol

Adapter for segmentation_models_pytorch (SMP) models.

This adapter standardizes the interface for SMP models that return logits directly.

Attributes:

Name Type Description
model Module

The underlying SMP model.

num_classes int

Number of output classes.

Initialize the adapter.

Parameters:

Name Type Description Default
model Module

SMP segmentation model instance.

required
Source code in segmentation_robustness_framework/adapters/smp_adapter.py
def __init__(self, model: torch.nn.Module):
    """Initialize the adapter.

    Args:
        model (torch.nn.Module): SMP segmentation model instance.
    """
    super().__init__()
    self.model = model

    if hasattr(model, "classifier") and hasattr(model.classifier, "out_channels"):
        self.num_classes = model.classifier.out_channels
    else:
        self.num_classes = getattr(model, "num_classes", 1)

Functions

logits(x: torch.Tensor) -> torch.Tensor

Return raw logits for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/smp_adapter.py
def logits(self, x: torch.Tensor) -> torch.Tensor:
    """Return raw logits for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    return self.model(x)
predictions(x: torch.Tensor) -> torch.Tensor

Return predicted class labels for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Predicted label tensor of shape (B, H, W).

Source code in segmentation_robustness_framework/adapters/smp_adapter.py
def predictions(self, x: torch.Tensor) -> torch.Tensor:
    """Return predicted class labels for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Predicted label tensor of shape (B, H, W).
    """
    logits = self.logits(x)
    return torch.argmax(logits, dim=1)
forward(x: torch.Tensor) -> torch.Tensor

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/smp_adapter.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    return self.logits(x)

Functions

segmentation_robustness_framework.adapters.huggingface_adapter

Classes

HuggingFaceAdapter(model: torch.nn.Module)

Bases: Module, SegmentationModelProtocol

Adapter for HuggingFace segmentation models.

This adapter standardizes the interface for HuggingFace models that return an object with a 'logits' attribute.

Attributes:

Name Type Description
model Module

The underlying HuggingFace model.

num_classes int

Number of output classes.

Initialize the adapter.

Parameters:

Name Type Description Default
model Module

HuggingFace segmentation model instance.

required
Source code in segmentation_robustness_framework/adapters/huggingface_adapter.py
def __init__(self, model: torch.nn.Module):
    """Initialize the adapter.

    Args:
        model (torch.nn.Module): HuggingFace segmentation model instance.
    """
    super().__init__()
    self.model = model

    if hasattr(model, "config") and hasattr(model.config, "num_labels"):
        self.num_classes = model.config.num_labels
    elif hasattr(model, "config") and hasattr(model.config, "num_classes"):
        self.num_classes = model.config.num_classes
    else:
        raise ValueError("Model config does not contain num_labels or num_classes")

Functions

logits(x: torch.Tensor) -> torch.Tensor

Return raw logits for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/huggingface_adapter.py
def logits(self, x: torch.Tensor) -> torch.Tensor:
    """Return raw logits for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    try:
        device = next(self.model.parameters()).device
    except StopIteration:
        device = x.device

    if x.device != device:
        x = x.to(device, non_blocking=True)

    try:
        if x.requires_grad:
            self.model.train()
            output = self.model(pixel_values=x)
            logits = output.logits
            self.model.eval()
        else:
            with torch.no_grad():
                output = self.model(pixel_values=x)
                logits = output.logits

        if logits.device != device:
            logits = logits.to(device, non_blocking=True)

        return logits
    except AttributeError as e:
        raise e
    except Exception as e:
        if device.type == "cuda":
            torch.cuda.empty_cache()
        raise RuntimeError(f"HuggingFace model forward pass failed: {e}") from e
predictions(x: torch.Tensor) -> torch.Tensor

Return predicted class labels for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Predicted label tensor of shape (B, H, W).

Source code in segmentation_robustness_framework/adapters/huggingface_adapter.py
def predictions(self, x: torch.Tensor) -> torch.Tensor:
    """Return predicted class labels for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Predicted label tensor of shape (B, H, W).
    """
    logits = self.logits(x)
    predictions = torch.argmax(logits, dim=1)

    del logits
    try:
        if next(self.model.parameters()).device.type == "cuda":
            torch.cuda.empty_cache()
    except StopIteration:
        pass

    return predictions
forward(x: torch.Tensor) -> torch.Tensor

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/huggingface_adapter.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    return self.logits(x)

Functions

segmentation_robustness_framework.adapters.custom_adapter

Classes

CustomAdapter(model: torch.nn.Module, num_classes: int = 1)

Bases: Module, SegmentationModelProtocol

Provide a template adapter for custom user segmentation models.

This class demonstrates how to implement an adapter for a user-defined segmentation model. Users should modify this template to fit their model's output structure and register it using the adapter registry if desired.

Attributes:

Name Type Description
model Module

The underlying custom model.

num_classes int

Number of output classes.

Initialize the custom adapter.

Parameters:

Name Type Description Default
model Module

Custom segmentation model instance.

required
num_classes int

Number of output classes.

1
Source code in segmentation_robustness_framework/adapters/custom_adapter.py
def __init__(self, model: torch.nn.Module, num_classes: int = 1):
    """Initialize the custom adapter.

    Args:
        model (torch.nn.Module): Custom segmentation model instance.
        num_classes (int): Number of output classes.
    """
    super().__init__()
    self.model = model
    self.num_classes = num_classes

Functions

logits(x: torch.Tensor) -> torch.Tensor

Return raw logits for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/custom_adapter.py
def logits(self, x: torch.Tensor) -> torch.Tensor:
    """Return raw logits for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    # Modify this line to match your model's output
    return self.model(x)
predictions(x: torch.Tensor) -> torch.Tensor

Return predicted class labels for input images.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Predicted label tensor of shape (B, H, W).

Source code in segmentation_robustness_framework/adapters/custom_adapter.py
def predictions(self, x: torch.Tensor) -> torch.Tensor:
    """Return predicted class labels for input images.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Predicted label tensor of shape (B, H, W).
    """
    logits = self.logits(x)
    return torch.argmax(logits, dim=1)
forward(x: torch.Tensor) -> torch.Tensor

Perform forward pass through the model.

Parameters:

Name Type Description Default
x Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Type Description
Tensor

torch.Tensor: Logits tensor of shape (B, num_classes, H, W).

Source code in segmentation_robustness_framework/adapters/custom_adapter.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Perform forward pass through the model.

    Args:
        x (torch.Tensor): Input image tensor of shape (B, C, H, W).

    Returns:
        torch.Tensor: Logits tensor of shape (B, num_classes, H, W).
    """
    return self.logits(x)

Adapter Overview

Adapters provide a standardized interface for different model types, ensuring compatibility with the framework's evaluation pipeline. Each adapter implements the SegmentationModelProtocol interface.

SegmentationModelProtocol

The base protocol that all adapters must implement:

from typing import Protocol
import torch

class SegmentationModelProtocol(Protocol):
    """Standardized interface for segmentation models."""

    num_classes: int

    def logits(self, x: torch.Tensor) -> torch.Tensor:
        """Return raw model outputs [B, C, H, W]"""
        ...

    def predictions(self, x: torch.Tensor) -> torch.Tensor:
        """Return predicted labels [B, H, W]"""
        ...

Available Adapters

TorchvisionAdapter

Adapts torchvision segmentation models (DeepLab, FCN, LRASPP):

from segmentation_robustness_framework.adapters import TorchvisionAdapter
import torchvision.models.segmentation as segmentation

# Create a torchvision model
model = segmentation.deeplabv3_resnet50(pretrained=True)

# Wrap with adapter
adapter = TorchvisionAdapter(model)

# Use in pipeline
logits = adapter.logits(x)  # [B, C, H, W]
predictions = adapter.predictions(x)  # [B, H, W]

SMPAdapter

Adapts segmentation_models_pytorch models (UNet, LinkNet, FPN, etc.):

from segmentation_robustness_framework.adapters import SMPAdapter
import segmentation_models_pytorch as smp

# Create an SMP model
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=21
)

# Wrap with adapter
adapter = SMPAdapter(model)

# Use in pipeline
logits = adapter.logits(x)  # [B, C, H, W]
predictions = adapter.predictions(x)  # [B, H, W]

HuggingFaceAdapter

Adapts HuggingFace transformer models:

from segmentation_robustness_framework.adapters import HuggingFaceAdapter
from transformers import SegformerForSemanticSegmentation

# Create a HuggingFace model
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b0-finetuned-ade-512-512"
)

# Wrap with adapter
adapter = HuggingFaceAdapter(model)

# Use in pipeline
logits = adapter.logits(x)  # [B, C, H, W]
predictions = adapter.predictions(x)  # [B, H, W]

CustomAdapter

Template for creating custom adapters:

from segmentation_robustness_framework.adapters import CustomAdapter
import torch.nn as nn

class MyCustomModel(nn.Module):
    def __init__(self, num_classes=21):
        super().__init__()
        # Your model architecture here
        pass

    def forward(self, x):
        # Your forward pass
        return logits

# Create custom adapter
class MyCustomAdapter(CustomAdapter):
    def __init__(self, model: MyCustomModel):
        super().__init__(model)
        self.num_classes = 21

    def logits(self, x):
        return self.model(x)

    def predictions(self, x):
        logits = self.logits(x)
        return torch.argmax(logits, dim=1)

# Use custom adapter
model = MyCustomModel()
adapter = MyCustomAdapter(model)

Adapter Registration

Register custom adapters for automatic discovery:

from segmentation_robustness_framework.adapters import register_adapter

@register_adapter("my_custom")
class MyCustomAdapter(CustomAdapter):
    def __init__(self, model):
        super().__init__(model)
        self.num_classes = 21

    def logits(self, x):
        return self.model(x)

    def predictions(self, x):
        logits = self.logits(x)
        return torch.argmax(logits, dim=1)

# Now you can use it with the registered name
model = UniversalModelLoader().load_model(
    model_type="my_custom",  # Uses the registered adapter
    model_config={"model_class": MyCustomModel}
)

Adapter Usage in Pipeline

Adapters are automatically used by the model loaders:

from segmentation_robustness_framework.loaders import UniversalModelLoader

# The loader automatically creates the appropriate adapter
model = UniversalModelLoader().load_model(
    model_type="torchvision",
    model_config={"name": "deeplabv3_resnet50"}
)

# The model is already wrapped with the correct adapter
logits = model.logits(x)
predictions = model.predictions(x)

How Adapter Selection Works

  1. Automatic Selection: The UniversalModelLoader automatically selects the appropriate adapter based on the model_type
  2. Registry Lookup: It uses get_adapter(model_type) to find the registered adapter
  3. Default Mapping: Built-in adapters are pre-registered with their model type names
  4. Custom Override: You can pass adapter_cls parameter to override the default adapter
  5. Protocol Check: If the model already implements SegmentationModelProtocol, no adapter is applied
# The selection process:
model_type = "torchvision"
adapter_cls = get_adapter(model_type)  # Returns TorchvisionAdapter
model = adapter_cls(raw_model)  # Wraps the model

Adapter Selection

Adapters are automatically selected based on the model type. The framework uses the following mapping:

  • torchvisionTorchvisionAdapter
  • smpSMPAdapter
  • huggingfaceHuggingFaceAdapter
  • custom_*CustomAdapter
# The universal loader automatically selects the correct adapter
model = UniversalModelLoader().load_model(
    model_type="torchvision",  # Will use TorchvisionAdapter
    model_config={"name": "deeplabv3_resnet50"}
)

# For custom models, you can override the adapter
from segmentation_robustness_framework.adapters import MyCustomAdapter

model = UniversalModelLoader().load_model(
    model_type="custom",
    model_config={"model_class": MyCustomModel},
    adapter_cls=MyCustomAdapter  # Override default adapter
)

Error Handling

Adapters include comprehensive error handling:

try:
    adapter = TorchvisionAdapter(model)
    logits = adapter.logits(x)
except Exception as e:
    print(f"Adapter error: {e}")
    # Handle error appropriately

Performance Considerations

  • Memory Efficiency: Adapters are lightweight wrappers
  • GPU Compatibility: All adapters support GPU acceleration
  • Batch Processing: Optimized for batch inference
  • Gradient Flow: Preserves gradients for adversarial training