cellmap_data.CellMapDataset#

class cellmap_data.CellMapDataset(raw_path: str, target_path: str, classes: Sequence[str], input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], spatial_transforms: Mapping[str, Mapping] | None = None, raw_value_transforms: Callable | None = None, target_value_transforms: Callable | Sequence[Callable] | Mapping[str, Callable] | None = None, class_relation_dict: Mapping[str, Sequence[str]] | None = None, is_train: bool = False, axis_order: str = 'zyx', context: Context | None = None, rng: Generator | None = None, force_has_data: bool = False, empty_value: float | int = nan, pad: bool = False, device: str | device | None = None)[source]#

Initializes the CellMapDataset class.

Parameters:
  • raw_path (str) – The path to the raw data.

  • target_path (str) – The path to the ground truth data.

  • classes (Sequence[str]) – A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros.

  • input_arrays (Mapping[str, Mapping[str, Sequence[int | float]]]) –

    A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:

    {
        "array_name": {
            "shape": tuple[int],
            "scale": Sequence[float],
        },
        ...
    }
    

  • array (where 'array_name' is the name of the)

  • voxels ('shape' is the shape of the array in)

  • units. (and 'scale' is the scale of the array in world)

  • target_arrays (Mapping[str, Mapping[str, Sequence[int | float]]]) – A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the same structure as ‘input_arrays’.

  • spatial_transforms (Optional[Mapping[str, Any]] = None, optional) –

    A sequence of dictionaries containing the spatial transformations to apply to the data. Defaults to None. The dictionary should have the following structure:

    {transform_name: {transform_args}}
    

  • raw_value_transforms (Optional[Callable], optional) – A function to apply to the raw data. Defaults to None. Example is to normalize the raw data.

  • target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]], optional) – A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. If the function is a dictionary, the keys should correspond to the classes in the ‘classes’ list. The function should return a tensor of the same shape as the input tensor. Note that target transforms are applied to the ground truth data and should generally not be used with use of true-negative data inferred using the ‘class_relations_dict’.

  • is_train (bool, optional) – Whether the dataset is for training. Defaults to False.

  • context (Optional[tensorstore.Context], optional) – The context for the image data. Defaults to None.

  • rng (Optional[torch.Generator], optional) – A random number generator. Defaults to None.

  • force_has_data (bool, optional) – Whether to force the dataset to report that it has data. Defaults to False.

  • empty_value (float | int, optional) – The value to fill in for empty data. Defaults to torch.nan.

  • pad (bool, optional) – Whether to pad the image data to match requested arrays. Defaults to False.

  • device (Optional[str | torch.device], optional) – The device for the dataset. Defaults to None. If None, the device will be set to “cuda” if available, “mps” if available, or “cpu” if neither are available.

  • class_relation_dict (Mapping[str, Sequence[str]] | None)

  • axis_order (str)

__init__(raw_path: str, target_path: str, classes: Sequence[str], input_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], target_arrays: Mapping[str, Mapping[str, Sequence[int | float]]], spatial_transforms: Mapping[str, Mapping] | None = None, raw_value_transforms: Callable | None = None, target_value_transforms: Callable | Sequence[Callable] | Mapping[str, Callable] | None = None, class_relation_dict: Mapping[str, Sequence[str]] | None = None, is_train: bool = False, axis_order: str = 'zyx', context: Context | None = None, rng: Generator | None = None, force_has_data: bool = False, empty_value: float | int = nan, pad: bool = False, device: str | device | None = None) None[source]#

Initializes the CellMapDataset class.

