import argparse
import gc
import json
import os
import shutil
from time import time
import zipfile
import numpy as np
import zarr
from scipy.spatial.distance import dice
from scipy.ndimage import distance_transform_edt
from fastremap import remap, unique, renumber
import cc3d
from cc3d.types import StatisticsDict, StatisticsSlicesDict
from zarr.errors import PathNotFoundError
from sklearn.metrics import jaccard_score
from tqdm import tqdm
from upath import UPath
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from .config import SUBMISSION_PATH, TRUTH_PATH, INSTANCE_CLASSES
from .utils import TEST_CROPS_DICT, MatchedCrop, rand_voi, get_git_hash
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
CAST_TO_NONE = [np.nan, np.inf, -np.inf, float("inf"), float("-inf")]
MAX_INSTANCE_THREADS = int(os.getenv("MAX_INSTANCE_THREADS", 3))
MAX_SEMANTIC_THREADS = int(os.getenv("MAX_SEMANTIC_THREADS", 25))
PER_INSTANCE_THREADS = int(os.getenv("PER_INSTANCE_THREADS", 25))
MAX_DISTANCE_CAP_EPS = float(os.getenv("MAX_DISTANCE_CAP_EPS", "1e-4"))
FINAL_INSTANCE_RATIO_CUTOFF = float(os.getenv("FINAL_INSTANCE_RATIO_CUTOFF", 10))
INITIAL_INSTANCE_RATIO_CUTOFF = float(os.getenv("INITIAL_INSTANCE_RATIO_CUTOFF", 50))
INSTANCE_RATIO_FACTOR = float(os.getenv("INSTANCE_RATIO_FACTOR", 5.0))
MAX_OVERLAP_EDGES = int(os.getenv("MAX_OVERLAP_EDGES", "5000000"))
[docs]
def ratio_cutoff(
nG: int,
R_base: float = FINAL_INSTANCE_RATIO_CUTOFF,
R_extra: float = INITIAL_INSTANCE_RATIO_CUTOFF,
k: float = INSTANCE_RATIO_FACTOR,
) -> float:
# nG==0 handled upstream (ratio undefined); return max tolerance for completeness
return float(R_base + R_extra * np.exp(-nG / k))
[docs]
def normalize_distance(distance: float, voxel_size) -> float:
"""
Normalize a distance value to [0, 1] using the maximum distance represented by a voxel
"""
if distance == np.inf:
return 0.0
return float((1.01 ** (-distance / np.linalg.norm(voxel_size))))
[docs]
def compute_default_max_distance(voxel_size, eps=MAX_DISTANCE_CAP_EPS) -> float:
v = np.linalg.norm(np.asarray(voxel_size, dtype=float))
return float(v * (np.log(1.0 / eps) / np.log(1.01)))
[docs]
def match_instances(gt: np.ndarray, pred: np.ndarray) -> dict | None:
"""
Matches instances between GT and Pred based on IoU.
Assumes IDs range from 1 to max(ID) (0 is background). If IDs are non-sequential (e.g., 1, 2, 5), the output matrix will contain empty rows/columns for missing IDs.
Returns a dictionary mapping pred IDs to gt IDs.
"""
if gt.shape != pred.shape:
raise ValueError("gt and pred must have the same shape")
# 1D views without copying if possible
g = np.ravel(gt)
p = np.ravel(pred)
# Number of instances (sequential ids -> max id)
nG = int(g.max()) if g.size else 0
nP = int(p.max()) if p.size else 0
# Early exits
if (nG == 0 and nP > 0) or (nP == 0 and nG > 0):
if nG == 0 and nP > 0:
logging.info("No GT instances; returning empty match.")
if nP == 0 and nG > 0:
logging.info("No Pred instances; returning empty match.")
return {}
elif nG == 0 and nP == 0:
logging.info("No GT or Pred instances; returning only background match.")
return {0: 0}
if (nP / nG) > ratio_cutoff(nG):
logging.warning(
f"WARNING: Skipping {nP} instances in submission, {nG} in ground truth, "
f"because there are too many instances in the submission."
)
return None
# Foreground (non-background) mask for each side and for pairwise overlaps
g_fg = g > 0
p_fg = p > 0
fg = g_fg & p_fg
# ---- Per-object areas (sizes) ----
# Use uint32 where possible to reduce memory; cast to int64 for safety if needed.
gt_sizes = np.bincount((g[g_fg].astype(np.int64) - 1), minlength=nG)[:, None]
pr_sizes = np.bincount((p[p_fg].astype(np.int64) - 1), minlength=nP)[None, :]
# ---- Intersections for observed pairs only (sparse counting) ----
gi = g[fg].astype(np.int64) - 1
pj = p[fg].astype(np.int64) - 1
if gi.size == 0:
# No overlaps anywhere -> IoU is all zeros
return {}
# Encode pairs to a single 64-bit key and count only present pairs
# Use unsigned to avoid negative-overflow corner cases.
gi_u = gi.astype(np.uint64)
pj_u = pj.astype(np.uint64)
key = gi_u * np.uint64(nP) + pj_u
uniq_keys, counts = np.unique(key, return_counts=True)
if uniq_keys.size > MAX_OVERLAP_EDGES:
logging.warning(
f"WARNING: Too many overlap edges ({uniq_keys.size}) — skipping instance scoring."
)
return None
rows = (uniq_keys // np.uint64(nP)).astype(np.int64)
cols = (uniq_keys % np.uint64(nP)).astype(np.int64)
# ---- IoU only for observed pairs, then scatter into dense matrix ----
inter = counts.astype(np.int64)
union = gt_sizes[rows, 0] + pr_sizes[0, cols] - inter
with np.errstate(divide="ignore", invalid="ignore"):
iou_vals = (inter / union).astype(np.float32)
# Keep only IoU > 0 edges (should already be true, but explicit)
keep = iou_vals > 0.0
rows = rows[keep]
cols = cols[keep]
iou_vals = iou_vals[keep]
if rows.size == 0:
return {}
# ---------------- OR-Tools Min-Cost Flow ----------------
from ortools.graph.python import min_cost_flow
mcf = min_cost_flow.SimpleMinCostFlow()
# Node indexing:
# source(0), GT nodes [1..nG], Pred nodes [1+nG .. nG+nP], sink(last)
source = 0
gt0 = 1
pred0 = gt0 + nG
sink = pred0 + nP
COST_SCALE = int(os.getenv("MCMF_COST_SCALE", "1000000"))
UNMATCH_COST = COST_SCALE + 1 # worse than any match in [0, COST_SCALE]
# ---- Build arcs into numpy arrays (dtype matters) ----
tails = []
heads = []
caps = []
costs = []
def add_arc(u: int, v: int, cap: int, cost: int) -> None:
tails.append(u)
heads.append(v)
caps.append(cap)
costs.append(cost)
# Source -> GT (each GT emits 1 unit)
for i in range(nG):
add_arc(source, gt0 + i, 1, 0)
# GT -> Sink "unmatched" option
for i in range(nG):
add_arc(gt0 + i, sink, 1, UNMATCH_COST)
# GT -> Pred edges for IoU>0
# (rows, cols are 0-based instance indices; +1 happens later when making label ids)
for r, c, iou in zip(rows, cols, iou_vals):
u = gt0 + int(r)
v = pred0 + int(c)
cost = int((1.0 - float(iou)) * COST_SCALE)
add_arc(u, v, 1, cost)
# Pred -> Sink capacity 1
for j in range(nP):
add_arc(pred0 + j, sink, 1, 0)
# Bulk add (use correct dtypes per API)
mcf.add_arcs_with_capacity_and_unit_cost(
np.asarray(tails, dtype=np.int32),
np.asarray(heads, dtype=np.int32),
np.asarray(caps, dtype=np.int64),
np.asarray(costs, dtype=np.int64),
)
# Supplies: push nG units from source to sink
mcf.set_node_supply(source, int(nG))
mcf.set_node_supply(sink, -int(nG))
status = mcf.solve()
if status != mcf.OPTIMAL:
logging.warning(f"Min-cost flow did not solve optimally (status={status}).")
return {}
# Extract matches: arcs GT->Pred with flow==1
mapping: dict[int, int] = {}
for a in range(mcf.num_arcs()):
if mcf.flow(a) != 1:
continue
u = mcf.tail(a)
v = mcf.head(a)
if gt0 <= u < pred0 and pred0 <= v < sink:
gt_id = (u - gt0) + 1 # back to label IDs (1-based, 0 is background)
pred_id = (v - pred0) + 1
mapping[pred_id] = gt_id
return mapping
[docs]
def optimized_hausdorff_distances(
truth_label,
pred_label,
voxel_size,
hausdorff_distance_max,
method="standard",
percentile: float | None = None,
):
"""
Compute per-truth-instance Hausdorff-like distances against the (already remapped)
prediction using multithreading. Returns a 1D float32 numpy array whose i-th
entry corresponds to truth_ids[i].
Parameters
----------
truth_label : np.ndarray
Ground-truth instance label volume (0 == background).
pred_label : np.ndarray
Prediction instance label volume that has already been remapped to align
with the GT ids (0 == background).
voxel_size : Sequence[float]
Physical voxel sizes in Z, Y, X (or Y, X) order.
hausdorff_distance_max : float
Cap for distances (use np.inf for uncapped).
method : {"standard", "modified", "percentile"}
"standard" -> classic Hausdorff (max of directed maxima)
"modified" -> mean of directed distances, then max of the two means
"percentile" -> use the given percentile of directed distances (requires
`percentile` to be provided).
percentile : float | None
Percentile (0-100) used when method=="percentile".
"""
# Unique GT ids (exclude background = 0)
truth_ids = unique(truth_label)
truth_ids = truth_ids[truth_ids != 0]
true_num = int(truth_ids.size)
if true_num == 0:
return np.empty((0,), dtype=np.float32)
voxel_size = np.asarray(voxel_size, dtype=np.float64)
truth_stats = cc3d.statistics(truth_label)
pred_stats = cc3d.statistics(pred_label)
def get_distance(i: int):
tid = int(truth_ids[i])
h_dist = compute_hausdorff_distance_roi(
truth_label,
truth_stats,
pred_label,
pred_stats,
tid,
voxel_size,
hausdorff_distance_max,
method=method,
percentile=percentile,
)
return i, float(h_dist)
dists = np.empty((true_num,), dtype=np.float32)
with ThreadPoolExecutor(max_workers=PER_INSTANCE_THREADS) as executor:
for idx, h in tqdm(
executor.map(get_distance, range(true_num)),
desc="Computing Hausdorff distances",
total=true_num,
dynamic_ncols=True,
):
dists[idx] = h
return dists
[docs]
def bbox_for_label(
stats: StatisticsDict | StatisticsSlicesDict,
ndim: int,
label_id: int,
):
"""
Try to get bbox without allocating a full boolean mask using cc3d statistics.
Falls back to mask-based bbox if cc3d doesn't provide expected fields.
Returns (mins, maxs) inclusive-exclusive in voxel indices, or None if missing.
"""
# stats = cc3d.statistics(label_vol)
# cc3d.statistics usually returns dict-like with keys per label id.
# There are multiple API variants; try common patterns.
if "bounding_boxes" in stats:
# bounding_boxes is a list where index corresponds to label_id
bounding_boxes = stats["bounding_boxes"]
if label_id >= len(bounding_boxes):
return None
bb = bounding_boxes[label_id]
if bb is None:
return None
# bb is a tuple of slices, convert to (mins, maxs)
if isinstance(bb, tuple) and all(isinstance(s, slice) for s in bb):
mins = [s.start for s in bb]
maxs = [s.stop for s in bb]
return mins, maxs
# bb might be (z0,z1,y0,y1,x0,x1) with end exclusive
mins = [bb[2 * k] for k in range(ndim)]
maxs = [bb[2 * k + 1] for k in range(ndim)]
return mins, maxs
if label_id in stats:
s = stats[label_id]
if "bounding_box" in s:
bb = s["bounding_box"]
mins = [bb[2 * k] for k in range(ndim)]
maxs = [bb[2 * k + 1] for k in range(ndim)]
return mins, maxs
[docs]
def roi_slices_for_pair(
truth_stats: StatisticsDict | StatisticsSlicesDict,
pred_stats: StatisticsDict | StatisticsSlicesDict,
tid: int,
voxel_size,
ndim: int,
shape: tuple[int, ...],
max_distance: float,
):
"""
ROI = union(bbox(truth==tid), bbox(pred==tid)) padded by P derived from max_distance.
Returns tuple of slices suitable for numpy indexing.
"""
vs = np.asarray(voxel_size, dtype=float)
if vs.size != ndim:
# tolerate vs longer (e.g. includes channel), take last ndim
vs = vs[-ndim:]
# padding per axis in voxels
pad = np.ceil(max_distance / vs).astype(int) + 2
tb = bbox_for_label(truth_stats, ndim, tid)
assert tb is not None, f"Truth ID {tid} not found in truth statistics."
tmins, tmaxs = tb
pb = bbox_for_label(pred_stats, ndim, tid)
if pb is None:
pmins, pmaxs = tmins, tmaxs
else:
pmins, pmaxs = pb
mins = [min(tmins[d], pmins[d]) for d in range(ndim)]
maxs = [max(tmaxs[d], pmaxs[d]) for d in range(ndim)]
# expand and clamp
out_slices = []
for d in range(ndim):
a = max(0, mins[d] - int(pad[d]))
b = min(shape[d], maxs[d] + int(pad[d]))
out_slices.append(slice(a, b))
return tuple(out_slices)
[docs]
def compute_hausdorff_distance_roi(
truth_label: np.ndarray,
truth_stats: StatisticsDict | StatisticsSlicesDict,
pred_label: np.ndarray,
pred_stats: StatisticsDict | StatisticsSlicesDict,
tid: int,
voxel_size,
max_distance: float,
method: str = "standard",
percentile: float | None = None,
):
"""
Same metric as compute_hausdorff_distance(), but operates on an ROI slice
and builds masks only inside ROI.
"""
ndim = truth_label.ndim
roi = roi_slices_for_pair(
truth_stats,
pred_stats,
tid,
voxel_size,
ndim,
truth_label.shape,
max_distance,
)
t_roi = truth_label[roi]
p_roi = pred_label[roi]
a = t_roi == tid
b = p_roi == tid
a_n = int(a.sum())
b_n = int(b.sum())
if a_n == 0 and b_n == 0:
return 0.0
elif a_n == 0 or b_n == 0:
return max_distance
vs = np.asarray(voxel_size, dtype=np.float64)
if vs.size != ndim:
vs = vs[-ndim:]
dist_to_b = distance_transform_edt(~b, sampling=vs)
dist_to_a = distance_transform_edt(~a, sampling=vs)
fwd = dist_to_b[a]
bwd = dist_to_a[b]
if method == "standard":
d = max(fwd.max(initial=0.0), bwd.max(initial=0.0))
elif method == "modified":
d = max(fwd.mean() if fwd.size else 0.0, bwd.mean() if bwd.size else 0.0)
elif method == "percentile":
if percentile is None:
raise ValueError("'percentile' must be provided when method='percentile'")
d = max(
float(np.percentile(fwd, percentile)) if fwd.size else 0.0,
float(np.percentile(bwd, percentile)) if bwd.size else 0.0,
)
else:
raise ValueError("method must be one of {'standard', 'modified', 'percentile'}")
return float(min(d, max_distance))
[docs]
def score_instance(
pred_label,
truth_label,
voxel_size,
hausdorff_distance_max=None,
) -> dict[str, float | str]:
"""
Score a single instance label volume against the ground truth instance label volume.
Args:
pred_label (np.ndarray): The predicted instance label volume.
truth_label (np.ndarray): The ground truth instance label volume.
voxel_size (tuple): The size of a voxel in each dimension.
hausdorff_distance_max (float): The maximum distance to consider for the Hausdorff distance.
Returns:
dict: A dictionary of scores for the instance label volume.
Example usage:
scores = score_instance(pred_label, truth_label)
"""
logging.info("Scoring instance segmentation...")
if hausdorff_distance_max is None:
hausdorff_distance_max = compute_default_max_distance(voxel_size)
logging.debug(
f"Using default maximum Hausdorff distance of {hausdorff_distance_max:.2f} for voxel size {voxel_size}."
)
# Relabel the predicted instance labels to be consistent with the ground truth instance labels
logging.info("Relabeling predicted instance labels...")
pred_label, n_pred = cc3d.connected_components(pred_label, return_N=True)
# pred_label, remapping = renumber(pred_label, in_place=True)
# n_pred = len(remapping) - 1 # exclude background
# Get stats that don't require matched instances
truth_binary = (truth_label > 0).ravel()
pred_binary = (pred_label > 0).ravel()
iou = jaccard_score(truth_binary, pred_binary, zero_division=1)
dice_score = 1 - dice(truth_binary, pred_binary)
binary_accuracy = float((truth_binary == pred_binary).mean())
voi = rand_voi(truth_label.astype(np.uint64), pred_label.astype(np.uint64))
del voi["voi_split_i"], voi["voi_merge_j"]
# Match instances between ground truth and prediction
mapping = match_instances(truth_label, pred_label)
if mapping is None:
# Too many instances in submission, skip scoring
return {
"accuracy": 0,
"binary_accuracy": binary_accuracy,
"hausdorff_distance": hausdorff_distance_max,
"normalized_hausdorff_distance": normalize_distance(
hausdorff_distance_max, voxel_size
),
"combined_score": 0,
"iou": iou,
"dice_score": dice_score,
"status": "skipped_too_many_instances",
**voi,
}
elif len(mapping) == 1 and 0 in mapping:
# Only background present in both ground truth and prediction
hausdorff_distances = [0.0]
elif len(mapping) > 0:
# Construct the volume for the matched instances
mapping[0] = 0 # background maps to background
pred_label = remap(
pred_label, mapping, in_place=True, preserve_missing_labels=True
)
hausdorff_distances = optimized_hausdorff_distances(
truth_label, pred_label, voxel_size, hausdorff_distance_max
)
matched_pred_ids = set(mapping.keys()) - {0}
pred_ids = set(np.arange(1, n_pred + 1)) - {0}
unmatched_pred = pred_ids - matched_pred_ids
if len(unmatched_pred) > 0:
hausdorff_distances = np.concatenate(
[
hausdorff_distances,
np.full(
len(unmatched_pred), hausdorff_distance_max, dtype=np.float32
),
]
)
else:
# No predictions to match (no GT XOR no Pred instances)
hausdorff_distances = [hausdorff_distance_max]
if len(hausdorff_distances) == 0:
hausdorff_distances = [hausdorff_distance_max]
# Compute the scores
logging.info("Computing accuracy score...")
accuracy = float((truth_label == pred_label).mean())
hausdorff_dist = np.mean(hausdorff_distances)
normalized_hausdorff_dist = np.mean(
[normalize_distance(hd, voxel_size) for hd in hausdorff_distances]
)
combined_score = (accuracy * normalized_hausdorff_dist) ** 0.5 # geometric mean
logging.info(f"Accuracy: {accuracy:.4f}")
logging.info(f"Hausdorff Distance: {hausdorff_dist:.4f}")
logging.info(f"Normalized Hausdorff Distance: {normalized_hausdorff_dist:.4f}")
logging.info(f"Combined Score: {combined_score:.4f}")
return {
"accuracy": accuracy,
"hausdorff_distance": hausdorff_dist,
"normalized_hausdorff_distance": normalized_hausdorff_dist,
"combined_score": combined_score,
"status": "scored",
"iou": iou,
"dice_score": dice_score,
"binary_accuracy": binary_accuracy,
**voi,
} # type: ignore
[docs]
def score_semantic(pred_label, truth_label) -> dict[str, float]:
"""
Score a single semantic label volume against the ground truth semantic label volume.
Args:
pred_label (np.ndarray): The predicted semantic label volume.
truth_label (np.ndarray): The ground truth semantic label volume.
Returns:
dict: A dictionary of scores for the semantic label volume.
Example usage:
scores = score_semantic(pred_label, truth_label)
"""
logging.info("Scoring semantic segmentation...")
# Flatten the label volumes and convert to binary
pred_label = (pred_label > 0.0).ravel()
truth_label = (truth_label > 0.0).ravel()
# Compute the scores
if np.sum(truth_label + pred_label) == 0:
# If there are no true or false positives, set the scores to 1
logging.debug("No true or false positives found. Setting scores to 1.")
dice_score = 1
iou_score = 1
else:
dice_score = 1 - dice(truth_label, pred_label)
iou_score = jaccard_score(truth_label, pred_label, zero_division=1)
scores = {
"iou": iou_score,
"dice_score": dice_score if not np.isnan(dice_score) else 1,
"binary_accuracy": float((truth_label == pred_label).mean()),
"status": "scored",
}
logging.info(f"IoU: {scores['iou']:.4f}")
logging.info(f"Dice Score: {scores['dice_score']:.4f}")
return scores
[docs]
def score_label(
pred_label_path,
label_name,
crop_name,
truth_path=TRUTH_PATH,
instance_classes=INSTANCE_CLASSES,
):
"""
Score a single label volume against the ground truth label volume.
Args:
pred_label_path (str): The path to the predicted label volume.
truth_path (str): The path to the ground truth label volume.
instance_classes (list): A list of instance classes.
Returns:
dict: A dictionary of scores for the label volume.
Example usage:
scores = score_label('pred.zarr/test_volume/label1')
"""
if pred_label_path is None:
logging.info(f"Label {label_name} not found in submission volume {crop_name}.")
return (
crop_name,
label_name,
empty_label_score(
label=label_name,
crop_name=crop_name,
instance_classes=instance_classes,
truth_path=truth_path,
),
)
logging.info(f"Scoring {pred_label_path}...")
truth_path = UPath(truth_path)
# Load the predicted and ground truth label volumes
truth_label_path = (truth_path / crop_name / label_name).path
truth_label_ds = zarr.open(truth_label_path, mode="r")
truth_label = truth_label_ds[:]
crop = TEST_CROPS_DICT[int(crop_name.removeprefix("crop")), label_name]
pred_label = match_crop_space(
pred_label_path,
label_name,
crop.voxel_size,
crop.shape,
crop.translation,
)
mask_path = truth_path / crop_name / f"{label_name}_mask"
if mask_path.exists():
# Mask out uncertain regions resulting from low-res ground truth annotations
logging.info(f"Masking {label_name} with {mask_path}...")
mask = zarr.open(mask_path.path, mode="r")[:]
pred_label = pred_label * mask
truth_label = truth_label * mask
del mask
# Compute the scores
if label_name in instance_classes:
logging.info(
f"Starting an instance evaluation for {label_name} in {crop_name}..."
)
timer = time()
results = score_instance(pred_label, truth_label, crop.voxel_size)
logging.info(
f"Finished instance evaluation for {label_name} in {crop_name} in {time() - timer:.2f} seconds..."
)
else:
results = score_semantic(pred_label, truth_label)
results["num_voxels"] = int(np.prod(truth_label.shape))
results["voxel_size"] = crop.voxel_size
results["is_missing"] = False
# drop big arrays before returning
del truth_label, pred_label, truth_label_ds
gc.collect()
return crop_name, label_name, results
[docs]
def empty_label_score(
label, crop_name, instance_classes=INSTANCE_CLASSES, truth_path=TRUTH_PATH
):
truth_path = UPath(truth_path)
ds = zarr.open((truth_path / crop_name / label).path, mode="r")
voxel_size = ds.attrs["voxel_size"]
if label in instance_classes:
truth_path = UPath(truth_path)
return {
"accuracy": 0,
"hausdorff_distance": compute_default_max_distance(voxel_size),
"normalized_hausdorff_distance": 0,
"combined_score": 0,
"num_voxels": int(np.prod(ds.shape)),
"voxel_size": voxel_size,
"is_missing": True,
"status": "missing",
}
else:
return {
"iou": 0,
"dice_score": 0,
"num_voxels": int(np.prod(ds.shape)),
"voxel_size": voxel_size,
"is_missing": True,
"status": "missing",
}
[docs]
def ensure_zgroup(path: UPath) -> zarr.Group:
"""
Ensure that the given path can be opened as a zarr Group. If a .zgroup is not present, add it.
"""
try:
return zarr.open(path.path, mode="r")
except PathNotFoundError:
if not path.is_dir():
raise ValueError(f"Path {path} is not a directory.")
# Add a .zgroup file to force Zarr-2 format
(path / ".zgroup").write_text('{"zarr_format": 2}')
return zarr.open(path.path, mode="r")
[docs]
def get_evaluation_args(
volumes,
submission_path,
truth_path=TRUTH_PATH,
instance_classes=INSTANCE_CLASSES,
) -> list[tuple]:
"""
Get the arguments for scoring each label in the submission.
Args:
volumes (list): A list of volumes to score.
submission_path (str): The path to the submission volume.
truth_path (str): The path to the ground truth volume.
instance_classes (list): A list of instance classes.
Returns:
A list of tuples containing the arguments for each label to be scored.
"""
if not isinstance(volumes, (tuple, list)):
volumes = [volumes]
score_label_arglist = []
for volume in volumes:
submission_path = UPath(submission_path)
pred_volume_path = submission_path / volume
logging.info(f"Scoring {pred_volume_path}...")
truth_path = UPath(truth_path)
# Find labels to score
pred_labels = [a for a in ensure_zgroup(pred_volume_path).array_keys()]
crop_name = pred_volume_path.name
truth_labels = [a for a in ensure_zgroup(truth_path / crop_name).array_keys()]
found_labels = list(set(pred_labels) & set(truth_labels))
missing_labels = list(set(truth_labels) - set(pred_labels))
# Score_label arguments for each label
score_label_arglist.extend(
[
(
pred_volume_path / label if label in found_labels else None,
label,
crop_name,
truth_path,
instance_classes,
)
for label in truth_labels
]
)
logging.info(f"Missing labels: {missing_labels}")
return score_label_arglist
[docs]
def missing_volume_score(
truth_path, volume, instance_classes=INSTANCE_CLASSES
) -> list[tuple]:
"""
Score a missing volume as 0's, congruent with the score_volume function.
Args:
truth_path (str): The path to the ground truth volume.
volume (str): The name of the volume.
instance_classes (list): A list of instance classes.
Returns:
dict: A dictionary of scores for the volume.
Example usage:
scores = missing_volume_score('truth.zarr/test_volume')
"""
logging.info(f"Scoring missing volume {volume}...")
truth_path = UPath(truth_path)
truth_volume_path = truth_path / volume
# Find labels to score
truth_labels = [a for a in ensure_zgroup(truth_volume_path).array_keys()]
# Score each label
scores = {
label: empty_label_score(label, volume, instance_classes, truth_path)
for label in truth_labels
}
return scores
[docs]
def combine_scores(
scores,
include_missing=True,
instance_classes=INSTANCE_CLASSES,
cast_to_none=CAST_TO_NONE,
):
"""
Combine scores across volumes, normalizing by the number of voxels.
Args:
scores (dict): A dictionary of scores for each volume, as returned by `score_volume`.
include_missing (bool): Whether to include missing volumes in the combined scores.
instance_classes (list): A list of instance classes.
cast_to_none (list): A list of values to cast to None in the combined scores.
Returns:
dict: A dictionary of combined scores across all volumes.
Example usage:
combined_scores = combine_scores(scores)
"""
# Combine label scores across volumes, normalizing by the number of voxels
logging.info(f"Combining label scores...")
scores = scores.copy()
label_scores = {}
total_voxels = {}
for ds, these_scores in scores.items():
for label, this_score in these_scores.items():
# logging.info(this_score)
if this_score["is_missing"] and not include_missing:
continue
if label in instance_classes:
if label not in label_scores:
label_scores[label] = {
"accuracy": 0,
"hausdorff_distance": 0,
"normalized_hausdorff_distance": 0,
"combined_score": 0,
}
total_voxels[label] = 0
else:
if label not in label_scores:
label_scores[label] = {"iou": 0, "dice_score": 0}
total_voxels[label] = 0
for key in label_scores[label].keys():
if this_score[key] is None:
continue
label_scores[label][key] += this_score[key] * this_score["num_voxels"]
if this_score[key] in cast_to_none:
scores[ds][label][key] = None
total_voxels[label] += this_score["num_voxels"]
# Normalize back to the total number of voxels
for label in label_scores:
if label in instance_classes:
label_scores[label]["accuracy"] /= total_voxels[label]
label_scores[label]["hausdorff_distance"] /= total_voxels[label]
label_scores[label]["normalized_hausdorff_distance"] /= total_voxels[label]
label_scores[label]["combined_score"] /= total_voxels[label]
else:
label_scores[label]["iou"] /= total_voxels[label]
label_scores[label]["dice_score"] /= total_voxels[label]
# Cast to None if the value is in `cast_to_none`
for key in label_scores[label]:
if label_scores[label][key] in cast_to_none:
label_scores[label][key] = None
scores["label_scores"] = label_scores
# Compute the overall score
logging.info("Computing overall scores...")
overall_instance_scores = []
overall_semantic_scores = []
instance_total_voxels = sum(
total_voxels[label] for label in label_scores if label in instance_classes
)
semantic_total_voxels = sum(
total_voxels[label] for label in label_scores if label not in instance_classes
)
for label in label_scores:
if label in instance_classes:
overall_instance_scores += [
label_scores[label]["combined_score"] * total_voxels[label]
]
else:
overall_semantic_scores += [
label_scores[label]["iou"] * total_voxels[label]
]
scores["overall_instance_score"] = (
np.nansum(overall_instance_scores) / instance_total_voxels
if overall_instance_scores
else 0
)
scores["overall_semantic_score"] = (
np.nansum(overall_semantic_scores) / semantic_total_voxels
if overall_semantic_scores
else 0
)
scores["overall_score"] = (
scores["overall_instance_score"] * scores["overall_semantic_score"]
) ** 0.5 # geometric mean
return scores
[docs]
def ensure_valid_submission(submission_path: UPath):
"""
Ensure that the unzipped submission path is a valid Zarr-2 file.
Args:
submission_path (str): The path to the unzipped submission Zarr-2 file.
Raises:
ValueError: If the submission is not a valid unzipped Zarr-2 file.
"""
# See if a Zarr was incorrectly zipped inside other folder(s)
# If so, move contents from .zarr folder to submission_path and warn user
zarr_folders = list(submission_path.glob("**/*.zarr"))
if len(zarr_folders) == 0:
# Try forcing Zarr-2 format by adding .zgroup if missing
try:
ensure_zgroup(submission_path)
logging.warning(
f"Submission at {submission_path} did not contain a .zgroup file. Added one to force Zarr-2 format."
)
except Exception as e:
raise ValueError(
f"Submission at {submission_path} is not a valid unzipped Zarr-2 file."
) from e
elif len(zarr_folders) == 1:
zarr_folder = zarr_folders[0]
logging.warning(
f"Submission at {submission_path} contains a Zarr folder inside subfolder(s) at {zarr_folder}. Moving contents to the root submission folder."
)
# Move contents of zarr_folder to submission_path
for item in zarr_folder.iterdir():
target = submission_path / item.name
if target.exists():
if target.is_file():
target.unlink()
else:
shutil.rmtree(target)
shutil.move(str(item), str(submission_path))
# Remove empty folders
for parent in zarr_folder.parents:
if parent == submission_path:
break
try:
parent.rmdir()
except OSError as e:
logging.warning(
"Failed to remove directory %s while cleaning nested Zarr submission: %s",
parent,
e,
)
# Try opening again
try:
ensure_zgroup(submission_path)
logging.warning(
f"Submission at {submission_path} did not contain a .zgroup file. Added one to force Zarr-2 format."
)
except Exception as e:
raise ValueError(
f"Submission at {submission_path} is not a valid unzipped Zarr-2 file."
) from e
elif len(zarr_folders) > 1:
raise ValueError(
f"Submission at {submission_path} contains multiple Zarr folders. Please ensure only one Zarr-2 file is submitted."
)
[docs]
def score_submission(
submission_path=UPath(SUBMISSION_PATH).with_suffix(".zip").path,
result_file=None,
truth_path=TRUTH_PATH,
instance_classes=INSTANCE_CLASSES,
):
"""
Score a submission against the ground truth data.
Args:
submission_path (str): The path to the zipped submission Zarr-2 file.
result_file (str): The path to save the scores. Default is 'results.json'.
truth_path (str): The path to the ground truth Zarr-2 file.
instance_classes (list): A list of instance classes.
Returns:
dict: A dictionary of scores for the submission.
Example usage:
scores = score_submission('submission.zip')
The results json is a dictionary with the following structure:
{
"volume" (the name of the ground truth volume): {
"label" (the name of the predicted class): {
(For semantic segmentation)
"iou": (the intersection over union score),
"dice_score": (the dice score),
OR
(For instance segmentation)
"accuracy": (the accuracy score),
"haussdorf_distance": (the haussdorf distance),
"normalized_haussdorf_distance": (the normalized haussdorf distance),
"combined_score": (the geometric mean of the accuracy and normalized haussdorf distance),
}
"num_voxels": (the number of voxels in the ground truth volume),
}
"label_scores": {
(the name of the predicted class): {
(For semantic segmentation)
"iou": (the mean intersection over union score),
"dice_score": (the mean dice score),
OR
(For instance segmentation)
"accuracy": (the mean accuracy score),
"haussdorf_distance": (the mean haussdorf distance),
"combined_score": (the mean geometric mean of the accuracy and haussdorf distance),
}
"overall_score": (the mean of the combined scores across all classes),
}
"""
logging.info(f"Scoring {submission_path}...")
start_time = time()
# Unzip the submission
submission_path = unzip_file(submission_path)
# Find volumes to score
logging.info(f"Scoring volumes in {submission_path}...")
ensure_valid_submission(UPath(submission_path))
pred_volumes = [d.name for d in UPath(submission_path).glob("*") if d.is_dir()]
truth_path = UPath(truth_path)
logging.info(f"Volumes: {pred_volumes}")
logging.info(f"Truth path: {truth_path}")
truth_volumes = [d.name for d in truth_path.glob("*") if d.is_dir()]
logging.info(f"Truth volumes: {truth_volumes}")
found_volumes = list(set(pred_volumes) & set(truth_volumes))
missing_volumes = list(set(truth_volumes) - set(pred_volumes))
if len(found_volumes) == 0:
# Check if "crop" prefixes are missing
prefixed_pred_volumes = [f"crop{v}" for v in pred_volumes]
found_volumes = list(set(prefixed_pred_volumes) & set(truth_volumes))
if len(found_volumes) == 0:
raise ValueError(
"No volumes found to score. Make sure the submission is formatted correctly."
)
missing_volumes = list(set(truth_volumes) - set(prefixed_pred_volumes))
# Move predicted volumes to have "crop" prefix
for v in pred_volumes:
old_path = UPath(submission_path) / v
new_path = UPath(submission_path) / f"crop{v}"
try:
old_path.move(new_path)
except Exception as exc:
msg = (
f"Failed to rename predicted volume directory '{old_path}' to "
f"'{new_path}'. This may be due to missing files, insufficient "
"permissions, or an existing destination directory. Cannot "
"continue evaluation."
)
logging.error(msg)
raise RuntimeError(msg) from exc
logging.info(f"Scoring volumes: {found_volumes}")
if len(missing_volumes) > 0:
logging.info(f"Missing volumes: {missing_volumes}")
logging.info("Scoring missing volumes as 0's")
scores = {
volume: missing_volume_score(
truth_path, volume, instance_classes=instance_classes
)
for volume in missing_volumes
}
# Get all prediction paths to evaluate
evaluation_args = get_evaluation_args(
found_volumes,
submission_path=UPath(submission_path),
truth_path=truth_path,
instance_classes=instance_classes,
)
# Score each volume
logging.info(
f"Scoring volumes in parallel, using {MAX_INSTANCE_THREADS} instance threads and {MAX_SEMANTIC_THREADS} semantic threads..."
)
instance_pool = ProcessPoolExecutor(MAX_INSTANCE_THREADS)
semantic_pool = ProcessPoolExecutor(MAX_SEMANTIC_THREADS)
futures = []
for args in evaluation_args:
if args[1] in instance_classes:
futures.append(instance_pool.submit(score_label, *args))
else:
futures.append(semantic_pool.submit(score_label, *args))
results = []
for future in tqdm(
as_completed(futures),
desc="Scoring volumes",
total=len(futures),
dynamic_ncols=True,
leave=True,
):
results.append(future.result())
all_scores, found_scores = update_scores(
scores, results, result_file, instance_classes=instance_classes
)
logging.info("Scores combined across all test volumes:")
logging.info(
f"\tOverall Instance Score: {all_scores['overall_instance_score']:.4f}"
)
logging.info(
f"\tOverall Semantic Score: {all_scores['overall_semantic_score']:.4f}"
)
logging.info(f"\tOverall Score: {all_scores['overall_score']:.4f}")
logging.info("Scores combined across test volumes with data submitted:")
logging.info(
f"\tOverall Instance Score: {found_scores['overall_instance_score']:.4f}"
)
logging.info(
f"\tOverall Semantic Score: {found_scores['overall_semantic_score']:.4f}"
)
logging.info(f"\tOverall Score: {found_scores['overall_score']:.4f}")
logging.info(f"Submission scored in {time() - start_time:.2f} seconds")
if result_file is None:
logging.info("Final combined scores:")
logging.info(all_scores)
return all_scores
else:
logging.info("Evaluation successful.")
[docs]
def num_evals_done(all_scores):
num_evals_done = 0
for volume, scores in all_scores.items():
if "crop" in volume:
num_evals_done += len(scores.keys())
return num_evals_done
[docs]
def sanitize_scores(scores):
"""
Sanitize scores by converting NaN values to None.
Args:
scores (dict): A dictionary of scores.
Returns:
dict: A sanitized dictionary of scores.
"""
for volume, volume_scores in scores.items():
if isinstance(volume_scores, dict):
for label, label_scores in volume_scores.items():
if isinstance(label_scores, dict):
for key, value in label_scores.items():
if value is None:
continue
if isinstance(value, str):
continue
if not np.isscalar(value) and len(value) == 1:
value = value[0]
if np.isscalar(value):
if np.isnan(value) or np.isinf(value) or np.isneginf(value):
scores[volume][label][key] = None
elif isinstance(value, np.floating):
scores[volume][label][key] = float(value)
else:
if any(
[
np.isnan(v) or np.isinf(v) or np.isneginf(v)
for v in value
]
):
scores[volume][label][key] = None
elif any([isinstance(v, np.floating) for v in value]):
scores[volume][label][key] = [float(v) for v in value]
return scores
[docs]
def update_scores(scores, results, result_file, instance_classes=INSTANCE_CLASSES):
start_time = time()
logging.info(f"Updating scores in {result_file}...")
# Check the types of the inputs
assert isinstance(scores, dict)
assert isinstance(results, list)
# Combine the results into a dictionary
# TODO: This is technically inefficient, but it works for now
for crop_name, label_name, result in results:
if crop_name not in scores:
scores[crop_name] = {}
scores[crop_name][label_name] = result
# Combine label scores across volumes, normalizing by the number of voxels
all_scores = combine_scores(
scores, include_missing=True, instance_classes=instance_classes
)
all_scores["total_evals"] = len(TEST_CROPS_DICT)
all_scores["num_evals_done"] = num_evals_done(all_scores)
all_scores["git_version"] = get_git_hash()
found_scores = combine_scores(
scores, include_missing=False, instance_classes=instance_classes
)
if result_file is not None:
logging.info(f"Saving collected scores to {result_file}...")
with open(result_file, "w") as f:
json.dump(sanitize_scores(all_scores), f, indent=4)
found_result_file = str(result_file).replace(
UPath(result_file).suffix, "_submitted_only" + UPath(result_file).suffix
)
with open(found_result_file, "w") as f:
json.dump(sanitize_scores(found_scores), f, indent=4)
logging.info(
f"Scores updated in {result_file} and {found_result_file} in {time() - start_time:.2f} seconds"
)
else:
logging.info("Final combined scores:")
logging.info(all_scores)
return all_scores, found_scores
[docs]
def resize_array(arr, target_shape, pad_value=0):
"""
Resize an array to a target shape by padding or cropping as needed.
Parameters:
arr (np.ndarray): Input array to resize.
target_shape (tuple): Desired shape for the output array.
pad_value (int, float, etc.): Value to use for padding if the array is smaller than the target shape.
Returns:
np.ndarray: Resized array with the specified target shape.
"""
arr_shape = arr.shape
resized_arr = arr
# Pad if the array is smaller than the target shape
pad_width = []
for i in range(len(target_shape)):
if arr_shape[i] < target_shape[i]:
# Padding needed: calculate amount for both sides
pad_before = (target_shape[i] - arr_shape[i]) // 2
pad_after = target_shape[i] - arr_shape[i] - pad_before
pad_width.append((pad_before, pad_after))
else:
# No padding needed for this dimension
pad_width.append((0, 0))
if any(pad > 0 for pads in pad_width for pad in pads):
resized_arr = np.pad(
resized_arr, pad_width, mode="constant", constant_values=pad_value
)
# Crop if the array is larger than the target shape
slices = []
for i in range(len(target_shape)):
if arr_shape[i] > target_shape[i]:
# Calculate cropping slices to center the crop
start = (arr_shape[i] - target_shape[i]) // 2
end = start + target_shape[i]
slices.append(slice(start, end))
else:
# No cropping needed for this dimension
slices.append(slice(None))
return resized_arr[tuple(slices)]
[docs]
def match_crop_space(path, class_label, voxel_size, shape, translation) -> np.ndarray:
mc = MatchedCrop(
path=path,
class_label=class_label,
target_voxel_size=voxel_size,
target_shape=shape,
target_translation=translation,
instance_classes=INSTANCE_CLASSES,
semantic_threshold=0.5,
pad_value=0,
)
return mc.load_aligned()
[docs]
def unzip_file(zip_path):
"""
Unzip a zip file to a specified directory.
Args:
zip_path (str): The path to the zip file.
Example usage:
unzip_file('submission.zip')
"""
logging.info(f"Unzipping {zip_path}...")
saved_path = UPath(zip_path).with_suffix(".zarr").path
if UPath(saved_path).exists():
logging.info(f"Using existing unzipped path at {saved_path}")
return UPath(saved_path)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(saved_path)
logging.info(f"Unzipped {zip_path} to {saved_path}")
return UPath(saved_path)
if __name__ == "__main__":
# When called on the commandline, evaluate the submission
# example usage: python evaluate.py submission.zip
argparser = argparse.ArgumentParser()
argparser.add_argument(
"submission_file", help="Path to submission zip file to score"
)
argparser.add_argument(
"result_file",
nargs="?",
help="If provided, store submission results in this file. Else print them to stdout",
)
argparser.add_argument(
"--truth-path", default=TRUTH_PATH, help="Path to zarr containing ground truth"
)
args = argparser.parse_args()
score_submission(args.submission_file, args.result_file, args.truth_path)