Source code for cellmap_data.dataloader
import torch
from torch.utils.data import DataLoader, Sampler, Subset
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .dataset_writer import CellMapDatasetWriter
from typing import Callable, Optional, Sequence
[docs]
class CellMapDataLoader:
"""
Utility class to create a DataLoader for a CellMapDataset or CellMapMultiDataset.
Attributes:
dataset (CellMapMultiDataset | CellMapDataset | CellMapSubset): The dataset to load.
classes (Iterable[str]): The classes to load.
batch_size (int): The batch size.
num_workers (int): The number of workers to use.
weighted_sampler (bool): Whether to use a weighted sampler.
sampler (Sampler | Callable | None): The sampler to use.
is_train (bool): Whether the data is for training and thus should be shuffled.
rng (Optional[torch.Generator]): The random number generator to use.
loader (DataLoader): The PyTorch DataLoader.
default_kwargs (dict): The default arguments to pass to the PyTorch DataLoader.
Methods:
refresh: If the sampler is a Callable, refresh the DataLoader with the current sampler.
collate_fn: Combine a list of dictionaries from different sources into a single dictionary for output.
"""
[docs]
def __init__(
self,
dataset: CellMapMultiDataset | CellMapDataset | Subset | CellMapDatasetWriter,
classes: Sequence[str],
batch_size: int = 1,
num_workers: int = 0,
weighted_sampler: bool = False,
sampler: Sampler | Callable | None = None,
is_train: bool = True,
rng: Optional[torch.Generator] = None,
device: Optional[str | torch.device] = None,
**kwargs,
):
"""
Initialize the CellMapDataLoader
Args:
dataset (CellMapMultiDataset | CellMapDataset | CellMapSubset): The dataset to load.
classes (Iterable[str]): The classes to load.
batch_size (int): The batch size.
num_workers (int): The number of workers to use.
weighted_sampler (bool): Whether to use a weighted sampler. Defaults to False.
sampler (Sampler | Callable | None): The sampler to use.
is_train (bool): Whether the data is for training and thus should be shuffled.
rng (Optional[torch.Generator]): The random number generator to use.
device (Optional[str | torch.device]): The device to use. Defaults to "cuda" or "mps" if available, else "cpu".
`**kwargs`: Additional arguments to pass to the DataLoader.
"""
self.dataset = dataset
self.classes = classes
self.batch_size = batch_size
self.num_workers = num_workers
self.weighted_sampler = weighted_sampler
self.sampler = sampler
self.is_train = is_train
self.rng = rng
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
self.dataset.to(device)
if self.sampler is None and self.weighted_sampler:
assert isinstance(
self.dataset, CellMapMultiDataset
), "Weighted sampler only relevant for CellMapMultiDataset"
self.sampler = self.dataset.get_weighted_sampler(self.batch_size, self.rng)
self.default_kwargs = kwargs.copy()
kwargs.update(
{
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
}
)
if self.sampler is not None:
if isinstance(self.sampler, Callable):
kwargs["sampler"] = self.sampler()
else:
kwargs["sampler"] = self.sampler
elif self.is_train:
kwargs["shuffle"] = True
else:
kwargs["shuffle"] = False
# TODO: Try persistent workers
self.loader = DataLoader(**kwargs)
[docs]
def __getitem__(self, indices: Sequence[int]) -> dict:
"""Get an item from the DataLoader."""
if isinstance(indices, int):
indices = [indices]
return self.collate_fn([self.loader.dataset[index] for index in indices])
[docs]
def refresh(self):
"""If the sampler is a Callable, refresh the DataLoader with the current sampler."""
kwargs = self.default_kwargs.copy()
kwargs.update(
{
"dataset": self.dataset,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"collate_fn": self.collate_fn,
}
)
if self.sampler is not None:
if isinstance(self.sampler, Callable):
kwargs["sampler"] = self.sampler()
else:
kwargs["sampler"] = self.sampler
elif self.is_train:
kwargs["shuffle"] = True
else:
kwargs["shuffle"] = False
self.loader = DataLoader(**kwargs)
[docs]
def collate_fn(self, batch: list[dict]) -> dict[str, torch.Tensor]:
"""Combine a list of dictionaries from different sources into a single dictionary for output."""
outputs = {}
for b in batch:
for key, value in b.items():
if key not in outputs:
outputs[key] = []
outputs[key].append(value)
for key, value in outputs.items():
outputs[key] = torch.stack(value)
return outputs