Source code for cellmap_segmentation_challenge.utils.eval_utils.scoring

"""Core scoring functions for segmentation evaluation."""

import gc
import logging
from time import time

import cc3d
import numpy as np
import zarr
from fastremap import remap, unique
from upath import UPath

from ...config import TRUTH_PATH, INSTANCE_CLASSES
from ..crops import TEST_CROPS_DICT
from ..matched_crop import MatchedCrop
from .config import EvaluationConfig
from .distance import (
    compute_max_distance,
    normalize_distance,
    optimized_hausdorff_distances,
)
from .exceptions import (
    TooManyInstancesError,
    TooManyOverlapEdgesError,
    MatchingFailedError,
)
from .instance_matching import match_instances
from .types import InstanceScoreDict


def _create_pathological_scores(status: str) -> InstanceScoreDict:
    """Create scores for a crop whose instance matching failed.

    The crop contributes nothing to the per-class pools (zero counts and no
    Hausdorff entries), so a failure neither helps nor penalizes the class.

    Args:
        status: Status string describing the failure.

    Returns:
        Score dict with zeroed counts and Hausdorff fields.
    """
    return {
        "tp": 0,
        "fp": 0,
        "fn": 0,
        "hausdorff_norm_sum": 0.0,
        "n_hausdorff": 0,
        "status": status,
    }


def _compute_hausdorff_scores(
    mapping: dict[int, int],
    truth_label: np.ndarray,
    pred_label: np.ndarray,
    n_pred: int,
    voxel_size: tuple[float, ...],
    hausdorff_distance_max: float,
    truth_ids: np.ndarray | None = None,
) -> np.ndarray:
    """Compute per-instance Hausdorff distances for pooling per class.

    Produces one distance per ground-truth instance (matched -> real distance,
    unmatched/FN -> max distance) plus a max-distance penalty per unmatched
    prediction (hallucination). Returns an empty array when the crop has no
    truth instances and no predictions, so empty crops contribute nothing to
    the per-class pool.

    Args:
        mapping: Instance ID mapping (pred -> truth)
        truth_label: Ground truth labels
        pred_label: Predicted labels (remapped to truth IDs)
        n_pred: Number of predicted instances
        voxel_size: Voxel size
        hausdorff_distance_max: Maximum distance
        truth_ids: Precomputed non-zero ground-truth ids

    Returns:
        1D array of per-instance Hausdorff distances (possibly empty)
    """
    # One distance per truth instance (matched -> real, unmatched/FN -> max).
    hausdorff_distances = optimized_hausdorff_distances(
        truth_label, pred_label, voxel_size, hausdorff_distance_max, truth_ids=truth_ids
    )

    # Max-distance penalty per unmatched prediction (hallucination).
    n_unmatched_pred = n_pred - len(set(mapping.keys()) - {0})
    if n_unmatched_pred > 0:
        hausdorff_distances = np.concatenate(
            [
                hausdorff_distances,
                np.full(n_unmatched_pred, hausdorff_distance_max, dtype=np.float32),
            ]
        )

    return hausdorff_distances


