Skip to content

Metrics API

This page documents the evaluation metrics components of the Segmentation Robustness Framework.

Metrics Classes

segmentation_robustness_framework.metrics.base_metrics

Classes

MetricsCollection(num_classes: int, ignore_index: int = 255)

Implements metrics to evaluate the quality of multiclass semantic segmentation models.

Attributes:

Name Type Description
targets Tensor

Ground truth segmentation mask [C, H, W], where each pixel value is the true class.

preds Tensor

Predicted segmentation mask [C, H, W], where each pixel value is the predicted class.

num_classes int

The number of classes.

Initialize segmentation metrics.

Parameters:

Name Type Description Default
num_classes int

Number of classes for segmentation.

required
ignore_index int

Index to ignore in evaluation. Defaults to 255.

255

Raises:

Type Description
TypeError

If num_classes is not an integer.

ValueError

If num_classes is less than 2.

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def __init__(self, num_classes: int, ignore_index: int = 255) -> None:
    """Initialize segmentation metrics.

    Args:
        num_classes (int): Number of classes for segmentation.
        ignore_index (int): Index to ignore in evaluation. Defaults to 255.

    Raises:
        TypeError: If `num_classes` is not an integer.
        ValueError: If `num_classes` is less than 2.
    """
    if not isinstance(num_classes, int):
        raise TypeError("The number of classes must be integer")
    if num_classes < 2:
        raise ValueError("The number of classes must be 2 or more")

    self.num_classes = num_classes
    self.ignore_index = ignore_index

Functions

mean_iou(targets: torch.Tensor, preds: torch.Tensor, average: str = 'macro') -> float

Compute mean Intersection over Union metric.

Parameters:

Name Type Description Default
targets Tensor

Ground-truth segmentation masks.

required
preds Tensor

Predicted segmentation masks.

required
average str

Type of averaging to use: "macro" or "micro". Defaults to "macro".

'macro'

Returns:

Name Type Description
float float

Mean IoU.

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def mean_iou(self, targets: torch.Tensor, preds: torch.Tensor, average: str = "macro") -> float:
    """Compute mean Intersection over Union metric.

    Args:
        targets (torch.Tensor): Ground-truth segmentation masks.
        preds (torch.Tensor): Predicted segmentation masks.
        average (str): Type of averaging to use: "macro" or "micro". Defaults to "macro".

    Returns:
        float: Mean IoU.
    """
    assert average in ["macro", "micro"]

    true_mask, pred_mask = self._preprocess_input_data(targets, preds)

    # Sums for micro average
    total_intersection = 0
    total_union = 0

    iou_scores = []
    for cls in range(self.num_classes):
        if cls == self.ignore_index:
            continue

        pred = (pred_mask == cls).astype(np.int32)
        true = (true_mask == cls).astype(np.int32)

        intersection = np.sum(pred * true)
        union = np.sum(pred) + np.sum(true) - intersection

        if np.sum(true) == 0 and np.sum(pred) == 0:
            iou = np.nan
        else:
            iou = intersection / union if union > 0 else 0.0
        iou_scores.append(round(iou, 3))

        total_intersection += intersection
        total_union += union

    if average == "macro":
        return round(np.nanmean(iou_scores), 3)
    if average == "micro":
        return round(total_intersection / total_union, 3) if total_union > 0 else 0.0
pixel_accuracy(targets: torch.Tensor, preds: torch.Tensor) -> float

Compute pixel accuracy metric.

Parameters:

Name Type Description Default
targets Tensor

Ground-truth segmentation masks.

required
preds Tensor

Predicted segmentation masks.

required

Returns:

Name Type Description
float float

Pixel accuracy.

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def pixel_accuracy(self, targets: torch.Tensor, preds: torch.Tensor) -> float:
    """Compute pixel accuracy metric.

    Args:
        targets (torch.Tensor): Ground-truth segmentation masks.
        preds (torch.Tensor): Predicted segmentation masks.

    Returns:
        float: Pixel accuracy.
    """
    true_mask, pred_mask = self._preprocess_input_data(targets, preds)

    correct_pixels = (pred_mask == true_mask).sum()
    total_pixels = true_mask.size
    return round(correct_pixels / total_pixels, 3) if total_pixels > 0 else 0.0
