Source code for cellmap_segmentation_challenge.utils.eval_utils.distance

"""Distance metrics including Hausdorff distance computation."""

from concurrent.futures import ThreadPoolExecutor

import numpy as np
from cc3d import statistics as cc3d_statistics
from cc3d.types import StatisticsDict, StatisticsSlicesDict
from fastremap import unique
from scipy.ndimage import distance_transform_edt
from tqdm import tqdm

from .config import PER_INSTANCE_THREADS


[docs] def compute_max_distance(voxel_size, shape) -> float: """ Compute the maximum distance used for distance-based metrics, based on the maximum distance to a volume boundary. """ voxel_size = np.asarray(voxel_size, dtype=np.float64) shape = np.asarray(shape, dtype=np.int64) return float(min([(v * s) / 2 for v, s in zip(voxel_size, shape)]))
[docs] def normalize_distance(distance: float, voxel_size) -> float: """ Normalize a distance value to [0, 1] using the maximum distance represented by a voxel """ if distance == np.inf: return 0.0 # TODO: Normalize by max possible distance in the volume return float((1.01 ** (-distance / np.linalg.norm(voxel_size))))
[docs] def optimized_hausdorff_distances( truth_label, pred_label, voxel_size, hausdorff_distance_max, method="standard", percentile: float | None = None, ): """ Compute per-truth-instance Hausdorff-like distances against the (already remapped) prediction using multithreading. Returns a 1D float32 numpy array whose i-th entry corresponds to truth_ids[i]. Parameters ---------- truth_label : np.ndarray Ground-truth instance label volume (0 == background). pred_label : np.ndarray Prediction instance label volume that has already been remapped to align with the GT ids (0 == background). voxel_size : Sequence[float] Physical voxel sizes in Z, Y, X (or Y, X) order. hausdorff_distance_max : float Cap for distances (use np.inf for uncapped). method : {"standard", "modified", "percentile"} "standard" -> classic Hausdorff (max of directed maxima) "modified" -> mean of directed distances, then max of the two means "percentile" -> use the given percentile of directed distances (requires `percentile` to be provided). percentile : float | None Percentile (0-100) used when method=="percentile". """ # Unique GT ids (exclude background = 0) truth_ids = unique(truth_label) truth_ids = truth_ids[truth_ids != 0] true_num = int(truth_ids.size) if true_num == 0: return np.empty((0,), dtype=np.float32) voxel_size = np.asarray(voxel_size, dtype=np.float64) truth_stats = cc3d_statistics(truth_label) pred_stats = cc3d_statistics(pred_label) def get_distance(i: int): tid = int(truth_ids[i]) h_dist = compute_hausdorff_distance_roi( truth_label, truth_stats, pred_label, pred_stats, tid, voxel_size, hausdorff_distance_max, method=method, percentile=percentile, ) return i, float(h_dist) dists = np.empty((true_num,), dtype=np.float32) with ThreadPoolExecutor(max_workers=PER_INSTANCE_THREADS) as executor: for idx, h in tqdm( executor.map(get_distance, range(true_num)), desc="Computing Hausdorff distances", total=true_num, dynamic_ncols=True, ): dists[idx] = h return dists
[docs] def bbox_for_label( stats: StatisticsDict | StatisticsSlicesDict, ndim: int, label_id: int, ): """ Try to get bbox without allocating a full boolean mask using cc3d statistics. Falls back to mask-based bbox if cc3d doesn't provide expected fields. Returns (mins, maxs) inclusive-exclusive in voxel indices, or None if missing. """ # stats = cc3d.statistics(label_vol) # cc3d.statistics usually returns dict-like with keys per label id. # There are multiple API variants; try common patterns. if "bounding_boxes" in stats: # bounding_boxes is a list where index corresponds to label_id bounding_boxes = stats["bounding_boxes"] if label_id >= len(bounding_boxes): return None bb = bounding_boxes[label_id] if bb is None: return None # bb is a tuple of slices, convert to (mins, maxs) if isinstance(bb, tuple) and all(isinstance(s, slice) for s in bb): mins = [s.start for s in bb] maxs = [s.stop for s in bb] return mins, maxs # bb might be (z0,z1,y0,y1,x0,x1) with end exclusive mins = [bb[2 * k] for k in range(ndim)] maxs = [bb[2 * k + 1] for k in range(ndim)] return mins, maxs if label_id in stats: s = stats[label_id] if "bounding_box" in s: bb = s["bounding_box"] mins = [bb[2 * k] for k in range(ndim)] maxs = [bb[2 * k + 1] for k in range(ndim)] return mins, maxs
[docs] def roi_slices_for_pair( truth_stats: StatisticsDict | StatisticsSlicesDict, pred_stats: StatisticsDict | StatisticsSlicesDict, tid: int, voxel_size, ndim: int, shape: tuple[int, ...], max_distance: float, ): """ ROI = union(bbox(truth==tid), bbox(pred==tid)) padded by P derived from max_distance. Returns tuple of slices suitable for numpy indexing. """ vs = np.asarray(voxel_size, dtype=float) if vs.size != ndim: # tolerate vs longer (e.g. includes channel), take last ndim vs = vs[-ndim:] # padding per axis in voxels pad = np.ceil(max_distance / vs).astype(int) + 2 tb = bbox_for_label(truth_stats, ndim, tid) assert tb is not None, f"Truth ID {tid} not found in truth statistics." tmins, tmaxs = tb pb = bbox_for_label(pred_stats, ndim, tid) if pb is None: pmins, pmaxs = tmins, tmaxs else: pmins, pmaxs = pb mins = [min(tmins[d], pmins[d]) for d in range(ndim)] maxs = [max(tmaxs[d], pmaxs[d]) for d in range(ndim)] # expand and clamp out_slices = [] for d in range(ndim): a = max(0, mins[d] - int(pad[d])) b = min(shape[d], maxs[d] + int(pad[d])) out_slices.append(slice(a, b)) return tuple(out_slices)
[docs] def compute_hausdorff_distance_roi( truth_label: np.ndarray, truth_stats: StatisticsDict | StatisticsSlicesDict, pred_label: np.ndarray, pred_stats: StatisticsDict | StatisticsSlicesDict, tid: int, voxel_size, max_distance: float, method: str = "standard", percentile: float | None = None, ): """ Same metric as compute_hausdorff_distance(), but operates on an ROI slice and builds masks only inside ROI. """ ndim = truth_label.ndim roi = roi_slices_for_pair( truth_stats, pred_stats, tid, voxel_size, ndim, truth_label.shape, max_distance, ) t_roi = truth_label[roi] p_roi = pred_label[roi] a = t_roi == tid b = p_roi == tid a_n = int(a.sum()) b_n = int(b.sum()) if a_n == 0 and b_n == 0: return 0.0 elif a_n == 0 or b_n == 0: return max_distance vs = np.asarray(voxel_size, dtype=np.float64) if vs.size != ndim: vs = vs[-ndim:] dist_to_b = distance_transform_edt(~b, sampling=vs) dist_to_a = distance_transform_edt(~a, sampling=vs) fwd = dist_to_b[a] bwd = dist_to_a[b] if method == "standard": d = max(fwd.max(initial=0.0), bwd.max(initial=0.0)) elif method == "modified": d = max(fwd.mean() if fwd.size else 0.0, bwd.mean() if bwd.size else 0.0) elif method == "percentile": if percentile is None: raise ValueError("'percentile' must be provided when method='percentile'") d = max( float(np.percentile(fwd, percentile)) if fwd.size else 0.0, float(np.percentile(bwd, percentile)) if bwd.size else 0.0, ) else: raise ValueError("method must be one of {'standard', 'modified', 'percentile'}") return float(min(d, max_distance))