Source code for cellmap_segmentation_challenge.utils.eval_utils.scoring

"""Core scoring functions for segmentation evaluation."""

import gc
import logging
from time import time

import cc3d
import numpy as np
import zarr
from fastremap import remap
from scipy.spatial.distance import dice
from sklearn.metrics import jaccard_score
from upath import UPath

from ...config import TRUTH_PATH, INSTANCE_CLASSES
from ..crops import TEST_CROPS_DICT
from ..matched_crop import MatchedCrop
from ..rand_voi import rand_voi
from .config import EvaluationConfig
from .distance import (
    compute_max_distance,
    normalize_distance,
    optimized_hausdorff_distances,
)
from .exceptions import (
    TooManyInstancesError,
    TooManyOverlapEdgesError,
    MatchingFailedError,
)
from .instance_matching import match_instances
from .types import InstanceScoreDict


def _compute_binary_metrics(
    truth_label: np.ndarray, pred_label: np.ndarray
) -> dict[str, float]:
    """Compute binary segmentation metrics.

    Args:
        truth_label: Ground truth labels
        pred_label: Predicted labels

    Returns:
        Dictionary with iou, dice_score, and binary_accuracy
    """
    truth_binary = (truth_label > 0).ravel()
    pred_binary = (pred_label > 0).ravel()

    iou = jaccard_score(truth_binary, pred_binary, zero_division=1)
    dice_score = 1 - dice(truth_binary, pred_binary)
    binary_accuracy = float((truth_binary == pred_binary).mean())

    return {
        "iou": iou,
        "dice_score": dice_score,
        "binary_accuracy": binary_accuracy,
    }


def _create_pathological_scores(
    binary_metrics: dict[str, float],
    voi_metrics: dict[str, float],
    hausdorff_distance_max: float,
    voxel_size: tuple[float, ...],
    status: str,
) -> InstanceScoreDict:
    """Create scores for pathological cases (matching failed).

    Args:
        binary_metrics: Pre-computed binary metrics
        voi_metrics: Pre-computed VoI metrics
        hausdorff_distance_max: Maximum Hausdorff distance
        voxel_size: Voxel size
        status: Status string for the failure

    Returns:
        Dictionary with worst-case scores
    """
    return {
        "mean_accuracy": 0,
        "binary_accuracy": binary_metrics["binary_accuracy"],
        "hausdorff_distance": hausdorff_distance_max,
        "normalized_hausdorff_distance": normalize_distance(
            hausdorff_distance_max, voxel_size
        ),
        "combined_score": 0,
        "iou": binary_metrics["iou"],
        "dice_score": binary_metrics["dice_score"],
        "status": status,
        **voi_metrics,
    }


def _compute_hausdorff_scores(
    mapping: dict[int, int],
    truth_label: np.ndarray,
    pred_label: np.ndarray,
    n_pred: int,
    voxel_size: tuple[float, ...],
    hausdorff_distance_max: float,
) -> list[float]:
    """Compute Hausdorff distances for matched instances.

    Args:
        mapping: Instance ID mapping (pred -> truth)
        truth_label: Ground truth labels
        pred_label: Predicted labels (remapped to truth IDs)
        n_pred: Number of predicted instances
        voxel_size: Voxel size
        hausdorff_distance_max: Maximum distance

    Returns:
        List of Hausdorff distances
    """
    if len(mapping) == 1 and 0 in mapping:
        # Only background
        return [0.0]

    if len(mapping) > 0:
        # Compute Hausdorff for matched instances
        hausdorff_distances = optimized_hausdorff_distances(
            truth_label, pred_label, voxel_size, hausdorff_distance_max
        )

        # Add max distance for unmatched predictions
        matched_pred_ids = set(mapping.keys()) - {0}
        pred_ids = set(np.arange(1, n_pred + 1)) - {0}
        unmatched_pred = pred_ids - matched_pred_ids

        if len(unmatched_pred) > 0:
            hausdorff_distances = np.concatenate(
                [
                    hausdorff_distances,
                    np.full(
                        len(unmatched_pred), hausdorff_distance_max, dtype=np.float32
                    ),
                ]
            )

        return hausdorff_distances.tolist()
    else:
        # No matches
        return [hausdorff_distance_max]