precision(targets: torch.Tensor, preds: torch.Tensor, average: str = 'macro') -> float

Compute precision metric.

Parameters:

Name Type Description Default
targets Tensor

Ground-truth segmentation masks.

required
preds Tensor

Predicted segmentation masks.

required
average str

Type of averaging to use: "macro" or "micro". Defaults to "macro"

'macro'

Returns:

Name Type Description
float float

Precision metric.

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def precision(self, targets: torch.Tensor, preds: torch.Tensor, average: str = "macro") -> float:
    """Compute precision metric.

    Args:
        targets (torch.Tensor): Ground-truth segmentation masks.
        preds (torch.Tensor): Predicted segmentation masks.
        average (str): Type of averaging to use: "macro" or "micro". Defaults to "macro"

    Returns:
        float: Precision metric.
    """
    assert average in ["macro", "micro"]

    true_mask, pred_mask = self._preprocess_input_data(targets, preds)
    precision = np.zeros(self.num_classes)

    for cls in range(self.num_classes):
        if cls == self.ignore_index:
            continue

        pred_class = pred_mask == cls
        true_class = true_mask == cls

        true_positive = np.logical_and(pred_class, true_class).sum()
        false_positive = np.logical_and(pred_class, np.logical_not(true_class)).sum()

        if true_positive + false_positive == 0:
            precision[cls] = np.nan
        else:
            precision[cls] = true_positive / (true_positive + false_positive)

    if average == "macro":
        macro_precision = np.nanmean(precision) if not np.all(np.isnan(precision)) else 0.0
        return round(macro_precision, 3)

    if average == "micro":
        total_tp = (np.logical_and(pred_mask == true_mask, true_mask != -1)).sum()
        total_fp = (np.logical_and(pred_mask != true_mask, pred_mask != -1)).sum()
        micro_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
        return round(micro_precision, 3)
recall(targets: torch.Tensor, preds: torch.Tensor, average: str = 'macro') -> float

Compute recall metric.

Parameters:

Name Type Description Default
targets Tensor

Ground-truth segmentation masks.

required
preds Tensor

Predicted segmentation masks.

required
average str

Type of averaging to use: "macro" or "micro". Defaults to "macro"

'macro'

Returns:

Name Type Description
float float

Recall metric.

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def recall(self, targets: torch.Tensor, preds: torch.Tensor, average: str = "macro") -> float:
    """Compute recall metric.

    Args:
        targets (torch.Tensor): Ground-truth segmentation masks.
        preds (torch.Tensor): Predicted segmentation masks.
        average (str): Type of averaging to use: "macro" or "micro". Defaults to "macro"

    Returns:
        float: Recall metric.
    """
    assert average in ["macro", "micro"]
    true_mask, pred_mask = self._preprocess_input_data(targets, preds)
    recall = np.zeros(self.num_classes)

    for cls in range(self.num_classes):
        if cls == self.ignore_index:
            continue
        pred_class = pred_mask == cls
        true_class = true_mask == cls

        true_positive = np.logical_and(pred_class, true_class).sum()
        false_negative = np.logical_and(np.logical_not(pred_class), true_class).sum()

        if true_positive + false_negative == 0:
            recall[cls] = np.nan
        else:
            recall[cls] = true_positive / (true_positive + false_negative)

    if average == "macro":
        macro_recall = np.nanmean(recall) if not np.all(np.isnan(recall)) else 0.0
        return round(macro_recall, 3)

    if average == "micro":
        total_tp = (np.logical_and(pred_mask == true_mask, true_mask != -1)).sum()
        total_fn = (np.logical_and(pred_mask != true_mask, true_mask != -1)).sum()
        micro_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
        return round(micro_recall, 3)
dice_score(targets: torch.Tensor, preds: torch.Tensor, average: str = 'macro') -> float

Compute dice score.

Parameters:

Name Type Description Default
targets Tensor

Ground-truth segmentation masks.

required
preds Tensor

Predicted segmentation masks.

required
average str

Type of averaging to use: "macro" or "micro". Defaults to "macro"

'macro'

Returns:

Name Type Description
float float

