# %%
import os
import shutil
import sys
from glob import glob
import numpy as np
from tqdm import tqdm
from upath import UPath
REPO_ROOT = UPath(__file__).parent.parent.parent.parent
SEARCH_PATH = (REPO_ROOT / "data/{dataset}/{dataset}.zarr/recon-1/{name}").path
CROP_NAME = UPath("labels/groundtruth/{crop}/{label}").path
RAW_NAME = UPath("em/fibsem-uint8").path
[docs]
def get_dataset_name(
raw_path: str, search_path: str = SEARCH_PATH, raw_name: str = RAW_NAME
) -> str:
"""
Get the name of the dataset from the raw path.
"""
path_base = search_path.format(dataset="{dataset}", name=raw_name)
assert "{dataset}" in path_base, (
f"search_path {search_path} must contain" + "{dataset}"
)
for rp, sp in zip(raw_path.split(os.path.sep), path_base.split(os.path.sep)):
if sp == "{dataset}":
return rp
raise ValueError(
f"Could not find dataset name in {raw_path} with {search_path} as template"
)
[docs]
def get_raw_path(crop_path: str, raw_name: str = RAW_NAME, label: str = "") -> str:
"""
Get the path to the raw data for a given crop path.
Parameters
----------
crop_path : str
The path to the crop.
raw_name : str, optional
The name of the raw data, by default RAW_NAME
label : str, optional
The label class at the crop_path, by default ""
Returns
-------
str
The path to the raw data.
"""
crop_path = crop_path.rstrip(label + os.path.sep)
crop_name = CROP_NAME.format(crop=os.path.basename(crop_path), label="").rstrip(
os.path.sep
)
return (UPath(crop_path.removesuffix(crop_name)) / raw_name).path
[docs]
def get_csv_string(
path: str,
classes: list[str],
usage: str,
raw_name: str = RAW_NAME,
):
"""
Get the csv string for a given dataset path, to be written to the datasplit csv file.
Parameters
----------
path : str
The path to the dataset.
classes : list[str]
The classes present in the dataset.
usage : str
The usage of the dataset (train or validate).
raw_name : str, optional
The name of the raw data. Default is RAW_NAME.
Returns
-------
str
The csv string for the dataset.
"""
raw_path = get_raw_path(path, raw_name)
dataset_name = get_dataset_name(raw_path)
if not UPath(raw_path).exists():
bar_string = (
f"No raw data found for {dataset_name} at {raw_path}, trying n5 format"
)
raw_path = raw_path.replace(".zarr", ".n5")
if not UPath(raw_path).exists():
bar_string = f"No raw data found for {dataset_name} at {raw_path}, skipping"
return None, bar_string
zarr_path = raw_path.split(".n5")[0] + ".n5"
else:
zarr_path = raw_path.split(".zarr")[0] + ".zarr"
raw_ds_name = raw_path.removeprefix(zarr_path + os.path.sep)
gt_ds_name = path.removeprefix(zarr_path + os.path.sep)
bar_string = f"Found raw data for {dataset_name} at {raw_path}"
return (
f'"{usage}","{zarr_path}","{raw_ds_name}","{zarr_path}","{gt_ds_name+os.path.sep}[{",".join([c for c in classes])}]"\n',
bar_string,
)
[docs]
def make_datasplit_csv(
classes: list[str] = ["nuc", "mito"],
force_all_classes: bool | str = False,
validation_prob: float = 0.1,
datasets: list[str] = ["*"],
crops: list[str] = ["*"],
search_path: str = SEARCH_PATH,
raw_name: str = RAW_NAME,
crop_name: str = CROP_NAME,
csv_path: str = "datasplit.csv",
dry_run: bool = False,
):
"""
Make a datasplit csv file for the given classes and datasets.
Parameters
----------
classes : list[str], optional
The classes to include in the csv, by default ["nuc", "mito"]
force_all_classes : bool | str, optional
If True, force all classes to be present in the training/validation datasets. If False, as long as at least one requested class is present, a crop will be included. If "train" or "validate", force all classes to be present in the training or validation datasets, respectively. By default False.
validation_prob : float, optional
The probability of a dataset being in the validation set, by default 0.1
datasets : list[str], optional
The datasets to include in the csv, by default ["*"], which includes all datasets
crops : list[str], optional
The crops to include in the csv, by default all crops are included. Otherwise, only the crops in the list are included.
search_path : str, optional
The search path to use to find the datasets, by default SEARCH_PATH
raw_name : str, optional
The name of the raw data, by default RAW_NAME
crop_name : str, optional
The name of the crop, by default CROP_NAME
csv_path : str, optional
The path to write the csv file to, by default "datasplit.csv"
dry_run : bool, optional
If True, do not write the csv file - just return the found datapaths. By default False
"""
# Define the paths to the raw and groundtruth data and the label classes by crawling the directories and writing the paths to a csv file
datapaths = {}
for dataset in datasets:
for crop in crops:
for label in classes:
these_datapaths = glob(
search_path.format(
dataset=dataset, name=crop_name.format(crop=crop, label=label)
)
)
if len(these_datapaths) == 0:
continue
these_datapaths = [
path.removesuffix(os.path.sep + label) for path in these_datapaths
]
for path in these_datapaths:
if path not in datapaths:
datapaths[path] = []
datapaths[path].append(label)
if dry_run:
print("Dry run, not writing csv")
return datapaths
shutil.rmtree(csv_path, ignore_errors=True)
assert not os.path.exists(
csv_path
), f"CSV file {csv_path} already exists and cannot be overwritten"
usage_dict = {
k: "train" if np.random.rand() > validation_prob else "validate"
for k in datapaths.keys()
}
num_train = num_validate = 0
bar = tqdm(datapaths.keys())
for path in bar:
print(f"Processing {path}")
usage = usage_dict[path]
if force_all_classes == usage:
if len(datapaths[path]) != len(classes):
usage = "train" if usage == "validate" else "validate"
elif force_all_classes is True:
if len(datapaths[path]) != len(classes):
usage_dict[path] = "none"
continue
usage_dict[path] = usage
csv_string, bar_string = get_csv_string(path, datapaths[path], usage, raw_name)
bar.set_postfix_str(bar_string)
if csv_string is not None:
with open(csv_path, "a") as f:
if csv_string is not None:
f.write(csv_string)
if usage == "train":
num_train += 1
else:
num_validate += 1
assert num_train + num_validate > 0, "No datasets found"
print(f"Number of datasets: {num_train + num_validate}")
print(
f"Number of training datasets: {num_train} ({num_train/(num_train+num_validate)*100:.2f}%)"
)
print(
f"Number of validation datasets: {num_validate} ({num_validate/(num_train+num_validate)*100:.2f}%)"
)
print(f"CSV written to {csv_path}")
[docs]
def get_dataset_counts(
classes: list[str] = ["nuc", "mito"],
search_path: str = SEARCH_PATH,
raw_name: str = RAW_NAME,
crop_name: str = CROP_NAME,
):
"""
Get the counts of each class in each dataset.
Parameters
----------
classes : list[str], optional
The classes to include in the csv, by default ["nuc", "mito"]
search_path : str, optional
The search path to use to find the datasets, by default SEARCH_PATH
raw_name : str, optional
The name of the raw data, by default RAW_NAME
crop_name : str, optional
The name of the crop, by default CROP_NAME
Returns
-------
dict
A dictionary of the counts of each class in each dataset.
"""
dataset_class_counts = {}
for label in classes:
these_datapaths = glob(
search_path.format(
dataset="*", name=crop_name.format(crop="*", label=label)
)
)
for path in these_datapaths:
raw_path = get_raw_path(path, raw_name, label)
dataset_name = get_dataset_name(raw_path)
if not UPath(raw_path).exists():
print(
f"No raw data found for {dataset_name} at {raw_path}, trying n5 format"
)
raw_path = raw_path.replace(".zarr", ".n5")
if not UPath(raw_path).exists():
print(
f"No raw data found for {dataset_name} at {raw_path}, skipping"
)
continue
if dataset_name not in dataset_class_counts:
dataset_class_counts[dataset_name] = {}
if label not in dataset_class_counts[dataset_name]:
dataset_class_counts[dataset_name][label] = 1
else:
dataset_class_counts[dataset_name][label] += 1
return dataset_class_counts
if __name__ == "__main__":
"""
Usage: python datasplit.py [search_path] [classes]
search_path: The search path to use to find the datasets. Defaults to SEARCH_PATH.
classes: A comma-separated list of classes to include in the csv. Defaults to ["nuc", "er"].
"""
if len(sys.argv) > 1 and sys.argv[1][0] == "[":
classes = sys.argv[1][1:-1].split(",")
if len(sys.argv) > 2:
search_path = sys.argv[2]
else:
search_path = SEARCH_PATH
elif len(sys.argv) > 1:
search_path = sys.argv[1]
classes = ["nuc", "er"]
else:
classes = ["nuc", "er"]
search_path = SEARCH_PATH
os.remove("datasplit.csv")
make_datasplit_csv(classes=classes, search_path=search_path, validation_prob=0.1)