import csv
import os
from typing import Any, Callable, Mapping, Optional, Sequence
import tensorstore
import torch
import torchvision.transforms.v2 as T
from tqdm import tqdm
from .transforms.augment import NaNtoNum, Normalize, Binarize
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .subdataset import CellMapSubset
import logging
logger = logging.getLogger(__name__)
[docs]
class CellMapDataSplit:
"""
A class to split the data into training and validation datasets.
Attributes:
input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure::
{
"array_name": {
"shape": tuple[int],
"scale": Sequence[float],
},
...
}
target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the same structure as input_arrays.
classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays.
empty_value (int | float): The value to use for empty data. Defaults to torch.nan.
pad (bool | str): Whether to pad the data. If a string, it should be either "train" or "validate". Defaults to False.
datasets (Optional[Mapping[str, Sequence[CellMapDataset]]]): A dictionary containing the dataset objects. The dictionary should have the following structure:
{
"train": Iterable[CellMapDataset],
"validate": Iterable[CellMapDataset],
}. Defaults to None.
dataset_dict (Optional[Mapping[str, Sequence[Mapping[str, str]]]): A dictionary containing the dataset data. Defaults to None. The dictionary should have the following structure::
{
"train" | "validate": [{
"raw": str (path to raw data),
"gt": str (path to ground truth data),
}],
...
}
csv_path (Optional[str]): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure:
train | validate, raw path, gt path
spatial_transforms (Optional[Sequence[dict[str, Any]]]): A sequence of dictionaries containing the spatial transformations to apply to the data. Defaults to None. The dictionary should have the following structure::
{transform_name: {transform_args}}
train_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in training datasets. Example is to add gaussian noise to the raw data. Defailts to Normalize, ToDtype, and NaNtoNum.
val_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in validation datasets. Example is to normalize the raw data. Defaults to Normalize, ToDtype, and NaNtoNum.
target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]]): A function to convert the ground truth data to target arrays. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. Defaults to ToDtype and Binarize.
class_relation_dict (Optional[Mapping[str, Sequence[str]]]): A dictionary containing the class relations. The dictionary should have the following structure::
{
"class_name": [mutually_exclusive_class_name, ...],
...
}
force_has_data (bool): Whether to force the datasets to have data even if no ground truth data is found. Defaults to False. Useful for training with only raw data.
context (Optional[tensorstore.Context]): The TensorStore context for the image data. Defaults to None.
device (Optional[str | torch.device]): Device to use for the dataloaders. Defaults to None.
Note:
The csv_path, dataset_dict, and datasets arguments are mutually exclusive, but one must be supplied.
Methods:
__repr__(): Returns the string representation of the class.
from_csv(csv_path: str): Loads the dataset data from a csv file.
construct(dataset_dict: Mapping[str, Sequence[Mapping[str, str]]]): Constructs the datasets from the dataset dictionary.
verify_datasets(): Verifies that the datasets have data, and removes ones that don't from 'self.train_datasets' and 'self.validation_datasets'.
set_raw_value_transforms(train_transforms: Optional[Callable] = None, val_transforms: Optional[Callable] = None): Sets the raw value transforms for each dataset in the training/validation multi-datasets.
set_target_value_transforms(transforms: Callable): Sets the target value transforms for each dataset in both training and validation multi-datasets.
set_spatial_transforms(spatial_transforms: dict[str, Any] | None): Sets the spatial transforms for each dataset in the training multi-dataset.
set_arrays(arrays: Mapping[str, Mapping[str, Sequence[int | float]]], type: str = "target", usage: str = "validate"): Sets the input or target arrays for the training or validation datasets.
Properties:
train_datasets_combined: A multi-dataset from the combination of all training datasets.
validation_datasets_combined: A multi-dataset from the combination of all validation datasets.
validation_blocks: A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks.
class_counts: A dictionary containing the class counts for the training and validation datasets.
"""
[docs]
def __init__(
self,
input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]],
target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]],
classes: Sequence[str],
empty_value: int | float = torch.nan,
pad: bool | str = False,
datasets: Optional[Mapping[str, Sequence[CellMapDataset]]] = None,
dataset_dict: Optional[Mapping[str, Sequence[Mapping[str, str]]]] = None,
csv_path: Optional[str] = None,
spatial_transforms: Optional[Mapping[str, Any]] = None,
train_raw_value_transforms: Optional[T.Transform] = T.Compose(
[
Normalize(),
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
),
val_raw_value_transforms: Optional[T.Transform] = T.Compose(
[
Normalize(),
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
),
target_value_transforms: Optional[T.Transform] = T.Compose(
[T.ToDtype(torch.float), Binarize()]
),
class_relation_dict: Optional[Mapping[str, Sequence[str]]] = None,
force_has_data: bool = False,
context: Optional[tensorstore.Context] = None, # type: ignore
device: Optional[str | torch.device] = None,
) -> None:
"""Initializes the CellMapDatasets class.
Args:
input_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure::
{
"array_name": {
"shape": tuple[int],
"scale": Sequence[float],
},
...
}
target_arrays (dict[str, dict[str, Sequence[int | float]]]): A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the same structure as input_arrays.
classes (Sequence[str]): A list of classes for segmentation training. Class order will be preserved in the output arrays.
empty_value (int | float): The value to use for empty data. Defaults to torch.nan.
pad (bool | str): Whether to pad the data. If a string, it should be either "train" or "validate". Defaults to False.
datasets (Optional[Mapping[str, Sequence[CellMapDataset]]]): A dictionary containing the dataset objects. Defaults to None. The dictionary should have the following structure::
{
"train": Iterable[CellMapDataset],
"validate": Iterable[CellMapDataset],
}.
dataset_dict (Optional[Mapping[str, Sequence[Mapping[str, str]]]): A dictionary containing the dataset data. Defaults to None. The dictionary should have the following structure::
{
"train" | "validate": [{
"raw": str (path to raw data),
"gt": str (path to ground truth data),
}],
...
}
csv_path (Optional[str]): A path to a csv file containing the dataset data. Defaults to None. Each row in the csv file should have the following structure:"
train | validate, raw path, gt path
spatial_transforms (Optional[Sequence[dict[str, Any]]]): A sequence of dictionaries containing the spatial transformations to apply to the data. Defaults to None. The dictionary should have the following structure::
{transform_name: {transform_args}}
train_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in training datasets. Defaults to None. Example is to add gaussian noise to the raw data.
val_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in validation datasets. Defaults to None. Example is to normalize the raw data.
target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]]): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
class_relation_dict (Optional[Mapping[str, Sequence[str]]]): A dictionary containing the class relations. The dictionary should have the following structure::
{
"class_name": [mutually_exclusive_class_name, ...],
...
}
force_has_data (bool): Whether to force the datasets to have data even if no ground truth data is found. Defaults to False. Useful for training with only raw data.
context (Optional[tensorstore.Context]): The TensorStore context for the image data. Defaults to None.
device (Optional[str | torch.device]): Device to use for the dataloaders. Defaults to None.
Note:
The csv_path, dataset_dict, and datasets arguments are mutually exclusive, but one must be supplied.
"""
logger.info("Initializing CellMapDataSplit...")
self.input_arrays = input_arrays
self.target_arrays = target_arrays
self.classes = classes
self.empty_value = empty_value
self.pad = pad
self.device = device
if isinstance(pad, str):
self.pad_training = pad.lower() == "train"
self.pad_validation = pad.lower() == "validate"
else:
self.pad_training = pad
self.pad_validation = pad
self.force_has_data = force_has_data
if datasets is not None:
self.datasets = datasets
self.train_datasets = datasets["train"]
if "validate" in datasets:
self.validation_datasets = datasets["validate"]
else:
self.validation_datasets = []
self.dataset_dict = None
elif dataset_dict is not None:
self.dataset_dict = dataset_dict
elif csv_path is not None:
self.dataset_dict = self.from_csv(csv_path)
self.spatial_transforms = spatial_transforms
self.train_raw_value_transforms = train_raw_value_transforms
self.val_raw_value_transforms = val_raw_value_transforms
self.target_value_transforms = target_value_transforms
self.class_relation_dict = class_relation_dict
self.context = context
if self.dataset_dict is not None:
self.construct(self.dataset_dict)
self.verify_datasets()
assert len(self.train_datasets) > 0, "No valid training datasets found."
logger.info("CellMapDataSplit initialized.")
def __repr__(self) -> str:
"""Returns the string representation of the class."""
return f"CellMapDataSplit(\n\tInput arrays: {self.input_arrays}\n\tTarget arrays:{self.target_arrays}\n\tClasses: {self.classes}\n\tDataset dict: {self.dataset_dict}\n\tSpatial transforms: {self.spatial_transforms}\n\tRaw value transforms: {self.train_raw_value_transforms}\n\tGT value transforms: {self.target_value_transforms}\n\tForce has data: {self.force_has_data}\n\tContext: {self.context})"
@property
def train_datasets_combined(self) -> CellMapMultiDataset:
"""A multi-dataset from the combination of all training datasets."""
try:
return self._train_datasets_combined
except AttributeError:
self._train_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
self.target_arrays,
[
ds
for ds in self.train_datasets
if self.force_has_data or ds.has_data
],
)
return self._train_datasets_combined
@property
def validation_datasets_combined(self) -> CellMapMultiDataset:
"""A multi-dataset from the combination of all validation datasets."""
try:
return self._validation_datasets_combined
except AttributeError:
if len(self.validation_datasets) == 0:
UserWarning("Validation datasets not loaded.")
self._validation_datasets_combined = CellMapMultiDataset.empty()
else:
self._validation_datasets_combined = CellMapMultiDataset(
self.classes,
self.input_arrays,
self.target_arrays,
[
ds
for ds in self.validation_datasets
if self.force_has_data or ds.has_data
],
)
return self._validation_datasets_combined
@property
def validation_blocks(self) -> CellMapSubset:
"""A subset of the validation datasets, tiling the validation datasets with non-overlapping blocks."""
try:
return self._validation_blocks
except AttributeError:
self._validation_blocks = CellMapSubset(
self.validation_datasets_combined,
self.validation_datasets_combined.validation_indices,
)
return self._validation_blocks
@property
def class_counts(self) -> dict[str, dict[str, float]]:
"""A dictionary containing the class counts for the training and validation datasets."""
try:
return self._class_counts
except AttributeError:
self._class_counts = {
"train": self.train_datasets_combined.class_counts,
"validate": self.validation_datasets_combined.class_counts,
}
return self._class_counts
[docs]
def from_csv(self, csv_path) -> dict[str, Sequence[dict[str, str]]]:
"""Loads the dataset_dict data from a csv file."""
dataset_dict = {}
with open(csv_path, "r") as f:
reader = csv.reader(f)
print("Reading csv file...")
for row in reader:
if row[0] not in dataset_dict:
dataset_dict[row[0]] = []
dataset_dict[row[0]].append(
{
"raw": os.path.join(row[1], row[2]),
"gt": os.path.join(row[3], row[4]) if len(row) > 3 else "",
}
)
return dataset_dict
[docs]
def construct(self, dataset_dict) -> None:
"""Constructs the datasets from the dataset dictionary."""
self.train_datasets = []
self.validation_datasets = []
self.datasets = {}
print("Constructing datasets...")
for data_paths in tqdm(dataset_dict["train"], desc="Training datasets"):
try:
self.train_datasets.append(
CellMapDataset(
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
self.spatial_transforms,
self.train_raw_value_transforms,
self.target_value_transforms,
is_train=True,
context=self.context,
force_has_data=self.force_has_data,
empty_value=self.empty_value,
class_relation_dict=self.class_relation_dict,
pad=self.pad_training,
device=self.device,
)
)
except ValueError as e:
print(f"Error loading dataset: {e}")
self.datasets["train"] = self.train_datasets
# TODO: probably want larger arrays for validation
if "validate" in dataset_dict:
for data_paths in tqdm(
dataset_dict["validate"], desc="Validation datasets"
):
try:
self.validation_datasets.append(
CellMapDataset(
data_paths["raw"],
data_paths["gt"],
self.classes,
self.input_arrays,
self.target_arrays,
raw_value_transforms=self.val_raw_value_transforms,
target_value_transforms=self.target_value_transforms,
is_train=False,
context=self.context,
force_has_data=self.force_has_data,
empty_value=self.empty_value,
class_relation_dict=self.class_relation_dict,
pad=self.pad_validation,
device=self.device,
)
)
except ValueError as e:
print(f"Error loading dataset: {e}")
self.datasets["validate"] = self.validation_datasets
[docs]
def verify_datasets(self) -> None:
"""Verifies that the datasets have data, and removes ones that don't from ``self.train_datasets`` and ``self.validation_datasets``."""
if self.force_has_data:
return
print("Verifying datasets...")
verified_datasets = []
for ds in tqdm(self.train_datasets, desc="Training datasets"):
if ds.verify():
verified_datasets.append(ds)
self.train_datasets = verified_datasets
verified_datasets = []
for ds in tqdm(self.validation_datasets, desc="Validation datasets"):
if ds.verify():
verified_datasets.append(ds)
self.validation_datasets = verified_datasets
[docs]
def set_arrays(
self,
arrays: Mapping[str, Mapping[str, Sequence[int | float]]],
type: str = "target",
usage: str = "validate",
) -> None:
"""Sets the input or target arrays for the training or validation datasets."""
reset_attrs = []
for dataset in self.datasets[usage]:
if type == "inputs":
dataset.input_arrays = arrays
elif type == "target":
dataset.target_arrays = arrays
else:
raise ValueError("Type must be 'inputs' or 'target'.")
dataset.reset_arrays(type)
if usage == "train":
self.train_datasets = self.datasets["train"]
reset_attrs.append("_train_datasets_combined")
elif usage == "validate":
self.validation_datasets = self.datasets["validate"]
reset_attrs.extend(["_validation_datasets_combined", "_validation_blocks"])
for attr in reset_attrs:
if hasattr(self, attr):
delattr(self, attr)
[docs]
def to(self, device: str | torch.device) -> None:
"""Sets the device for the dataloaders."""
self.device = device
for dataset in self.train_datasets:
dataset.to(device)
for dataset in self.validation_datasets:
dataset.to(device)
if hasattr(self, "_train_datasets_combined"):
self._train_datasets_combined.to(device)
if hasattr(self, "_validation_datasets_combined"):
self._validation_datasets_combined.to(device)
if hasattr(self, "_validation_blocks"):
self._validation_blocks.to(device)