cellmap_data.multidataset#

Classes

CellMapMultiDataset(classes, input_arrays, ...)

This class is used to combine multiple datasets into a single dataset.

class cellmap_data.multidataset.CellMapMultiDataset(classes: Sequence[str], input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], datasets: Sequence[CellMapDataset])[source]#

This class is used to combine multiple datasets into a single dataset. It is a subclass of PyTorch’s ConcatDataset. It maintains the same API as the ConcatDataset class. It retrieves raw and groundtruth data from multiple CellMapDataset objects. See the CellMapDataset class for more information on the dataset object.

Parameters:
  • classes (Sequence[str])

  • input_arrays (Mapping[str, Mapping[str, Sequence[int | float]]])

  • target_arrays (Mapping[str, Mapping[str, Sequence[int | float]]])

  • datasets (Sequence[CellMapDataset])

classes#

Sequence[str] The classes in the dataset.

input_arrays#

Mapping[str, Mapping[str, Sequence[int | float]]] The input arrays for each dataset in the multi-dataset.

target_arrays#

Mapping[str, Mapping[str, Sequence[int | float]]] The target arrays for each dataset in the multi-dataset.

datasets#

Sequence[CellMapDataset] The datasets to be combined into the multi-dataset.

Type:

List[Dataset[_T_co]]

to(device

str | torch.device) -> “CellMapMultiDataset”: Moves the multi-dataset to the specified device.

get_weighted_sampler(batch_size

int = 1, rng: Optional[torch.Generator] = None) -> WeightedRandomSampler: Returns a weighted random sampler for the multi-dataset.

get_subset_random_sampler(num_samples

int, weighted: bool = True, rng: Optional[torch.Generator] = None) -> torch.utils.data.SubsetRandomSampler: Returns a random sampler that samples num_samples from the multi-dataset.

get_indices(chunk_size

Mapping[str, int]) -> Sequence[int]: Returns the indices of the multi-dataset that will tile all of the datasets according to the requested chunk_size.

set_raw_value_transforms(transforms

Callable) -> None: Sets the raw value transforms for each dataset in the multi-dataset.

set_target_value_transforms(transforms

Callable) -> None: Sets the target value transforms for each dataset in the multi-dataset.

set_spatial_transforms(spatial_transforms

Mapping[str, Any] | None) -> None: Sets the spatial transforms for each dataset in the multi-dataset.

Properties:
class_counts: Mapping[str, float]

Returns the number of samples in each class for each dataset in the multi-dataset, as well as the total number of samples in each class.

class_weights: Mapping[str, float]

Returns the class weights for the multi-dataset based on the number of samples in each class.

dataset_weights: Mapping[CellMapDataset, float]

Returns the weights for each dataset in the multi-dataset based on the number of samples of each class in each dataset.

sample_weights: Sequence[float]

Returns the weights for each sample in the multi-dataset based on the number of samples in each dataset.

validation_indices: Sequence[int]

Returns the indices of the validation set for each dataset in the multi-dataset.

datasets: List[Dataset[_T_co]]#
property class_counts: dict[str, float]#

Returns the number of samples in each class for each dataset in the multi-dataset, as well as the total number of samples in each class.

property class_weights: Mapping[str, float]#

Returns the class weights for the multi-dataset based on the number of samples in each class.

property dataset_weights: Mapping[CellMapDataset, float]#

Returns the weights for each dataset in the multi-dataset based on the number of samples in each dataset.

property sample_weights: Sequence[float]#

Returns the weights for each sample in the multi-dataset based on the number of samples in each dataset.

property validation_indices: Sequence[int]#

Returns the indices of the validation set for each dataset in the multi-dataset.

to(device: str | device) CellMapMultiDataset[source]#
Parameters:

device (str | device)

Return type:

CellMapMultiDataset

get_weighted_sampler(batch_size: int = 1, rng: Generator | None = None) WeightedRandomSampler[source]#
Parameters:
  • batch_size (int)

  • rng (Generator | None)

Return type:

WeightedRandomSampler

get_subset_random_sampler(num_samples: int, weighted: bool = True, rng: Generator | None = None) SubsetRandomSampler[source]#

Returns a random sampler that samples num_samples from the dataset.

Parameters:
  • num_samples (int)

  • weighted (bool)

  • rng (Generator | None)

Return type:

SubsetRandomSampler

get_indices(chunk_size: Mapping[str, int]) Sequence[int][source]#

Returns the indices of the dataset that will tile all of the datasets according to the chunk_size.

Parameters:

chunk_size (Mapping[str, int])

Return type:

Sequence[int]

set_raw_value_transforms(transforms: Callable) None[source]#

Sets the raw value transforms for each dataset in the multi-dataset.

Parameters:

transforms (Callable)

Return type:

None

set_target_value_transforms(transforms: Callable) None[source]#

Sets the target value transforms for each dataset in the multi-dataset.

Parameters:

transforms (Callable)

Return type:

None

set_spatial_transforms(spatial_transforms: Mapping[str, Any] | None) None[source]#

Sets the raw value transforms for each dataset in the training multi-dataset.

Parameters:

spatial_transforms (Mapping[str, Any] | None)

Return type:

None

static empty() CellMapMultiDataset[source]#

Creates an empty dataset.

Return type:

CellMapMultiDataset

cumulative_sizes: List[int]#