# %%
import csv
import os
import shutil
import sys
from glob import glob
from typing import Optional
import numpy as np
from tqdm import tqdm
from upath import UPath
from ..config import (
CROP_NAME,
RAW_NAME,
SEARCH_PATH,
S3_CROP_NAME,
S3_RAW_NAME,
S3_SEARCH_PATH,
GT_S3_BUCKET,
RAW_S3_BUCKET,
)
# TODO: Consolidate with get_formatted_fields
[docs]
def get_dataset_name(
raw_path: str, search_path: str = SEARCH_PATH, raw_name: str = RAW_NAME
) -> str:
"""
Get the name of the dataset from the raw path.
"""
path_base = search_path.format(dataset="{dataset}", name=raw_name)
assert "{dataset}" in path_base, (
f"search_path {search_path} must contain" + "{dataset}"
)
for rp, sp in zip(
UPath(raw_path).path.split("/"), UPath(path_base).path.split("/")
):
if sp == "{dataset}":
return rp
raise ValueError(
f"Could not find dataset name in {raw_path} with {search_path} as template"
)
# TODO: Consolidate with get_formatted_fields
[docs]
def get_raw_path(crop_path: str, raw_name: str = RAW_NAME, label: str = "") -> str:
"""
Get the path to the raw data for a given crop path.
Parameters
----------
crop_path : str
The path to the crop.
raw_name : str, optional
The name of the raw data, by default RAW_NAME
label : str, optional
The label class at the crop_path, by default ""
Returns
-------
str
The path to the raw data.
"""
crop_path = crop_path.rstrip(label + os.path.sep)
crop_name = CROP_NAME.format(crop=os.path.basename(crop_path), label="").rstrip(
os.path.sep
)
return (UPath(crop_path.removesuffix(crop_name)) / raw_name).path
[docs]
def get_s3_csv_string(path: str, classes: list[str], usage: str):
"""
Get the csv string for a given dataset path, to be written to the datasplit csv file.
Parameters
----------
path : str
The path to the dataset.
classes : list[str]
The classes present in the dataset.
usage : str
The usage of the dataset (train or validate).
Returns
-------
str
The csv string for the dataset.
"""
dataset_name = get_formatted_fields(path, S3_SEARCH_PATH, ["dataset", "name"])[
"dataset"
]
raw_path = UPath("s3://" + RAW_S3_BUCKET, anon=True) / S3_SEARCH_PATH.format(
dataset=dataset_name, name=S3_RAW_NAME
)
raw_zarr_path = raw_path.path.split(".zarr")[0] + ".zarr"
gt_zarr_path = (UPath("s3://" + GT_S3_BUCKET, anon=True) / path).path.split(
".zarr"
)[0] + ".zarr"
raw_ds_name = raw_path.path.removeprefix(raw_zarr_path + os.path.sep)
gt_ds_name = path.split(".zarr")[-1].removeprefix(os.path.sep)
bar_string = f"Found raw data for {dataset_name} at {raw_path}"
return (
f'"{usage}","{"s3://" + raw_zarr_path}","{raw_ds_name}","{"s3://" + gt_zarr_path}","{gt_ds_name+os.path.sep}[{",".join([c for c in classes])}]"\n',
bar_string,
)
[docs]
def get_csv_string(
path: str,
classes: list[str],
usage: str,
raw_name: str = RAW_NAME,
search_path: str = SEARCH_PATH,
):
"""
Get the csv string for a given dataset path, to be written to the datasplit csv file.
Parameters
----------
path : str
The path to the dataset.
classes : list[str]
The classes present in the dataset.
usage : str
The usage of the dataset (train or validate).
raw_name : str, optional
The name of the raw data. Default is RAW_NAME.
search_path : str, optional
The search path to use to find the datasets. Default is SEARCH_PATH.
Returns
-------
str
The csv string for the dataset.
"""
raw_path = get_raw_path(path, raw_name)
dataset_name = get_dataset_name(
raw_path, search_path=search_path, raw_name=raw_name
)
if not UPath(raw_path).exists():
bar_string = (
f"No raw data found for {dataset_name} at {raw_path}, trying n5 format"
)
raw_path = raw_path.replace(".zarr", ".n5")
if not UPath(raw_path).exists():
bar_string = f"No raw data found for {dataset_name} at {raw_path}, skipping"
return None, bar_string
zarr_path = raw_path.split(".n5")[0] + ".n5"
else:
zarr_path = raw_path.split(".zarr")[0] + ".zarr"
raw_ds_name = raw_path.removeprefix(zarr_path + os.path.sep)
gt_ds_name = path.removeprefix(zarr_path + os.path.sep)
bar_string = f"Found raw data for {dataset_name} at {raw_path}"
return (
f'"{usage}","{zarr_path}","{raw_ds_name}","{zarr_path}","{gt_ds_name+os.path.sep}[{",".join([c for c in classes])}]"\n',
bar_string,
)
[docs]
def make_s3_datasplit_csv(
classes: list[str] = ["nuc", "mito"],
force_all_classes: bool | str = False,
validation_prob: float = 0.1,
datasets: list[str] = ["*"],
crops: list[str] = ["*"],
csv_path: str = "datasplit.csv",
dry_run: bool = False,
**kwargs,
):
"""
Make a datasplit csv file for the given classes and datasets.
Parameters
----------
classes : list[str], optional
The classes to include in the csv, by default ["nuc", "mito"]
force_all_classes : bool | str, optional
If True, force all classes to be present in the training/validation datasets. If False, as long as at least one requested class is present, a crop will be included. If "train" or "validate", force all classes to be present in the training or validation datasets, respectively. By default False.
validation_prob : float, optional
The probability of a dataset being in the validation set, by default 0.1
datasets : list[str], optional
The datasets to include in the csv, by default ["*"], which includes all datasets
crops : list[str], optional
The crops to include in the csv, by default all crops are included. Otherwise, only the crops in the list are included.
csv_path : str, optional
The path to write the csv file to, by default "datasplit.csv"
dry_run : bool, optional
If True, do not write the csv file - just return the found datapaths. By default False
**kwargs : dict
Additional keyword arguments will be unused. Kept for compatibility with make_datasplit_csv.
"""
# Define the paths to the raw and groundtruth data and the label classes by crawling the directories and writing the paths to a csv file
if not dry_run:
shutil.rmtree(csv_path, ignore_errors=True)
assert not os.path.exists(
csv_path
), f"CSV file {csv_path} already exists and could not be overwritten"
datapaths = {}
for dataset in datasets:
for crop in crops:
for label in classes:
these_datapaths = list(
UPath("s3://" + GT_S3_BUCKET, anon=True).glob(
S3_SEARCH_PATH.format(
dataset=dataset,
name=S3_CROP_NAME.format(crop=crop, label=label),
)
)
)
if len(these_datapaths) == 0:
continue
these_datapaths = [
path.path.removesuffix(os.path.sep + label).removeprefix(
GT_S3_BUCKET + os.path.sep
)
for path in these_datapaths
]
for path in these_datapaths:
if path not in datapaths:
datapaths[path] = []
datapaths[path].append(label)
if dry_run:
print("Dry run, not writing csv")
return datapaths
usage_dict = {
k: "train" if np.random.rand() > validation_prob else "validate"
for k in datapaths.keys()
}
num_train = num_validate = 0
bar = tqdm(datapaths.keys())
for path in bar:
print(f"Processing {path}")
usage = usage_dict[path]
if force_all_classes == usage:
if len(datapaths[path]) != len(classes):
usage = "train" if usage == "validate" else "validate"
elif force_all_classes is True:
if len(datapaths[path]) != len(classes):
usage_dict[path] = "none"
continue
usage_dict[path] = usage
csv_string, bar_string = get_s3_csv_string(path, datapaths[path], usage)
bar.set_postfix_str(bar_string)
if csv_string is not None:
with open(csv_path, "a") as f:
if csv_string is not None:
f.write(csv_string)
if usage == "train":
num_train += 1
else:
num_validate += 1
assert num_train + num_validate > 0, "No datasets found"
print(f"Number of datasets: {num_train + num_validate}")
print(
f"Number of training datasets: {num_train} ({num_train/(num_train+num_validate)*100:.2f}%)"
)
print(
f"Number of validation datasets: {num_validate} ({num_validate/(num_train+num_validate)*100:.2f}%)"
)
print(f"CSV written to {csv_path}")
[docs]
def make_datasplit_csv(
classes: list[str] = ["nuc", "mito"],
force_all_classes: bool | str = False,
validation_prob: float = 0.1,
datasets: list[str] = ["*"],
crops: list[str] = ["*"],
search_path: str = SEARCH_PATH,
raw_name: str = RAW_NAME,
crop_name: str = CROP_NAME,
csv_path: str = "datasplit.csv",
dry_run: bool = False,
):
"""
Make a datasplit csv file for the given classes and datasets.
Parameters
----------
classes : list[str], optional
The classes to include in the csv, by default ["nuc", "mito"]
force_all_classes : bool | str, optional
If True, force all classes to be present in the training/validation datasets. If False, as long as at least one requested class is present, a crop will be included. If "train" or "validate", force all classes to be present in the training or validation datasets, respectively. By default False.
validation_prob : float, optional
The probability of a dataset being in the validation set, by default 0.1
datasets : list[str], optional
The datasets to include in the csv, by default ["*"], which includes all datasets
crops : list[str], optional
The crops to include in the csv, by default all crops are included. Otherwise, only the crops in the list are included.
search_path : str, optional
The search path to use to find the datasets, by default SEARCH_PATH
raw_name : str, optional
The name of the raw data, by default RAW_NAME
crop_name : str, optional
The name of the crop, by default CROP_NAME
csv_path : str, optional
The path to write the csv file to, by default "datasplit.csv"
dry_run : bool, optional
If True, do not write the csv file - just return the found datapaths. By default False
"""
# Define the paths to the raw and groundtruth data and the label classes by crawling the directories and writing the paths to a csv file
datapaths = {}
for dataset in datasets:
for crop in crops:
for label in classes:
these_datapaths = glob(
search_path.format(
dataset=dataset, name=crop_name.format(crop=crop, label=label)
)
)
if len(these_datapaths) == 0:
continue
these_datapaths = [
path.removesuffix(os.path.sep + label) for path in these_datapaths
]
for path in these_datapaths:
if path not in datapaths:
datapaths[path] = []
datapaths[path].append(label)
if dry_run:
print("Dry run, not writing csv")
return datapaths
shutil.rmtree(csv_path, ignore_errors=True)
assert not os.path.exists(
csv_path
), f"CSV file {csv_path} already exists and cannot be overwritten"
usage_dict = {
k: "train" if np.random.rand() > validation_prob else "validate"
for k in datapaths.keys()
}
# Now enforce that there is one training and one validation crop if possible
if len(usage_dict) >= 2:
if np.sum(usage_dict.values() == "train") == 0:
usage_dict[list(usage_dict.keys())[0]] = "train"
elif np.sum(usage_dict.values() == "validate") == 0:
usage_dict[list(usage_dict.keys())[0]] = "validate"
num_train = num_validate = 0
bar = tqdm(datapaths.keys())
for path in bar:
print(f"Processing {path}")
usage = usage_dict[path]
if force_all_classes == usage:
if len(datapaths[path]) != len(classes):
usage = "train" if usage == "validate" else "validate"
elif force_all_classes is True:
if len(datapaths[path]) != len(classes):
usage_dict[path] = "none"
continue
usage_dict[path] = usage
csv_string, bar_string = get_csv_string(
path, datapaths[path], usage, raw_name, search_path
)
bar.set_postfix_str(bar_string)
if csv_string is not None:
with open(csv_path, "a") as f:
if csv_string is not None:
f.write(csv_string)
if usage == "train":
num_train += 1
else:
num_validate += 1
assert num_train + num_validate > 0, "No datasets found"
print(f"Number of datasets: {num_train + num_validate}")
print(
f"Number of training datasets: {num_train} ({num_train/(num_train+num_validate)*100:.2f}%)"
)
print(
f"Number of validation datasets: {num_validate} ({num_validate/(num_train+num_validate)*100:.2f}%)"
)
print(f"CSV written to {csv_path}")
[docs]
def get_dataset_counts(
classes: list[str] = ["nuc", "mito"],
search_path: str = SEARCH_PATH,
raw_name: str = RAW_NAME,
crop_name: str = CROP_NAME,
):
"""
Get the counts of each class in each dataset.
Parameters
----------
classes : list[str], optional
The classes to include in the csv, by default ["nuc", "mito"]
search_path : str, optional
The search path to use to find the datasets, by default SEARCH_PATH
raw_name : str, optional
The name of the raw data, by default RAW_NAME
crop_name : str, optional
The name of the crop, by default CROP_NAME
Returns
-------
dict
A dictionary of the counts of each class in each dataset.
"""
dataset_class_counts = {}
for label in classes:
these_datapaths = glob(
search_path.format(
dataset="*", name=crop_name.format(crop="*", label=label)
)
)
for path in these_datapaths:
raw_path = get_raw_path(path, raw_name, label)
dataset_name = get_dataset_name(raw_path)
if not UPath(raw_path).exists():
print(
f"No raw data found for {dataset_name} at {raw_path}, trying n5 format"
)
raw_path = raw_path.replace(".zarr", ".n5")
if not UPath(raw_path).exists():
print(
f"No raw data found for {dataset_name} at {raw_path}, skipping"
)
continue
if dataset_name not in dataset_class_counts:
dataset_class_counts[dataset_name] = {}
if label not in dataset_class_counts[dataset_name]:
dataset_class_counts[dataset_name][label] = 1
else:
dataset_class_counts[dataset_name][label] += 1
return dataset_class_counts
[docs]
def get_tested_classes(
csv_path: str = (UPath(__file__).parent / "tested_classes.csv").path,
):
"""
Get the classes that will be tested for the challenge.
Parameters
----------
csv_path : str, optional
The path to the csv file, by default "tested_classes.csv"
Returns
-------
list[str]
A list of the classes that have been tested.
"""
tested_classes = []
with open(csv_path, "r") as f:
reader = csv.reader(f)
for row in reader:
tested_classes.append(row[0])
return tested_classes
[docs]
def get_class_incl_ids(incl_ids_string):
if incl_ids_string is None or incl_ids_string == "":
return []
return [int(id) for id in incl_ids_string.split(",")]
[docs]
def get_class_relations(
csv_path: str = (UPath(__file__).parent / "classes.csv").path,
named_classes: Optional[list[str]] = None,
):
classes_dict = {}
with open(csv_path, "r") as f:
reader = csv.reader(f)
for row in reader:
if named_classes is not None and row[0] not in named_classes:
continue
id = int(row[1])
if id not in classes_dict:
classes_dict[id] = {}
classes_dict[id]["name"] = row[0]
classes_dict[id]["incl_ids"] = set(get_class_incl_ids(*row[2:]) + [id])
# Get all potentially overlapping classes (e.g. cell can overlap with every organelle but not ecs)
class_relation_dict = {}
for id1, info1 in classes_dict.items():
if info1["name"] not in class_relation_dict:
class_relation_dict[info1["name"]] = set()
for id2, info2 in classes_dict.items():
if id1 == id2:
class_relation_dict[info1["name"]].add(info2["name"])
if len(info1["incl_ids"].intersection(info2["incl_ids"])) > 0:
class_relation_dict[info1["name"]].add(info2["name"])
class_ids = set(v["name"] for v in classes_dict.values())
for key, value in class_relation_dict.items():
class_relation_dict[key] = class_ids - value
return class_relation_dict
if __name__ == "__main__":
"""
Usage: python datasplit.py [search_path] [classes]
search_path: The search path to use to find the datasets. Defaults to SEARCH_PATH.
classes: A comma-separated list of classes to include in the csv. Defaults to ["nuc", "er"].
"""
if len(sys.argv) > 1 and sys.argv[1][0] == "[":
classes = sys.argv[1][1:-1].split(",")
if len(sys.argv) > 2:
search_path = sys.argv[2]
else:
search_path = SEARCH_PATH
elif len(sys.argv) > 1:
search_path = sys.argv[1]
classes = ["nuc", "er"]
else:
classes = ["nuc", "er"]
search_path = SEARCH_PATH
os.remove("datasplit.csv")
make_datasplit_csv(classes=classes, search_path=search_path, validation_prob=0.1)