Dice score.

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def dice_score(self, targets: torch.Tensor, preds: torch.Tensor, average: str = "macro") -> float:
    """Compute dice score.

    Args:
        targets (torch.Tensor): Ground-truth segmentation masks.
        preds (torch.Tensor): Predicted segmentation masks.
        average (str): Type of averaging to use: "macro" or "micro". Defaults to "macro"

    Returns:
        float: Dice score.
    """
    assert average in ["macro", "micro"]

    true_mask, pred_mask = self._preprocess_input_data(targets, preds)
    dice_scores = np.zeros(self.num_classes)

    for cls in range(self.num_classes):
        if cls == self.ignore_index:
            continue
        pred_class = pred_mask == cls
        true_class = true_mask == cls

        intersection = np.logical_and(pred_class, true_class).sum()
        pred_sum = pred_class.sum()
        true_sum = true_class.sum()

        if pred_sum + true_sum == 0:
            dice_scores[cls] = np.nan
        else:
            dice_scores[cls] = 2 * intersection / (pred_sum + true_sum)

    if average == "macro":
        return round(np.nanmean(dice_scores), 3) if not np.all(np.isnan(dice_scores)) else 0.0

    if average == "micro":
        total_intersection = np.logical_and(pred_mask == true_mask, true_mask != -1).sum()
        total_pred_sum = (pred_mask != -1).sum()
        total_true_sum = (true_mask != -1).sum()

        return (
            round(2 * total_intersection / (total_pred_sum + total_true_sum), 3)
            if (total_pred_sum + total_true_sum) > 0
            else 0.0
        )
get_metric_with_averaging(metric_name: str, average: str = 'macro')

Get a metric function with specified averaging strategy.

Parameters:

Name Type Description Default
metric_name str

Name of the metric ('mean_iou', 'precision', 'recall', 'dice_score')

required
average str

Averaging strategy ('macro' or 'micro')

'macro'

Returns:

Name Type Description
callable

Metric function with the specified averaging

Raises:

Type Description
ValueError

If metric_name is not supported or average is invalid

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def get_metric_with_averaging(self, metric_name: str, average: str = "macro"):
    """Get a metric function with specified averaging strategy.

    Args:
        metric_name (str): Name of the metric ('mean_iou', 'precision', 'recall', 'dice_score')
        average (str): Averaging strategy ('macro' or 'micro')

    Returns:
        callable: Metric function with the specified averaging

    Raises:
        ValueError: If metric_name is not supported or average is invalid
    """
    if average not in ["macro", "micro"]:
        raise ValueError("average must be 'macro' or 'micro'")

    if metric_name == "mean_iou":
        return lambda targets, preds: self.mean_iou(targets, preds, average=average)
    elif metric_name == "precision":
        return lambda targets, preds: self.precision(targets, preds, average=average)
    elif metric_name == "recall":
        return lambda targets, preds: self.recall(targets, preds, average=average)
    elif metric_name == "dice_score":
        return lambda targets, preds: self.dice_score(targets, preds, average=average)
    elif metric_name == "pixel_accuracy":
        return self.pixel_accuracy
    else:
        raise ValueError(f"Unsupported metric: {metric_name}")
get_all_metrics_with_averaging(include_pixel_accuracy: bool = True)

Get all metrics with both macro and micro averaging.

Parameters:

Name Type Description Default
include_pixel_accuracy bool

Whether to include pixel_accuracy (no averaging)

True

Returns:

Name Type Description
tuple

(metrics_list, metric_names_list) with proper naming

Source code in segmentation_robustness_framework/metrics/base_metrics.py
def get_all_metrics_with_averaging(self, include_pixel_accuracy: bool = True):
    """Get all metrics with both macro and micro averaging.

    Args:
        include_pixel_accuracy (bool): Whether to include pixel_accuracy (no averaging)

    Returns:
        tuple: (metrics_list, metric_names_list) with proper naming
    """
    metrics = []
    metric_names = []

    # Metrics that support averaging
    averaging_metrics = ["mean_iou", "precision", "recall", "dice_score"]

    # Add macro averaging metrics
    for metric_name in averaging_metrics:
        metrics.append(self.get_metric_with_averaging(metric_name, "macro"))
        metric_names.append(f"{metric_name}_macro")

    # Add pixel accuracy if requested
    if include_pixel_accuracy:
        metrics.append(self.pixel_accuracy)
        metric_names.append("pixel_accuracy")

    # Add micro averaging metrics
    for metric_name in averaging_metrics:
        metrics.append(self.get_metric_with_averaging(metric_name, "micro"))
        metric_names.append(f"{metric_name}_micro")

    return metrics, metric_names

