Source code for cellmap_segmentation_challenge.utils.eval_utils.instance_matching

"""Instance matching using min-cost flow optimization."""

import logging
from dataclasses import dataclass

import numpy as np

from .config import EvaluationConfig, ratio_cutoff
from .exceptions import (
    TooManyInstancesError,
    TooManyOverlapEdgesError,
    MatchingFailedError,
)


[docs] @dataclass class InstanceOverlapData: """Data structure for instance overlap computation.""" nG: int # Number of ground truth instances nP: int # Number of predicted instances rows: np.ndarray # GT indices for overlaps cols: np.ndarray # Pred indices for overlaps iou_vals: np.ndarray # IoU values for overlaps
def _check_instance_counts(nG: int, nP: int) -> bool: """Check if instance counts allow matching. Args: nG: Number of ground truth instances nP: Number of predicted instances Returns: True if matching should proceed, False if special case handled """ if (nG == 0 and nP > 0) or (nP == 0 and nG > 0): if nG == 0 and nP > 0: logging.info("No GT instances; returning empty match.") if nP == 0 and nG > 0: logging.info("No Pred instances; returning empty match.") return False elif nG == 0 and nP == 0: logging.info("No GT or Pred instances; returning only background match.") return False return True def _check_instance_ratio(nP: int, nG: int, config: EvaluationConfig) -> None: """Check if predicted/GT ratio is within acceptable bounds. Args: nP: Number of predicted instances nG: Number of ground truth instances config: Evaluation configuration Raises: TooManyInstancesError: If ratio exceeds cutoff """ assert nG > 0, "nG must be > 0 to check instance ratio" cutoff = ratio_cutoff( nG, config.final_instance_ratio_cutoff, config.initial_instance_ratio_cutoff, config.instance_ratio_factor, ) ratio = nP / nG if ratio > cutoff: logging.warning( f"Instance ratio {ratio:.2f} exceeds cutoff {cutoff:.2f} " f"({nP} pred vs {nG} GT)" ) raise TooManyInstancesError(nP, nG, ratio, cutoff) def _compute_instance_overlaps( gt: np.ndarray, pred: np.ndarray, nG: int, nP: int, max_edges: int ) -> InstanceOverlapData: """Compute IoU overlaps between GT and predicted instances. Args: gt: Ground truth instance labels (1D or flattened view) pred: Predicted instance labels (1D or flattened view) nG: Number of ground truth instances nP: Number of predicted instances max_edges: Maximum number of overlap edges allowed Returns: InstanceOverlapData with overlap information Raises: TooManyOverlapEdgesError: If number of edges exceeds max_edges """ # 1D views g = np.ravel(gt) p = np.ravel(pred) # Foreground masks g_fg = g > 0 p_fg = p > 0 fg = g_fg & p_fg # Per-object sizes gt_sizes = np.bincount((g[g_fg].astype(np.int64) - 1), minlength=nG)[:, None] pr_sizes = np.bincount((p[p_fg].astype(np.int64) - 1), minlength=nP)[None, :] # Foreground overlaps gi = g[fg].astype(np.int64) - 1 pj = p[fg].astype(np.int64) - 1 if gi.size == 0: # No overlaps return InstanceOverlapData( nG=nG, nP=nP, rows=np.array([], dtype=np.int64), cols=np.array([], dtype=np.int64), iou_vals=np.array([], dtype=np.float32), ) # Encode pairs and count gi_u = gi.astype(np.uint64) pj_u = pj.astype(np.uint64) key = gi_u * np.uint64(nP) + pj_u uniq_keys, counts = np.unique(key, return_counts=True) if uniq_keys.size > max_edges: raise TooManyOverlapEdgesError(uniq_keys.size, max_edges) rows = (uniq_keys // np.uint64(nP)).astype(np.int64) cols = (uniq_keys % np.uint64(nP)).astype(np.int64) # Compute IoU inter = counts.astype(np.int64) union = gt_sizes[rows, 0] + pr_sizes[0, cols] - inter with np.errstate(divide="ignore", invalid="ignore"): iou_vals = (inter / union).astype(np.float32) # Keep only IoU > 0 keep = iou_vals > 0.0 rows = rows[keep] cols = cols[keep] iou_vals = iou_vals[keep] return InstanceOverlapData(nG=nG, nP=nP, rows=rows, cols=cols, iou_vals=iou_vals) def _solve_matching_problem( overlap_data: InstanceOverlapData, cost_scale: int ) -> dict[int, int]: """Solve min-cost flow matching problem. Args: overlap_data: Instance overlap data cost_scale: Scale factor for cost values Returns: Dictionary mapping predicted ID to ground truth ID Raises: MatchingFailedError: If optimization fails """ from ortools.graph.python import min_cost_flow nG = overlap_data.nG nP = overlap_data.nP rows = overlap_data.rows cols = overlap_data.cols iou_vals = overlap_data.iou_vals mcf = min_cost_flow.SimpleMinCostFlow() # Node indexing source = 0 gt0 = 1 pred0 = gt0 + nG sink = pred0 + nP UNMATCH_COST = cost_scale + 1 # Build arcs tails = [] heads = [] caps = [] costs = [] def add_arc(u: int, v: int, cap: int, cost: int) -> None: tails.append(u) heads.append(v) caps.append(cap) costs.append(cost) # Source -> GT for i in range(nG): add_arc(source, gt0 + i, 1, 0) # GT -> Sink (unmatched option) for i in range(nG): add_arc(gt0 + i, sink, 1, UNMATCH_COST) # GT -> Pred edges for r, c, iou in zip(rows, cols, iou_vals): u = gt0 + int(r) v = pred0 + int(c) cost = int((1.0 - float(iou)) * cost_scale) add_arc(u, v, 1, cost) # Pred -> Sink for j in range(nP): add_arc(pred0 + j, sink, 1, 0) # Add arcs in bulk mcf.add_arcs_with_capacity_and_unit_cost( np.asarray(tails, dtype=np.int32), np.asarray(heads, dtype=np.int32), np.asarray(caps, dtype=np.int64), np.asarray(costs, dtype=np.int64), ) # Set supplies mcf.set_node_supply(source, int(nG)) mcf.set_node_supply(sink, -int(nG)) # Solve status = mcf.solve() if status != mcf.OPTIMAL: raise MatchingFailedError(status) # Extract matches mapping: dict[int, int] = {} for a in range(mcf.num_arcs()): if mcf.flow(a) != 1: continue u = mcf.tail(a) v = mcf.head(a) if gt0 <= u < pred0 and pred0 <= v < sink: gt_id = (u - gt0) + 1 pred_id = (v - pred0) + 1 mapping[pred_id] = gt_id return mapping
[docs] def match_instances( gt: np.ndarray, pred: np.ndarray, config: EvaluationConfig | None = None, ) -> dict[int, int]: """Match instances between GT and prediction based on IoU. Uses min-cost flow optimization to find optimal 1:1 matching between predicted and ground truth instances based on IoU overlap. Args: gt: Ground truth instance labels (0 = background) pred: Predicted instance labels (0 = background) config: Evaluation configuration (uses defaults if None) Returns: Dictionary mapping predicted instance ID to ground truth instance ID. Returns {0: 0} if only background present. Returns {} if no matches found or one side has no instances. Raises: ValidationError: If array shapes don't match TooManyInstancesError: If pred/GT ratio exceeds threshold TooManyOverlapEdgesError: If overlap computation is too large MatchingFailedError: If optimization fails Example: >>> mapping = match_instances(gt, pred) >>> # Remap predictions to match GT IDs >>> pred_aligned = remap(pred, mapping, preserve_missing_labels=True) """ if config is None: config = EvaluationConfig.from_env() # Get instance counts g = np.ravel(gt) p = np.ravel(pred) nG = int(g.max()) if g.size else 0 nP = int(p.max()) if p.size else 0 # Check for special cases if not _check_instance_counts(nG, nP): if nG == 0 and nP == 0: return {0: 0} return {} # Check instance ratio _check_instance_ratio(nP, nG, config) # Compute overlaps overlap_data = _compute_instance_overlaps( gt, pred, nG, nP, config.max_overlap_edges ) # Handle case of no overlaps if overlap_data.rows.size == 0: return {} # Solve matching problem mapping = _solve_matching_problem(overlap_data, config.mcmf_cost_scale) return mapping