import shutil
import sys
from time import time
import numpy as np
import requests
from tqdm import tqdm
from cellmap_segmentation_challenge.utils import get_tested_classes
from cellmap_segmentation_challenge import TRUTH_PATH
import zarr
import git
from upath import UPath
from scipy import ndimage as ndi
[docs]
def iou(a: np.ndarray, b: np.ndarray) -> float:
a = a.astype(bool)
b = b.astype(bool)
inter = np.logical_and(a, b).sum()
union = np.logical_or(a, b).sum()
return float(inter) / float(union) if union else 1.0
def _disk(radius: int) -> np.ndarray:
"""Binary disk structuring element."""
r = int(radius)
y, x = np.ogrid[-r : r + 1, -r : r + 1]
return (x * x + y * y) <= r * r
def _boundary_bands(G: np.ndarray, band: int, connectivity: int = 1):
"""
Returns:
inner_band: pixels in G near boundary (preferred FN removal candidates)
outer_band: pixels outside G near boundary (preferred FP addition candidates)
"""
G = G.astype(bool)
if band <= 0:
inner = G.copy()
outer = (~G).copy()
return inner, outer
# Use disk SE for spatial realism
se = _disk(band)
# Inner band = G minus an eroded version of G
G_er = ndi.binary_erosion(G, structure=se, iterations=1, border_value=0)
inner_band = np.logical_and(G, ~G_er)
# Outer band = dilated G minus G
G_di = ndi.binary_dilation(G, structure=se, iterations=1)
outer_band = np.logical_and(G_di, ~G)
return inner_band, outer_band
def _sample_indices(mask: np.ndarray, k: int, rng: np.random.Generator) -> np.ndarray:
"""Uniformly sample k True locations (flat indices) without replacement."""
idx = np.flatnonzero(mask)
if k <= 0 or idx.size == 0:
return np.array([], dtype=np.int64)
k = min(k, idx.size)
return rng.choice(idx, size=k, replace=False)
def _add_blob(
P: np.ndarray,
candidate: np.ndarray,
add_k: int,
rng: np.random.Generator,
blob_radius: int = 8,
n_seeds: int = 6,
) -> int:
"""
Adds 'blob-like' FP pixels by planting random seeds and dilating them.
Returns how many pixels were actually added.
"""
if add_k <= 0:
return 0
cand_idx = np.flatnonzero(candidate)
if cand_idx.size == 0:
return 0
# Create seed map
seeds = np.zeros_like(P, dtype=bool)
seeds_n = min(n_seeds, cand_idx.size)
seed_idx = rng.choice(cand_idx, size=seeds_n, replace=False)
seeds.flat[seed_idx] = True
# Grow seeds into blobs
se = _disk(max(1, blob_radius))
blob = ndi.binary_dilation(seeds, structure=se, iterations=1)
# Restrict to candidate region + not already in P
blob = np.logical_and(blob, candidate)
blob = np.logical_and(blob, ~P)
# If blob too big, randomly subsample to exact add_k
blob_idx = np.flatnonzero(blob)
if blob_idx.size == 0:
return 0
choose = (
blob_idx
if blob_idx.size <= add_k
else rng.choice(blob_idx, size=add_k, replace=False)
)
before = P.sum()
P.flat[choose] = True
return int(P.sum() - before)
def _remove_blob(
P: np.ndarray,
candidate: np.ndarray,
rem_k: int,
rng: np.random.Generator,
blob_radius: int = 8,
n_seeds: int = 6,
) -> int:
"""
Removes 'blob-like' FN pixels by planting seeds inside candidate and dilating them.
Returns how many pixels were actually removed.
"""
if rem_k <= 0:
return 0
cand_idx = np.flatnonzero(candidate)
if cand_idx.size == 0:
return 0
seeds = np.zeros_like(P, dtype=bool)
seeds_n = min(n_seeds, cand_idx.size)
seed_idx = rng.choice(cand_idx, size=seeds_n, replace=False)
seeds.flat[seed_idx] = True
se = _disk(max(1, blob_radius))
blob = ndi.binary_dilation(seeds, structure=se, iterations=1)
# Restrict to candidate region + currently in P
blob = np.logical_and(blob, candidate)
blob = np.logical_and(blob, P)
blob_idx = np.flatnonzero(blob)
if blob_idx.size == 0:
return 0
choose = (
blob_idx
if blob_idx.size <= rem_k
else rng.choice(blob_idx, size=rem_k, replace=False)
)
before = P.sum()
P.flat[choose] = False
return int(before - P.sum())
[docs]
def perturb_mask_realistic(
G: np.ndarray,
target_iou: float,
p_fn: float = 0.5,
band: int = 3,
style: str = "ring", # "ring" or "blob"
blob_radius: int = 8,
blob_seeds: int = 6,
rng: np.random.Generator | None = None,
max_tries: int = 4000,
):
"""
Make a perturbed mask P from ground-truth G with approx target IoU, using spatially realistic errors.
Construction:
Start P = G
Remove r pixels (FN) mostly from inner boundary band
Add a pixels (FP) mostly from outer boundary band
Exact-count IoU relationship (counts):
IoU = (g - r) / (g + a)
"""
if rng is None:
rng = np.random.default_rng()
G = G.astype(bool)
g = int(G.sum())
if g == 0:
raise ValueError("Ground-truth mask is empty; cannot target IoU reliably.")
if not (0.0 < target_iou <= 1.0):
raise ValueError("target_iou must be in (0, 1].")
if target_iou == 1.0:
return G.copy()
bg = ~G
b = int(bg.sum())
# Choose r then compute a = (g - r)/t - g
def a_from_r(r: int) -> int:
return int(round((g - r) / target_iou - g))
# Prefer both FP and FN: try to find r such that a>0 and feasible
r0 = int(np.clip(round(p_fn * g * (1 - target_iou)), 1, g - 1))
best = None
jitter = max(5, g // 80)
for _ in range(max_tries):
r = int(np.clip(r0 + rng.integers(-jitter, jitter + 1), 0, g))
a = a_from_r(r)
if 0 <= a <= b:
if r > 0 and a > 0:
best = (r, a)
break
if best is None:
best = (r, a)
if best is None:
raise RuntimeError(
"Could not find feasible (FN removals r, FP adds a). Try different target_iou."
)
r, a = best
P = G.copy()
# Candidate bands near boundary
inner_band, outer_band = _boundary_bands(G, band=band)
# If band degenerates (tiny objects), fall back to whole regions
inner_cand = inner_band if inner_band.any() else G
outer_cand = outer_band if outer_band.any() else (~G)
# --- Apply FN removals ---
if style == "ring":
# remove from inner band first
rem_idx = _sample_indices(inner_cand & P, r, rng)
P.flat[rem_idx] = False
# If not enough (rare), remove remaining from anywhere inside
remaining = r - rem_idx.size
if remaining > 0:
rem2 = _sample_indices(G & P, remaining, rng)
P.flat[rem2] = False
elif style == "blob":
removed = _remove_blob(
P, inner_cand, r, rng, blob_radius=blob_radius, n_seeds=blob_seeds
)
remaining = r - removed
if remaining > 0:
# finish by ring-like sampling if blobs didn't hit exact count
rem_idx = _sample_indices((G & P), remaining, rng)
P.flat[rem_idx] = False
else:
raise ValueError("style must be 'ring' or 'blob'")
# --- Apply FP additions ---
if style == "ring":
add_idx = _sample_indices(outer_cand & (~P), a, rng)
P.flat[add_idx] = True
remaining = a - add_idx.size
if remaining > 0:
add2 = _sample_indices((~G) & (~P), remaining, rng)
P.flat[add2] = True
elif style == "blob":
added = _add_blob(
P, outer_cand, a, rng, blob_radius=blob_radius, n_seeds=blob_seeds
)
remaining = a - added
if remaining > 0:
add_idx = _sample_indices(((~G) & (~P)), remaining, rng)
P.flat[add_idx] = True
# (Optional) return counts for debugging
achieved = iou(G, P)
info = {
"target_iou": target_iou,
"achieved_iou": achieved,
"g": g,
"fn_removed_r": r,
"fp_added_a": a,
"band": band,
"style": style,
}
return P, info
[docs]
def construct_test_crop_manifest(
path_root: str,
search_path: str = "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}",
write_path: str | None = (UPath(__file__).parent / "test_crop_manifest.csv").path,
verbose: bool = False,
) -> None | list[str]:
"""
Construct a manifest file for testing crops from a given path.
Parameters
----------
path_root : str
Path to the directory containing the datasets.
search_path : str, optional
Format string to search for the crops. The default is "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}". The function assumes that the keys appear in the file tree in the following order: 1) "path_root", 2) "dataset", 3) "crop", 4) "label"
write_path : str, optional
Path to write the manifest file. The default is "test_crop_manifest.csv".
verbose : bool, optional
Print verbose output. The default is False.
"""
# Get the tested classes
tested_classes = set(get_tested_classes())
# Construct the manifest
manifest = [
"crop_name,dataset,class_label,voxel_size,translation,shape",
]
# Get datasets
datasets = [
d.name
for d in UPath(
search_path.split("{dataset}")[0].format(path_root=path_root)
).iterdir()
if d.is_dir()
]
for dataset in datasets:
print(f"Processing dataset: {dataset}")
# Get crops
crops = [
d.name
for d in UPath(
search_path.split("{crop}")[0].format(
path_root=path_root, dataset=dataset
)
).iterdir()
if d.is_dir()
]
for crop in crops:
if verbose:
print(f"\tProcessing crop: {crop}")
# Get labels in crop
had_classes = set(
[
d.name
for d in UPath(
search_path.split("{label}")[0].format(
path_root=path_root, dataset=dataset, crop=crop
)
).iterdir()
if d.is_dir()
]
)
# Filter for tested classes
labels = list(had_classes.intersection(tested_classes))
for label in labels:
if verbose:
print(f"\t\tProcessing label: {label}")
# Get the zarr file
zarr_file = zarr.open(
search_path.format(
path_root=path_root, dataset=dataset, crop=crop, label=label
),
mode="r",
)
# Get the metadata
metadata = zarr_file.attrs.asdict()["multiscales"][0]["datasets"][0][
"coordinateTransformations"
]
for meta in metadata:
if meta["type"] == "translation":
translation = format_coordinates(meta["translation"])
elif meta["type"] == "scale":
voxel_size = format_coordinates(meta["scale"])
shape = format_coordinates(zarr_file["s0"].shape)
manifest.append(
f"{crop.replace('crop', '')},{dataset},{label},{voxel_size},{translation},{shape}"
)
if verbose:
print(
f"\t\t\tScale: {voxel_size}\n\t\t\tTranslation: {translation}\n\t\t\tShape: {shape}"
)
if write_path is None:
return manifest
# Write the manifest
with open(write_path, "w") as f:
f.write("\n".join(manifest))
print(f"Manifest written to: {write_path}")
[docs]
def construct_truth_dataset(
path_root: str,
search_path: str = "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}",
destination: str = TRUTH_PATH,
write_path: str = "{crop}/{label}",
):
"""
Construct a consolidated Zarr file for the groundtruth datasets, to use for evaluation.
Parameters
----------
path_root : str
Path to the directory containing the datasets.
search_path : str, optional
Format string to search for the crops. The default is "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}". The function assumes that the keys appear in the file tree in the following order: 1) "path_root", 2) "dataset", 3) "crop", 4) "label"
destination : str, optional
Path to write the consolidated Zarr file. The default is "cellmap-segmentation-challenge/data/ground_truth.zarr".
write_path : str, optional
Format string to write the crops to within the destination Zarr. The default is "{crop}/{label}".
"""
start_time = time()
# Get the test crop manifested
manifest = construct_test_crop_manifest(path_root, search_path, write_path=None)
# Open the destination Zarr folder
if UPath(destination).exists():
print(f"Removing existing ground truth dataset at: {destination}")
shutil.rmtree(destination)
# ground_truth = zarr.open_group(destination, mode="w")
ground_truth = zarr.open_group(destination, mode="a")
# Make a pool for parallel processing
from concurrent.futures import ThreadPoolExecutor, as_completed
pool = ThreadPoolExecutor()
# Copy the ground truth datasets
futures = []
crops_started = set()
for line in tqdm(manifest[1:], desc="Formatting ground truth..."):
crop = line.split(",")[0]
if crop not in crops_started:
crops_started.add(crop)
ground_truth.create_group(f"crop{crop}")
futures.append(
pool.submit(copy_gt, line, search_path, path_root, write_path, ground_truth)
)
for future in tqdm(as_completed(futures), total=len(futures), desc="Copying..."):
future.result()
print(f"Ground truth dataset written to: {destination}")
print(f"Done in {time() - start_time}!")
[docs]
def copy_gt(line, search_path, path_root, write_path, ground_truth):
# Get the metadata from the manifest
crop, dataset, class_label, voxel_size, translation, shape = line.split(",")
crop_name = f"crop{crop}"
voxel_size = eval(voxel_size.replace(";", ","))
translation = eval(translation.replace(";", ","))
shape = eval(shape.replace(";", ","))
# Open the source ground truth zarr file
path = search_path.format(
path_root=path_root, dataset=dataset, crop=crop_name, label=class_label
)
zarr_file = zarr.open(path, mode="r")
# Write the dataset to the destination Zarr
print(f"Writing {write_path.format(crop=crop_name, label=class_label)}")
dataset = ground_truth.create_dataset(
write_path.format(crop=crop_name, label=class_label),
data=zarr_file["s0"],
shape=shape,
dtype=zarr_file["s0"].dtype,
overwrite=True,
# fill_value=0,
dimension_separator="/",
)
dataset.attrs["voxel_size"] = voxel_size
dataset.attrs["translation"] = translation
dataset.attrs["shape"] = shape
[docs]
def download_file(url, dest):
response = requests.get(url)
response.raise_for_status()
with open(dest, "wb") as f:
f.write(response.content)
[docs]
def get_git_hash() -> str:
"""
Get the current git hash of the repository.
Returns
-------
str
The current git hash.
"""
repo = git.Repo(UPath(__file__).parent, search_parent_directories=True)
sha = repo.head.object.hexsha
return sha
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python utils.py <path_root>")
sys.exit(1)
elif len(sys.argv) == 2 or sys.argv[2] == "dataset":
construct_truth_dataset(
sys.argv[1],
)
elif sys.argv[2] == "manifest":
construct_test_crop_manifest(sys.argv[1], verbose=True)