Parameters:
  • raw_path (str) – The path to the raw data.

  • target_path (str) – The path to the ground truth data.

  • classes (Sequence[str]) – A list of classes for segmentation training. Class order will be preserved in the output arrays. Classes not contained in the dataset will be filled in with zeros.

  • input_arrays (Mapping[str, Mapping[str, Sequence[int | float]]]) –

    A dictionary containing the arrays of the dataset to input to the network. The dictionary should have the following structure:

    {
        "array_name": {
            "shape": tuple[int],
            "scale": Sequence[float],
        },
        ...
    }
    

  • array (where 'array_name' is the name of the)

  • voxels ('shape' is the shape of the array in)

  • units. (and 'scale' is the scale of the array in world)

  • target_arrays (Mapping[str, Mapping[str, Sequence[int | float]]]) – A dictionary containing the arrays of the dataset to use as targets for the network. The dictionary should have the same structure as ‘input_arrays’.

  • spatial_transforms (Optional[Mapping[str, Any]] = None, optional) –

    A sequence of dictionaries containing the spatial transformations to apply to the data. Defaults to None. The dictionary should have the following structure:

    {transform_name: {transform_args}}
    

  • raw_value_transforms (Optional[Callable], optional) – A function to apply to the raw data. Defaults to None. Example is to normalize the raw data.

  • target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]], optional) – A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. If the function is a dictionary, the keys should correspond to the classes in the ‘classes’ list. The function should return a tensor of the same shape as the input tensor. Note that target transforms are applied to the ground truth data and should generally not be used with use of true-negative data inferred using the ‘class_relations_dict’.

  • is_train (bool, optional) – Whether the dataset is for training. Defaults to False.

  • context (Optional[tensorstore.Context], optional) – The context for the image data. Defaults to None.

  • rng (Optional[torch.Generator], optional) – A random number generator. Defaults to None.

  • force_has_data (bool, optional) – Whether to force the dataset to report that it has data. Defaults to False.

  • empty_value (float | int, optional) – The value to fill in for empty data. Defaults to torch.nan.

  • pad (bool, optional) – Whether to pad the image data to match requested arrays. Defaults to False.

  • device (Optional[str | torch.device], optional) – The device for the dataset. Defaults to None. If None, the device will be set to “cuda” if available, “mps” if available, or “cpu” if neither are available.

  • class_relation_dict (Mapping[str, Sequence[str]] | None)

  • axis_order (str)

Return type:

None

Methods

__init__(raw_path, target_path, classes, ...)

Initializes the CellMapDataset class.

empty()

Creates an empty dataset.

generate_spatial_transforms()

When 'self.is_train' is True, generates random spatial transforms for the dataset, based on the user specified transforms.

get_empty_store(array_info, device)

Returns an empty store, based on the requested array.

get_indices(chunk_size)

Returns the indices of the dataset that will tile the dataset according to the chunk_size.

get_label_array(label, i, array_info, ...)

Returns a target array source for a specific class in the dataset.

get_target_array(array_info)

Returns a target array source for the dataset.

reset_arrays([type])

Sets the arrays for the dataset to return.

set_raw_value_transforms(transforms)

Sets the raw value transforms for the dataset.

set_target_value_transforms(transforms)

Sets the ground truth value transforms for the dataset.

to(device)

Sets the device for the dataset.

verify()

Verifies that the dataset is valid to draw samples from.

Attributes

bounding_box

Returns the bounding box of the dataset.

bounding_box_shape

Returns the shape of the bounding box of the dataset in voxels of the largest voxel size requested.

center

Returns the center of the dataset in world units.

class_counts

Returns the number of pixels for each class in the ground truth data, normalized by the resolution.

class_weights

Returns the class weights for the dataset based on the number of samples in each class.

device

Returns the device for the dataset.

largest_voxel_sizes

Returns the largest voxel size of the dataset.

sampling_box

Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box).

sampling_box_shape

Returns the shape of the sampling box of the dataset in voxels of the largest voxel size requested.

size

Returns the size of the dataset in voxels of the largest voxel size requested.

validation_indices

Returns the indices of the dataset that will produce non-overlapping tiles for use in validation, based on the largest requested voxel size.

property center: Mapping[str, float] | None#

Returns the center of the dataset in world units.

property largest_voxel_sizes: Mapping[str, float]#

Returns the largest voxel size of the dataset.

property bounding_box: Mapping[str, list[float]]#

Returns the bounding box of the dataset.

property bounding_box_shape: Mapping[str, int]#

Returns the shape of the bounding box of the dataset in voxels of the largest voxel size requested.

property sampling_box: Mapping[str, list[float]]#

Returns the sampling box of the dataset (i.e. where centers can be drawn from and still have full samples drawn from within the bounding box).

property sampling_box_shape: dict[str, int]#

Returns the shape of the sampling box of the dataset in voxels of the largest voxel size requested.

property size: int#

Returns the size of the dataset in voxels of the largest voxel size requested.

property class_counts: Mapping[str, Mapping[str, float]]#