segmentation_robustness_framework.metrics.custom_metrics

Functions

register_custom_metric(name: str) -> Callable

Register a custom metric function.

Parameters:

Name Type Description Default
name str

Name to register the metric under.

required

Returns:

Name Type Description
Callable Callable

Decorator function.

Example

@register_custom_metric("my_dice_score") def my_dice_score(targets, preds): # Custom implementation return score

Source code in segmentation_robustness_framework/metrics/custom_metrics.py
def register_custom_metric(name: str) -> Callable:
    """Register a custom metric function.

    Args:
        name (str): Name to register the metric under.

    Returns:
        Callable: Decorator function.

    Example:
        @register_custom_metric("my_dice_score")
        def my_dice_score(targets, preds):
            # Custom implementation
            return score
    """

    def decorator(func: Callable) -> Callable:
        CUSTOM_METRICS_REGISTRY[name] = func
        logger.info(f"Registered custom metric: {name}")
        return func

    return decorator

get_custom_metric(name: str) -> Callable

Get a custom metric function by name.

Parameters:

Name Type Description Default
name str

Name of the registered metric.

required

Returns:

Name Type Description
Callable Callable

The metric function.

Raises:

Type Description
KeyError

If the metric name is not registered.

Source code in segmentation_robustness_framework/metrics/custom_metrics.py
def get_custom_metric(name: str) -> Callable:
    """Get a custom metric function by name.

    Args:
        name (str): Name of the registered metric.

    Returns:
        Callable: The metric function.

    Raises:
        KeyError: If the metric name is not registered.
    """
    if name not in CUSTOM_METRICS_REGISTRY:
        available_metrics = list(CUSTOM_METRICS_REGISTRY.keys())
        raise KeyError(f"Custom metric '{name}' not found. Available metrics: {available_metrics}")
    return CUSTOM_METRICS_REGISTRY[name]

list_custom_metrics() -> list[str]

List all registered custom metrics.

Returns:

Type Description
list[str]

list[str]: List of registered metric names.

Source code in segmentation_robustness_framework/metrics/custom_metrics.py
def list_custom_metrics() -> list[str]:
    """List all registered custom metrics.

    Returns:
        list[str]: List of registered metric names.
    """
    return list(CUSTOM_METRICS_REGISTRY.keys())

Metrics Overview

The framework provides comprehensive evaluation metrics for semantic segmentation tasks, including both standard metrics and custom implementations.

MetricsCollection

The main metrics container that provides standardized evaluation functions:

from segmentation_robustness_framework.metrics import MetricsCollection

# Initialize metrics collection
metrics = MetricsCollection(num_classes=21, ignore_index=255)

# Get metric functions for pipeline
metric_functions = [
    metrics.mean_iou,
    metrics.pixel_accuracy,
    metrics.precision,
    metrics.recall,
    metrics.dice_score
]

Available Metrics

Mean IoU (Intersection over Union)

The most commonly used metric for semantic segmentation:

# Calculate mean IoU
iou = metrics.mean_iou(targets, predictions)
print(f"Mean IoU: {iou:.3f}")

Features: - Handles class imbalance - Robust to different class distributions - Standard benchmark metric

Pixel Accuracy

Overall pixel-level accuracy:

# Calculate pixel accuracy
accuracy = metrics.pixel_accuracy(targets, predictions)
print(f"Pixel Accuracy: {accuracy:.3f}")

Features: - Simple and intuitive - Fast computation - Good for balanced datasets

Precision

Per-class precision scores:

# Calculate precision
precision = metrics.precision(targets, predictions)
print(f"Precision: {precision:.3f}")

Features: - Per-class evaluation - Useful for imbalanced datasets - Detailed performance analysis

Recall

Per-class recall scores:

# Calculate recall
recall = metrics.recall(targets, predictions)
print(f"Recall: {recall:.3f}")

Features: - Per-class evaluation - Completeness measure - Balanced with precision

Dice Score (F1-Score)

Harmonic mean of precision and recall:

# Calculate dice score
dice = metrics.dice_score(targets, predictions)
print(f"Dice Score: {dice:.3f}")

