import shutil
import sys
from time import time
import numpy as np
import requests
from tqdm import tqdm
from cellmap_segmentation_challenge.utils import get_tested_classes
from cellmap_segmentation_challenge import TRUTH_PATH
import zarr
from skimage.measure import label as relabel
from upath import UPath
[docs]
def construct_test_crop_manifest(
path_root: str,
search_path: str = "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}",
write_path: str | None = (UPath(__file__).parent / "test_crop_manifest.csv").path,
verbose: bool = False,
) -> None | list[str]:
"""
Construct a manifest file for testing crops from a given path.
Parameters
----------
path_root : str
Path to the directory containing the datasets.
search_path : str, optional
Format string to search for the crops. The default is "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}". The function assumes that the keys appear in the file tree in the following order: 1) "path_root", 2) "dataset", 3) "crop", 4) "label"
write_path : str, optional
Path to write the manifest file. The default is "test_crop_manifest.csv".
verbose : bool, optional
Print verbose output. The default is False.
"""
# Get the tested classes
tested_classes = set(get_tested_classes())
# Construct the manifest
manifest = [
"crop_name,dataset,class_label,voxel_size,translation,shape",
]
# Get datasets
datasets = [
d.name
for d in UPath(
search_path.split("{dataset}")[0].format(path_root=path_root)
).iterdir()
if d.is_dir()
]
for dataset in datasets:
print(f"Processing dataset: {dataset}")
# Get crops
crops = [
d.name
for d in UPath(
search_path.split("{crop}")[0].format(
path_root=path_root, dataset=dataset
)
).iterdir()
if d.is_dir()
]
for crop in crops:
if verbose:
print(f"\tProcessing crop: {crop}")
# Get labels in crop
had_classes = set(
[
d.name
for d in UPath(
search_path.split("{label}")[0].format(
path_root=path_root, dataset=dataset, crop=crop
)
).iterdir()
if d.is_dir()
]
)
# Filter for tested classes
labels = list(had_classes.intersection(tested_classes))
for label in labels:
if verbose:
print(f"\t\tProcessing label: {label}")
# Get the zarr file
zarr_file = zarr.open(
search_path.format(
path_root=path_root, dataset=dataset, crop=crop, label=label
),
mode="r",
)
# Get the metadata
metadata = zarr_file.attrs.asdict()["multiscales"][0]["datasets"][0][
"coordinateTransformations"
]
for meta in metadata:
if meta["type"] == "translation":
translation = format_coordinates(meta["translation"])
elif meta["type"] == "scale":
voxel_size = format_coordinates(meta["scale"])
shape = format_coordinates(zarr_file["s0"].shape)
manifest.append(
f"{crop.replace('crop', '')},{dataset},{label},{voxel_size},{translation},{shape}"
)
if verbose:
print(
f"\t\t\tScale: {voxel_size}\n\t\t\tTranslation: {translation}\n\t\t\tShape: {shape}"
)
if write_path is None:
return manifest
# Write the manifest
with open(write_path, "w") as f:
f.write("\n".join(manifest))
print(f"Manifest written to: {write_path}")
[docs]
def construct_truth_dataset(
path_root: str,
search_path: str = "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}",
destination: str = TRUTH_PATH,
write_path: str = "{crop}/{label}",
):
"""
Construct a consolidated Zarr file for the groundtruth datasets, to use for evaluation.
Parameters
----------
path_root : str
Path to the directory containing the datasets.
search_path : str, optional
Format string to search for the crops. The default is "{path_root}/{dataset}/groundtruth.zarr/{crop}/{label}". The function assumes that the keys appear in the file tree in the following order: 1) "path_root", 2) "dataset", 3) "crop", 4) "label"
destination : str, optional
Path to write the consolidated Zarr file. The default is "cellmap-segmentation-challenge/data/ground_truth.zarr".
write_path : str, optional
Format string to write the crops to within the destination Zarr. The default is "{crop}/{label}".
"""
start_time = time()
# Get the test crop manifested
manifest = construct_test_crop_manifest(path_root, search_path, write_path=None)
# Open the destination Zarr folder
if UPath(destination).exists():
print(f"Removing existing ground truth dataset at: {destination}")
shutil.rmtree(destination)
# ground_truth = zarr.open_group(destination, mode="w")
ground_truth = zarr.open_group(destination, mode="a")
# Make a pool for parallel processing
from concurrent.futures import ThreadPoolExecutor, as_completed
pool = ThreadPoolExecutor()
# Copy the ground truth datasets
futures = []
crops_started = set()
for line in tqdm(manifest[1:], desc="Formatting ground truth..."):
crop = line.split(",")[0]
if crop not in crops_started:
crops_started.add(crop)
ground_truth.create_group(f"crop{crop}")
futures.append(
pool.submit(copy_gt, line, search_path, path_root, write_path, ground_truth)
)
for future in tqdm(as_completed(futures), total=len(futures), desc="Copying..."):
future.result()
print(f"Ground truth dataset written to: {destination}")
print(f"Done in {time() - start_time}!")
[docs]
def copy_gt(line, search_path, path_root, write_path, ground_truth):
# Get the metadata from the manifest
crop, dataset, class_label, voxel_size, translation, shape = line.split(",")
crop_name = f"crop{crop}"
voxel_size = eval(voxel_size.replace(";", ","))
translation = eval(translation.replace(";", ","))
shape = eval(shape.replace(";", ","))
# Open the source ground truth zarr file
path = search_path.format(
path_root=path_root, dataset=dataset, crop=crop_name, label=class_label
)
zarr_file = zarr.open(path, mode="r")
# Write the dataset to the destination Zarr
print(f"Writing {write_path.format(crop=crop_name, label=class_label)}")
dataset = ground_truth.create_dataset(
write_path.format(crop=crop_name, label=class_label),
data=zarr_file["s0"],
shape=shape,
dtype=zarr_file["s0"].dtype,
overwrite=True,
# fill_value=0,
dimension_separator="/",
)
dataset.attrs["voxel_size"] = voxel_size
dataset.attrs["translation"] = translation
dataset.attrs["shape"] = shape
# %%
# Helper functions for simulating predictions
[docs]
def simulate_predictions_iou_binary(labels, iou):
# TODO: Add false positives (only makes false negatives currently)
print(f"Simulating predictions with IOU: {iou}")
shape = labels.shape
labels = labels.flatten() > 0
n_positive = np.sum(labels)
labels[labels > 0] = np.random.choice([1, 0], n_positive, p=[iou, 1 - iou])
labels = labels.reshape(shape)
return labels
[docs]
def simulate_predictions_iou(true_labels, iou):
# TODO: Add false positives (only makes false negatives currently)
pred_labels = np.zeros_like(true_labels)
for i in np.unique(true_labels):
if i == 0:
continue
pred_labels[true_labels == i] = np.random.choice(
[i, 0], np.sum(true_labels == i), p=[iou, 1 - iou]
)
pred_labels = relabel(pred_labels, connectivity=len(pred_labels.shape))
return pred_labels
[docs]
def simulate_predictions_accuracy(true_labels, accuracy):
shape = true_labels.shape
true_labels = true_labels.flatten()
# Get the total number of labels
n = len(true_labels)
# Create an array to store the simulated predictions (copy the true labels initially)
simulated_predictions = np.copy(true_labels)
# Randomly select indices to be incorrect
incorrect_indices = np.random.choice(n, size=n - int(accuracy * n), replace=False)
simulated_predictions[incorrect_indices] = (
1 - simulated_predictions[incorrect_indices]
)
# Reshape and relabel the predictions
simulated_predictions = simulated_predictions.reshape(shape)
simulated_predictions = relabel(simulated_predictions, connectivity=len(shape))
return simulated_predictions
[docs]
def perturb_instance_mask(true_labels, hd_target=None, accuracy=0.8):
"""
Simulate a predicted instance segmentation mask with an approximate Hausdorff distance.
Parameters:
- true_labels: np.ndarray
Ground-truth instance segmentation mask.
- hd_target: float | None
Desired approximate Hausdorff distance. If None, it will be calculated from the accuracy.
- accuracy: float
Desired accuracy of the perturbed mask.
Returns:
- np.ndarray
Perturbed instance segmentation mask.
"""
if hd_target is None:
hd_target = -(np.log(accuracy) / np.log(1.01))
perturbed = np.copy(true_labels)
unique_instances = np.unique(true_labels)[1:] # Exclude background (0)
for instance in unique_instances:
# print("Shifting...")
# Randomly shift the mask
indices = np.where(perturbed == instance)
perturbed[indices] = 0 # Remove the original instance
indices = list(indices)
for i in range(3):
shift = np.random.randint(-hd_target, hd_target + 1)
shift = np.clip(
shift,
-indices[i].min(),
true_labels.shape[i] - (indices[i].max() + 1),
)
indices[i] += shift
indices = tuple(indices)
perturbed[indices] = instance
perturbed = simulate_predictions_accuracy(perturbed, accuracy)
return perturbed
[docs]
def download_file(url, dest):
response = requests.get(url)
response.raise_for_status()
with open(dest, "wb") as f:
f.write(response.content)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python utils.py <path_root>")
sys.exit(1)
elif len(sys.argv) == 2 or sys.argv[2] == "dataset":
construct_truth_dataset(
sys.argv[1],
)
elif sys.argv[2] == "manifest":
construct_test_crop_manifest(sys.argv[1], verbose=True)