Returns the number of pixels for each class in the ground truth data, normalized by the resolution.

property class_weights: Mapping[str, float]#

Returns the class weights for the dataset based on the number of samples in each class. Classes without any samples will have a weight of NaN.

property validation_indices: Sequence[int]#

Returns the indices of the dataset that will produce non-overlapping tiles for use in validation, based on the largest requested voxel size.

property device: device#

Returns the device for the dataset.

__len__() int[source]#

Returns the length of the dataset, determined by the number of coordinates that could be sampled as the center for an array request.

Return type:

int

__getitem__(idx: int) dict[str, Tensor][source]#

Returns a crop of the input and target data as PyTorch tensors, corresponding to the coordinate of the unwrapped index.

Parameters:

idx (int)

Return type:

dict[str, Tensor]

get_empty_store(array_info: Mapping[str, Sequence[int]], device: device) Tensor[source]#

Returns an empty store, based on the requested array.

Parameters:
  • array_info (Mapping[str, Sequence[int]])

  • device (device)

Return type:

Tensor

get_target_array(array_info: Mapping[str, Sequence[int | float]]) dict[str, CellMapImage | EmptyImage | Sequence[str]][source]#

Returns a target array source for the dataset. Creates a dictionary of image sources for each class in the dataset. For classes that are not present in the ground truth data, the data can be inferred from the other classes in the dataset. This is useful for training segmentation networks with mutually exclusive classes.

Parameters:

array_info (Mapping[str, Sequence[int | float]])

Return type:

dict[str, CellMapImage | EmptyImage | Sequence[str]]

get_label_array(label: str, i: int, array_info: Mapping[str, Sequence[int | float]], empty_store: Tensor) CellMapImage | EmptyImage | Sequence[str][source]#

Returns a target array source for a specific class in the dataset.

Parameters:
  • label (str)

  • i (int)

  • array_info (Mapping[str, Sequence[int | float]])

  • empty_store (Tensor)

Return type:

CellMapImage | EmptyImage | Sequence[str]

verify() bool[source]#

Verifies that the dataset is valid to draw samples from.

Return type:

bool

get_indices(chunk_size: Mapping[str, int]) Sequence[int][source]#

Returns the indices of the dataset that will tile the dataset according to the chunk_size.

Parameters:

chunk_size (Mapping[str, int])

Return type:

Sequence[int]

to(device: str | device) CellMapDataset[source]#

Sets the device for the dataset.

Parameters:

device (str | device)

Return type:

CellMapDataset

generate_spatial_transforms() Mapping[str, Any] | None[source]#

When ‘self.is_train’ is True, generates random spatial transforms for the dataset, based on the user specified transforms.

Available spatial transforms:
  • “mirror”: Mirrors the data along the specified axes. Parameters are the probabilities of mirroring along each axis, formatted as a dictionary of axis: probability pairs. Example: {“mirror”: {“axes”: {“x”: 0.5, “y”: 0.5, “z”:0.1}}} will mirror the data along the x and y axes with a 50% probability, and along the z axis with a 10% probability.

  • “transpose”: Transposes the data along the specified axes. Parameters are the axes to transpose, formatted as a list. Example: {“transpose”: {“axes”: [“x”, “z”]}} will randomly transpose the data along the x and z axes.

  • “rotate”: Rotates the data around the specified axes within the specified angle ranges. Parameters are the axes to rotate and the angle ranges, formatted as a dictionary of axis: [min_angle, max_angle] pairs. Example: {“rotate”: {“axes”: {“x”: [-180,180], “y”: [-180,180], “z”:[-180,180]}} will rotate the data around the x, y, and z axes from 180 to -180 degrees.

Return type:

Mapping[str, Any] | None

set_raw_value_transforms(transforms: Callable) None[source]#

Sets the raw value transforms for the dataset.

Parameters:

transforms (Callable)

Return type:

None

set_target_value_transforms(transforms: Callable) None[source]#

Sets the ground truth value transforms for the dataset.

Parameters:

transforms (Callable)

Return type:

None

reset_arrays(type: str = 'target') None[source]#

Sets the arrays for the dataset to return.

Parameters:

type (str)

Return type:

None

static empty() CellMapDataset[source]#

Creates an empty dataset.

Return type:

CellMapDataset