from typing import Any, Mapping, Optional, Sequence
import torch
import torchvision.transforms.v2 as T
from cellmap_data import CellMapDataLoader, CellMapDataSplit
from cellmap_data.transforms.augment import NaNtoNum, Normalize
[docs]
def get_dataloader(
datasplit_path: str,
classes: Sequence[str],
batch_size: int,
array_info: Optional[Mapping[str, Sequence[int | float]]] = None,
input_array_info: Optional[Mapping[str, Sequence[int | float]]] = None,
target_array_info: Optional[Mapping[str, Sequence[int | float]]] = None,
spatial_transforms: Optional[Mapping[str, Any]] = None,
# TODO: Add value transforms
iterations_per_epoch: int = 1000,
random_validation: bool = False,
device: Optional[str | torch.device] = None,
) -> tuple[CellMapDataLoader, CellMapDataLoader]:
"""
Get the train and validation dataloaders.
This function gets the train and validation dataloaders for the given datasplit file, classes, batch size, array
info, spatial transforms, iterations per epoch, number of workers, and device.
Parameters
----------
datasplit_path : str
Path to the datasplit file that defines the train/val split the dataloader should use.
classes : Sequence[str]
List of classes to segment.
batch_size : int
Batch size for the dataloader.
array_info : Optional[Mapping[str, Sequence[int | float]]]
Dictionary containing the shape and scale of the data to load for the input and target. Either `array_info` or `input_array_info` & `target_array_info` must be provided.
input_array_info : Optional[Mapping[str, Sequence[int | float]]]
Dictionary containing the shape and scale of the data to load for the input.
target_array_info : Optional[Mapping[str, Sequence[int | float]]]
Dictionary containing the shape and scale of the data to load for the target.
spatial_transforms : Optional[Mapping[str, any]]
Dictionary containing the spatial transformations to apply to the data.
For example the dictionary could contain transformations like mirror, transpose, and rotate.
spatial_transforms = {
# 3D
# Probability of applying mirror for each axis
# Values range from 0 (no mirroring) to 1 (will always mirror)
"mirror": {"axes": {"x": 0.5, "y": 0.5, "z": 0.5}},
# Specifies the axes that will be invovled in the trasposition
"transpose": {"axes": ["x", "y", "z"]},
# Defines rotation range for each axis.
# Rotation angle for each axis is randomly chosen within the specified range (-180, 180).
"rotate": {"axes": {"x": [-180, 180], "y": [-180, 180], "z": [-180, 180]}},
# 2D (used when there is no z axis)
# "mirror": {"axes": {"x": 0.5, "y": 0.5}},
# "transpose": {"axes": ["x", "y"]},
# "rotate": {"axes": {"x": [-180, 180], "y": [-180, 180]}},
}
iterations_per_epoch : int
Number of iterations per epoch.
random_validation : bool
Whether or not to randomize the validation data draws. Useful if not evaluating on the entire validation set everytime. Defaults to False.
device : Optional[str or torch.device]
Device to use for training. If None, defaults to "cuda" if available, or "mps" if available, or "cpu".
Returns
-------
tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]
Tuple containing the train and validation dataloaders.
"""
input_arrays = {
"input": input_array_info if input_array_info is not None else array_info
}
target_arrays = {
"output": target_array_info if target_array_info is not None else array_info
}
assert (
input_arrays is not None and target_arrays is not None
), "No array info provided"
value_transforms = T.Compose(
[
Normalize(),
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
)
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.mps.is_available():
device = "mps"
else:
device = "cpu"
datasplit = CellMapDataSplit(
input_arrays=input_arrays,
target_arrays=target_arrays,
classes=classes,
pad=True,
csv_path=datasplit_path,
train_raw_value_transforms=value_transforms,
val_raw_value_transforms=value_transforms,
target_value_transforms=T.ToDtype(torch.float),
spatial_transforms=spatial_transforms,
device=device,
)
validation_loader = CellMapDataLoader(
datasplit.validation_blocks.to(device),
classes=classes,
batch_size=batch_size,
is_train=random_validation,
device=device,
)
train_loader = CellMapDataLoader(
datasplit.train_datasets_combined.to(device),
classes=classes,
batch_size=batch_size,
sampler=lambda: datasplit.train_datasets_combined.get_subset_random_sampler(
iterations_per_epoch * batch_size, weighted=False
),
device=device,
)
return train_loader, validation_loader # type: ignore