[docs] def score_instance( pred_label, truth_label, voxel_size, hausdorff_distance_max=None, config: EvaluationConfig | None = None, ) -> InstanceScoreDict: """Score instance segmentation against ground truth. Computes instance F1 score, Hausdorff distance, and combined metrics after optimal instance matching. Args: pred_label: Predicted instance labels (0 = background) truth_label: Ground truth instance labels (0 = background) voxel_size: Physical voxel size in (Z, Y, X) order hausdorff_distance_max: Maximum Hausdorff distance cap (None = auto) config: Evaluation configuration (uses defaults if None) Returns: Dict with tp/fp/fn and per-instance Hausdorff sum/count. F1 and combined_score are computed per class in aggregation, not here. Example: >>> scores = score_instance(pred, truth, voxel_size=(4.0, 4.0, 4.0)) >>> print(scores["tp"], scores["n_hausdorff"]) """ if config is None: config = EvaluationConfig.from_env() logging.info("Scoring instance segmentation...") # Determine Hausdorff distance cap if hausdorff_distance_max is None: hausdorff_distance_max = compute_max_distance(voxel_size, truth_label.shape) logging.debug( f"Using default maximum Hausdorff distance of {hausdorff_distance_max:.2f}" ) # Relabel predictions using connected components logging.info("Relabeling predicted instance labels...") # TODO: Switch to just renumbering to contiguous IDs, and leave user labels intact pred_label, n_pred = cc3d.connected_components(pred_label, return_N=True) # Match instances try: mapping = match_instances(truth_label, pred_label, config) except (TooManyInstancesError, TooManyOverlapEdgesError) as e: logging.warning(f"Instance matching failed: {e}") return _create_pathological_scores("skipped_too_many_instances") except MatchingFailedError as e: logging.error(f"Matching optimization failed: {e}") return _create_pathological_scores("matching_failed") # Remap predictions to match GT IDs if len(mapping) > 0 and not (len(mapping) == 1 and 0 in mapping): mapping[0] = 0 # background maps to background pred_label = remap( pred_label, mapping, in_place=True, preserve_missing_labels=True ) # Non-zero ground-truth ids, computed once and shared with the Hausdorff step. truth_ids = unique(truth_label) truth_ids = truth_ids[truth_ids != 0] # Free matching-stage scratch before the memory-heavy Hausdorff phase. gc.collect() # Per-instance Hausdorff distances (pooled per class in aggregation). hausdorff_distances = _compute_hausdorff_scores( mapping, truth_label, pred_label, n_pred, voxel_size, hausdorff_distance_max, truth_ids ) # F1 counts from matching (pooled per class in aggregation). gt_count = int(truth_ids.size) tp = len(set(mapping.values()) - {0}) fp = n_pred - len(set(mapping.keys()) - {0}) fn = gt_count - tp # Normalize per-instance distances; emit sum/count for per-class pooling. norm = normalize_distance(hausdorff_distances, voxel_size) hausdorff_norm_sum = float(np.sum(norm)) n_hausdorff = int(np.size(norm)) logging.debug(f"TP={tp}, FP={fp}, FN={fn}, n_hausdorff={n_hausdorff}") return { "tp": tp, "fp": fp, "fn": fn, "hausdorff_norm_sum": hausdorff_norm_sum, "n_hausdorff": n_hausdorff, "status": "scored", }
[docs] def score_semantic(pred_label, truth_label) -> dict[str, int | str]: """ Score a single semantic label volume against the ground truth semantic label volume. Args: pred_label (np.ndarray): The predicted semantic label volume. truth_label (np.ndarray): The ground truth semantic label volume. Returns: dict: A dictionary of scores for the semantic label volume. Example usage: scores = score_semantic(pred_label, truth_label) """ logging.info("Scoring semantic segmentation...") # Flatten the label volumes and convert to binary pred_label = (pred_label > 0.0).ravel() truth_label = (truth_label > 0.0).ravel() # Voxel confusion counts; pooled per class in aggregation to compute IoU. tp = int(np.count_nonzero(truth_label & pred_label)) fp = int(pred_label.sum()) - tp fn = int(truth_label.sum()) - tp logging.debug(f"Semantic counts: TP={tp}, FP={fp}, FN={fn}") return { "tp": tp, "fp": fp, "fn": fn, "status": "scored", }
[docs] def score_label( pred_label_path, label_name, crop_name, truth_path=TRUTH_PATH, instance_classes=INSTANCE_CLASSES, ): """ Score a single label volume against the ground truth label volume. Args: pred_label_path (str): The path to the predicted label volume. truth_path (str): The path to the ground truth label volume. instance_classes (list): A list of instance classes. Returns: dict: A dictionary of scores for the label volume. Example usage: scores = score_label('pred.zarr/test_volume/label1') """ if pred_label_path is None: logging.info(f"Label {label_name} not found in submission volume {crop_name}.") return ( crop_name, label_name, empty_label_score( label=label_name, crop_name=crop_name, instance_classes=instance_classes, truth_path=truth_path, ), ) logging.info(f"Scoring {pred_label_path}...") truth_path = UPath(truth_path) # Load the predicted and ground truth label volumes truth_label_path = (truth_path / crop_name / label_name).path truth_label_ds = zarr.open(truth_label_path, mode="r") truth_label = truth_label_ds[:] crop = TEST_CROPS_DICT[int(crop_name.removeprefix("crop")), label_name] pred_label = match_crop_space( pred_label_path, label_name, crop.voxel_size, crop.shape, crop.translation, ) mask_path = truth_path / crop_name / f"{label_name}_mask" if mask_path.exists(): # Mask out uncertain regions resulting from low-res ground truth annotations logging.info(f"Masking {label_name} with {mask_path}...") mask = zarr.open(mask_path.path, mode="r")[:] pred_label = pred_label * mask truth_label = truth_label * mask del mask gc.collect() # Compute the scores if label_name in instance_classes: logging.info( f"Starting an instance evaluation for {label_name} in {crop_name}..." ) timer = time() results = score_instance(pred_label, truth_label, crop.voxel_size) logging.info( f"Finished instance evaluation for {label_name} in {crop_name} in {time() - timer:.2f} seconds..." ) else: results = score_semantic(pred_label, truth_label) results["num_voxels"] = int(np.prod(truth_label.shape)) results["voxel_size"] = crop.voxel_size results["is_missing"] = False # drop big arrays before returning del truth_label, pred_label, truth_label_ds gc.collect() return crop_name, label_name, results
[docs] def empty_label_score( label, crop_name, instance_classes=INSTANCE_CLASSES, truth_path=TRUTH_PATH ): """Score a not-submitted label as a penalized "missing" volume. Non-submission is penalized harder than a (submitted) empty prediction: instance classes count every ground-truth instance as a false negative at a worst-case (zero) boundary score; semantic classes count every voxel as a false negative. Used when a label is absent from the submission. Args: label: Label/class name. crop_name: Crop identifier. instance_classes: Names of instance-segmented classes. truth_path: Path to the ground-truth zarr. Returns: A per-volume score dict flagged ``is_missing=True``. """ truth_path = UPath(truth_path) ds = zarr.open((truth_path / crop_name / label).path, mode="r") voxel_size = ds.attrs["voxel_size"] if label in instance_classes: # Not submitted: each instance an FN at worst-case (0) boundary -- harsher than empty submission. truth_ids = unique(ds[:]) n_instances = int(truth_ids[truth_ids != 0].size) return { "tp": 0, "fp": 0, "fn": n_instances, "hausdorff_norm_sum": 0.0, "n_hausdorff": n_instances, "num_voxels": int(np.prod(ds.shape)), "voxel_size": voxel_size, "is_missing": True, "status": "missing", } else: # Not submitted: count every voxel as a false negative -> IoU 0. n_voxels = int(np.prod(ds.shape)) return { "tp": 0, "fp": 0, "fn": n_voxels, "num_voxels": n_voxels, "voxel_size": voxel_size, "is_missing": True, "status": "missing", }
[docs] def match_crop_space(path, class_label, voxel_size, shape, translation) -> np.ndarray: mc = MatchedCrop( path=path, class_label=class_label, target_voxel_size=voxel_size, target_shape=shape, target_translation=translation, instance_classes=INSTANCE_CLASSES, semantic_threshold=0.5, pad_value=0, ) return mc.load_aligned()