Features: - Balanced metric - Good for imbalanced classes - Medical imaging standard

Custom Metrics

Create custom metrics by implementing metric functions:

import torch
import torch.nn.functional as F

def custom_metric(targets: torch.Tensor, predictions: torch.Tensor, 
                  num_classes: int, ignore_index: int = 255) -> float:
    """Custom metric implementation."""

    # Remove ignored pixels
    mask = targets != ignore_index
    targets = targets[mask]
    predictions = predictions[mask]

    # Your custom metric calculation
    # Example: weighted accuracy
    correct = (targets == predictions).float()
    weighted_accuracy = correct.mean()

    return weighted_accuracy.item()

# Use custom metric in pipeline
pipeline = SegmentationRobustnessPipeline(
    model=model,
    dataset=dataset,
    attacks=[FGSM(model, eps=0.1)],
    metrics=[custom_metric],
    batch_size=4,
    device="cuda"
)

Metric Registration

Register custom metrics for automatic discovery:

from segmentation_robustness_framework.metrics import register_custom_metric

@register_custom_metric("weighted_accuracy")
def weighted_accuracy(targets: torch.Tensor, predictions: torch.Tensor, 
                     num_classes: int, ignore_index: int = 255) -> float:
    """Weighted accuracy metric."""

    # Remove ignored pixels
    mask = targets != ignore_index
    targets = targets[mask]
    predictions = predictions[mask]

    # Calculate weighted accuracy
    correct = (targets == predictions).float()
    weighted_accuracy = correct.mean()

    return weighted_accuracy.item()

# Now you can use it in configuration
# metrics:
#   - weighted_accuracy

Metric Configuration

Configure metrics in YAML configuration files:

metrics:
  ignore_index: 255
  selected_metrics:
    - mean_iou
    - pixel_accuracy
    - precision
    - recall
    - {"name": "dice_score", "average": "micro"}
    - weighted_accuracy  # Custom metric

Metric Usage in Pipeline

Metrics are automatically used by the pipeline:

from segmentation_robustness_framework.pipeline import SegmentationRobustnessPipeline
from segmentation_robustness_framework.metrics import MetricsCollection

# Create metrics collection
metrics = MetricsCollection(num_classes=21, ignore_index=255)

# Create pipeline with metrics
pipeline = SegmentationRobustnessPipeline(
    model=model,
    dataset=dataset,
    attacks=[FGSM(model, eps=0.1)],
    metrics=[
        metrics.mean_iou,
        metrics.pixel_accuracy,
        metrics.precision,
        metrics.recall,
        metrics.dice_score
    ],
    batch_size=4,
    device="cuda"
)

# Run evaluation
results = pipeline.run()

# Access results
clean_iou = results['clean']['mean_iou']
attack_iou = results['attack_fgsm']['mean_iou']

print(f"Clean IoU: {clean_iou:.3f}")
print(f"Attack IoU: {attack_iou:.3f}")

Metric Aggregation

The framework provides different aggregation strategies:

# Micro averaging (global)
micro_precision = metrics.precision(targets, predictions, average='micro')

# Macro averaging (per-class then average)
macro_precision = metrics.precision(targets, predictions, average='macro')

# Weighted averaging (per-class weighted by frequency)
weighted_precision = metrics.precision(targets, predictions, average='weighted')

Performance Considerations

  • GPU Acceleration: All metrics support GPU computation
  • Memory Efficiency: Optimized for large batches
  • Batch Processing: Efficient batch metric computation
  • Numerical Stability: Robust to edge cases

Metric Interpretation

Understanding metric results:

# Good performance indicators
good_iou = 0.8      # 80% IoU is excellent
good_accuracy = 0.9  # 90% accuracy is very good
good_dice = 0.85     # 85% dice score is excellent

# Poor performance indicators
poor_iou = 0.3       # 30% IoU indicates issues
poor_accuracy = 0.5  # 50% accuracy is poor
poor_dice = 0.4      # 40% dice score is poor

# Robustness evaluation
robustness_ratio = attack_iou / clean_iou
if robustness_ratio > 0.8:
    print("Model is robust")
elif robustness_ratio > 0.5:
    print("Model has moderate robustness")
else:
    print("Model is vulnerable to attacks")