[docs] def score_instance( pred_label, truth_label, voxel_size, hausdorff_distance_max=None, config: EvaluationConfig | None = None, ) -> InstanceScoreDict: """Score instance segmentation against ground truth. Computes pixel-wise accuracy, Hausdorff distance, and combined metrics after optimal instance matching. Args: pred_label: Predicted instance labels (0 = background) truth_label: Ground truth instance labels (0 = background) voxel_size: Physical voxel size in (Z, Y, X) order hausdorff_distance_max: Maximum Hausdorff distance cap (None = auto) config: Evaluation configuration (uses defaults if None) Returns: Dictionary containing all instance segmentation metrics Example: >>> scores = score_instance(pred, truth, voxel_size=(4.0, 4.0, 4.0)) >>> print(f"Combined score: {scores['combined_score']:.3f}") """ if config is None: config = EvaluationConfig.from_env() logging.info("Scoring instance segmentation...") # Determine Hausdorff distance cap if hausdorff_distance_max is None: hausdorff_distance_max = compute_max_distance(voxel_size, truth_label.shape) logging.debug( f"Using default maximum Hausdorff distance of {hausdorff_distance_max:.2f}" ) # Relabel predictions using connected components logging.info("Relabeling predicted instance labels...") # TODO: Switch to just renumbering to contiguous IDs, and leave user labels intact pred_label, n_pred = cc3d.connected_components(pred_label, return_N=True) # Compute metrics that don't require matching binary_metrics = _compute_binary_metrics(truth_label, pred_label) voi = rand_voi(truth_label.astype(np.uint64), pred_label.astype(np.uint64)) del voi["voi_split_i"], voi["voi_merge_j"] # Match instances try: mapping = match_instances(truth_label, pred_label, config) except (TooManyInstancesError, TooManyOverlapEdgesError) as e: logging.warning(f"Instance matching failed: {e}") return _create_pathological_scores( binary_metrics, voi, hausdorff_distance_max, voxel_size, "skipped_too_many_instances", ) except MatchingFailedError as e: logging.error(f"Matching optimization failed: {e}") return _create_pathological_scores( binary_metrics, voi, hausdorff_distance_max, voxel_size, "matching_failed" ) # Remap predictions to match GT IDs if len(mapping) > 0 and not (len(mapping) == 1 and 0 in mapping): mapping[0] = 0 # background maps to background pred_label = remap( pred_label, mapping, in_place=True, preserve_missing_labels=True ) # Compute Hausdorff distances hausdorff_distances = _compute_hausdorff_scores( mapping, truth_label, pred_label, n_pred, voxel_size, hausdorff_distance_max ) if len(hausdorff_distances) == 0: hausdorff_distances = [hausdorff_distance_max] # Aggregate scores logging.info("Computing final scores...") mean_accuracy = float((truth_label == pred_label).mean()) hausdorff_dist = float(np.mean(hausdorff_distances)) normalized_hausdorff_dist = float( np.mean([normalize_distance(hd, voxel_size) for hd in hausdorff_distances]) ) combined_score = (mean_accuracy * normalized_hausdorff_dist) ** 0.5 logging.info(f"Mean Accuracy: {mean_accuracy:.4f}") logging.info(f"Hausdorff Distance: {hausdorff_dist:.4f}") logging.info(f"Normalized Hausdorff Distance: {normalized_hausdorff_dist:.4f}") logging.info(f"Combined Score: {combined_score:.4f}") return { "mean_accuracy": mean_accuracy, "hausdorff_distance": hausdorff_dist, "normalized_hausdorff_distance": normalized_hausdorff_dist, "combined_score": combined_score, "status": "scored", **binary_metrics, **voi, }
[docs] def score_semantic(pred_label, truth_label) -> dict[str, float]: """ Score a single semantic label volume against the ground truth semantic label volume. Args: pred_label (np.ndarray): The predicted semantic label volume. truth_label (np.ndarray): The ground truth semantic label volume. Returns: dict: A dictionary of scores for the semantic label volume. Example usage: scores = score_semantic(pred_label, truth_label) """ logging.info("Scoring semantic segmentation...") # Flatten the label volumes and convert to binary pred_label = (pred_label > 0.0).ravel() truth_label = (truth_label > 0.0).ravel() # Compute the scores if np.sum(truth_label + pred_label) == 0: # If there are no true or false positives, set the scores to 1 logging.debug("No true or false positives found. Setting scores to 1.") dice_score = 1 iou_score = 1 else: dice_score = 1 - dice(truth_label, pred_label) iou_score = jaccard_score(truth_label, pred_label, zero_division=1) scores = { "iou": iou_score, "dice_score": dice_score if not np.isnan(dice_score) else 1, "binary_accuracy": float((truth_label == pred_label).mean()), "status": "scored", } logging.info(f"IoU: {scores['iou']:.4f}") logging.info(f"Dice Score: {scores['dice_score']:.4f}") return scores
[docs] def score_label( pred_label_path, label_name, crop_name, truth_path=TRUTH_PATH, instance_classes=INSTANCE_CLASSES, ): """ Score a single label volume against the ground truth label volume. Args: pred_label_path (str): The path to the predicted label volume. truth_path (str): The path to the ground truth label volume. instance_classes (list): A list of instance classes. Returns: dict: A dictionary of scores for the label volume. Example usage: scores = score_label('pred.zarr/test_volume/label1') """ if pred_label_path is None: logging.info(f"Label {label_name} not found in submission volume {crop_name}.") return ( crop_name, label_name, empty_label_score( label=label_name, crop_name=crop_name, instance_classes=instance_classes, truth_path=truth_path, ), ) logging.info(f"Scoring {pred_label_path}...") truth_path = UPath(truth_path) # Load the predicted and ground truth label volumes truth_label_path = (truth_path / crop_name / label_name).path truth_label_ds = zarr.open(truth_label_path, mode="r") truth_label = truth_label_ds[:] crop = TEST_CROPS_DICT[int(crop_name.removeprefix("crop")), label_name] pred_label = match_crop_space( pred_label_path, label_name, crop.voxel_size, crop.shape, crop.translation, ) mask_path = truth_path / crop_name / f"{label_name}_mask" if mask_path.exists(): # Mask out uncertain regions resulting from low-res ground truth annotations logging.info(f"Masking {label_name} with {mask_path}...") mask = zarr.open(mask_path.path, mode="r")[:] pred_label = pred_label * mask truth_label = truth_label * mask del mask # Compute the scores if label_name in instance_classes: logging.info( f"Starting an instance evaluation for {label_name} in {crop_name}..." ) timer = time() results = score_instance(pred_label, truth_label, crop.voxel_size) logging.info( f"Finished instance evaluation for {label_name} in {crop_name} in {time() - timer:.2f} seconds..." ) else: results = score_semantic(pred_label, truth_label) results["num_voxels"] = int(np.prod(truth_label.shape)) results["voxel_size"] = crop.voxel_size results["is_missing"] = False # drop big arrays before returning del truth_label, pred_label, truth_label_ds gc.collect() return crop_name, label_name, results
[docs] def empty_label_score( label, crop_name, instance_classes=INSTANCE_CLASSES, truth_path=TRUTH_PATH ): truth_path = UPath(truth_path) ds = zarr.open((truth_path / crop_name / label).path, mode="r") voxel_size = ds.attrs["voxel_size"] if label in instance_classes: truth_path = UPath(truth_path) return { "mean_accuracy": 0, "hausdorff_distance": compute_max_distance(voxel_size, ds.shape), "normalized_hausdorff_distance": 0, "combined_score": 0, "num_voxels": int(np.prod(ds.shape)), "voxel_size": voxel_size, "is_missing": True, "status": "missing", } else: return { "iou": 0, "dice_score": 0, "binary_accuracy": 0, "num_voxels": int(np.prod(ds.shape)), "voxel_size": voxel_size, "is_missing": True, "status": "missing", }
[docs] def match_crop_space(path, class_label, voxel_size, shape, translation) -> np.ndarray: mc = MatchedCrop( path=path, class_label=class_label, target_voxel_size=voxel_size, target_shape=shape, target_translation=translation, instance_classes=INSTANCE_CLASSES, semantic_threshold=0.5, pad_value=0, ) return mc.load_aligned()