cellmap_data.dataloader#

Classes

CellMapDataLoader(dataset, classes[, ...])

Initialize the CellMapDataLoader

class cellmap_data.dataloader.CellMapDataLoader(dataset: CellMapMultiDataset | CellMapDataset | Subset | CellMapDatasetWriter, classes: Sequence[str], batch_size: int = 1, num_workers: int = 0, weighted_sampler: bool = False, sampler: Sampler | Callable | None = None, is_train: bool = True, rng: Generator | None = None, device: str | device | None = None, **kwargs)[source]#

Initialize the CellMapDataLoader

Parameters:
  • dataset (CellMapMultiDataset | CellMapDataset | CellMapSubset) – The dataset to load.

  • classes (Iterable[str]) – The classes to load.

  • batch_size (int) – The batch size.

  • num_workers (int) – The number of workers to use.

  • weighted_sampler (bool) – Whether to use a weighted sampler. Defaults to False.

  • sampler (Sampler | Callable | None) – The sampler to use.

  • is_train (bool) – Whether the data is for training and thus should be shuffled.

  • rng (Optional[torch.Generator]) – The random number generator to use.

  • device (Optional[str | torch.device]) – The device to use. Defaults to “cuda” or “mps” if available, else “cpu”.

  • **kwargs – Additional arguments to pass to the DataLoader.

__getitem__(indices: Sequence[int]) dict[source]#

Get an item from the DataLoader.

Parameters:

indices (Sequence[int])

Return type:

dict

refresh()[source]#

If the sampler is a Callable, refresh the DataLoader with the current sampler.

collate_fn(batch: list[dict]) dict[str, Tensor][source]#

Combine a list of dictionaries from different sources into a single dictionary for output.

Parameters:

batch (list[dict])

Return type:

dict[str, Tensor]