import copy
import os
import tempfile
from glob import glob
from typing import Any
import torch
import torchvision.transforms.v2 as T
from cellmap_data import CellMapDatasetWriter, CellMapImage
from cellmap_data.utils import (
array_has_singleton_dim,
is_array_2D,
permute_singleton_dimension,
)
from cellmap_data.transforms.augment import NaNtoNum
from tqdm import tqdm
from upath import UPath
from .config import CROP_NAME, PREDICTIONS_PATH, RAW_NAME, SEARCH_PATH
from .models import get_model
from .utils import (
load_safe_config,
get_test_crops,
get_test_crop_labels,
get_data_from_batch,
get_singleton_dim,
squeeze_singleton_dim,
structure_model_output,
unsqueeze_singleton_dim,
)
from .utils.datasplit import get_formatted_fields, get_raw_path
[docs]
def predict_orthoplanes(
model: torch.nn.Module, dataset_writer_kwargs: dict[str, Any], batch_size: int
):
print("Predicting orthogonal planes.")
# Make a temporary prediction for each axis
tmp_dir = tempfile.TemporaryDirectory()
print(f"Temporary directory for predictions: {tmp_dir.name}")
for axis in range(3):
# Actually slice per axis by permuting singleton dimension
temp_kwargs = dataset_writer_kwargs.copy()
temp_kwargs["target_path"] = os.path.join(
tmp_dir.name, "output.zarr", str(axis)
)
# Permute input_arrays and target_arrays so singleton is at the current axis
input_arrays = {k: v.copy() for k, v in temp_kwargs["input_arrays"].items()}
target_arrays = {k: v.copy() for k, v in temp_kwargs["target_arrays"].items()}
permute_singleton_dimension(input_arrays, axis)
permute_singleton_dimension(target_arrays, axis)
temp_kwargs["input_arrays"] = input_arrays
temp_kwargs["target_arrays"] = target_arrays
_predict(
model,
temp_kwargs,
batch_size=batch_size,
)
# Get dataset writer for the average of predictions from x, y, and z orthogonal planes
# TODO: Skip loading raw data
dataset_writer = CellMapDatasetWriter(**dataset_writer_kwargs)
# Load the images for the individual predictions
single_axis_images = {
array_name: {
label: [
CellMapImage(
os.path.join(tmp_dir.name, "output.zarr", str(axis), label),
target_class=label,
target_scale=array_info["scale"],
target_voxel_shape=array_info["shape"],
pad=True,
pad_value=0,
)
for axis in range(3)
]
for label in dataset_writer_kwargs["classes"]
}
for array_name, array_info in dataset_writer_kwargs["target_arrays"].items()
}
# Combine the predictions from the x, y, and z orthogonal planes
print("Combining predictions.")
for batch in tqdm(dataset_writer.loader(batch_size=batch_size), dynamic_ncols=True):
# For each class, get the predictions from the x, y, and z orthogonal planes
outputs = {}
for array_name, images in single_axis_images.items():
outputs[array_name] = {}
for label in dataset_writer_kwargs["classes"]:
outputs[array_name][label] = []
for idx in batch["idx"]:
average_prediction = []
for image in images[label]:
average_prediction.append(image[dataset_writer.get_center(idx)])
average_prediction = torch.stack(average_prediction).mean(dim=0)
outputs[array_name][label].append(average_prediction)
outputs[array_name][label] = torch.stack(outputs[array_name][label])
# Save the outputs
dataset_writer[batch["idx"]] = outputs
tmp_dir.cleanup()
def _predict(
model: torch.nn.Module, dataset_writer_kwargs: dict[str, Any], batch_size: int
):
"""
Predicts the output of a model on a large dataset by splitting it into blocks and predicting each block separately.
Parameters
----------
model : torch.nn.Module
The model to use for prediction.
dataset_writer_kwargs : dict[str, Any]
A dictionary containing the arguments for the dataset writer.
batch_size : int
The batch size to use for prediction
"""
model.eval()
device = dataset_writer_kwargs["device"]
input_keys = list(dataset_writer_kwargs["input_arrays"].keys())
if "classes" not in dataset_writer_kwargs or not dataset_writer_kwargs["classes"]:
raise ValueError("No classes specified in dataset_writer_kwargs")
# Get the classes to use for model output (all classes the model was trained on)
# vs the classes to actually save (filtered by test_crop_manifest)
model_classes = dataset_writer_kwargs.get(
"model_classes", dataset_writer_kwargs["classes"]
)
# Restrict classes_to_save to only those the model knows about
classes_to_save = [
c for c in dataset_writer_kwargs["classes"] if c in model_classes
]
dataset_writer_kwargs["classes"] = classes_to_save
# Validate that classes_to_save is not empty
if not classes_to_save:
print("classes_to_save is empty. Nothing to predict. Skipping.")
return
# Create a mapping from class names to indices for efficient lookup during filtering
model_class_to_index = (
{c: i for i, c in enumerate(model_classes)}
if model_classes != classes_to_save
else None
)
# Test a single batch to get number of output channels
test_batch = {
k: torch.rand((1, *info["shape"])).unsqueeze(0).to(device)
for k, info in dataset_writer_kwargs["input_arrays"].items()
}
test_inputs = get_data_from_batch(test_batch, input_keys, device)
# Apply the same singleton-dimension squeezing as in the main prediction loop
singleton_dim = get_singleton_dim(
list(dataset_writer_kwargs["input_arrays"].values())[0]["shape"]
)
if singleton_dim is not None:
test_inputs = squeeze_singleton_dim(test_inputs, singleton_dim + 1)
with torch.no_grad():
test_outputs = model(test_inputs)
model_returns_class_dict = False
num_channels_per_class = None
if isinstance(test_outputs, dict):
if set(test_outputs.keys()) == set(model_classes):
# Keys are the class names; values are already per-class tensors
model_returns_class_dict = True
else:
# Dict with non-class keys (e.g., resolution levels): use the first
# value tensor to detect the channel count
test_outputs = next(iter(test_outputs.values()))
if not model_returns_class_dict and test_outputs.shape[1] > len(model_classes):
if test_outputs.shape[1] % len(model_classes) == 0:
num_channels_per_class = test_outputs.shape[1] // len(model_classes)
# To avoid mutating the input dictionary (which may be shared across multiple
# prediction calls), create a deep copy of target_arrays and update the shape
# to include the channel dimension.
target_arrays_copy = copy.deepcopy(dataset_writer_kwargs["target_arrays"])
for key in target_arrays_copy.keys():
current_shape = target_arrays_copy[key]["shape"]
# Use the first input array's shape to determine expected spatial rank
# (all input arrays should have the same spatial dimensions)
first_input_key = next(iter(dataset_writer_kwargs["input_arrays"]))
expected_spatial_rank = len(
dataset_writer_kwargs["input_arrays"][first_input_key]["shape"]
)
# Only prepend the channel dimension if the shape doesn't already include it
# We check if the current rank matches the expected spatial rank (no channel dim yet)
if len(current_shape) == expected_spatial_rank:
target_arrays_copy[key]["shape"] = (
num_channels_per_class,
*current_shape,
)
# Replace target_arrays in the kwargs with the modified copy
dataset_writer_kwargs = {
**dataset_writer_kwargs,
"target_arrays": target_arrays_copy,
}
else:
raise ValueError(
f"Number of output channels ({test_outputs.shape[1]}) does not match number of "
f"classes ({len(model_classes)}). Should be a multiple of the "
"number of classes."
)
del test_batch, test_inputs, test_outputs
if "raw_value_transforms" not in dataset_writer_kwargs:
dataset_writer_kwargs["raw_value_transforms"] = T.Compose(
[
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
)
dataset_writer_kwargs = {
k: v for k, v in dataset_writer_kwargs.items() if k != "model_classes"
}
dataset_writer = CellMapDatasetWriter(**dataset_writer_kwargs)
dataloader = dataset_writer.loader(batch_size=batch_size)
# Find singleton dimension if there is one
# Only the first singleton dimension will be used for squeezing/unsqueezing.
# If there are multiple singleton dimensions, only the first is handled.
with torch.no_grad():
for batch in tqdm(dataloader, dynamic_ncols=True):
# Get the inputs, handling dict vs. tensor data
inputs = get_data_from_batch(batch, input_keys, device)
if singleton_dim is not None:
inputs = squeeze_singleton_dim(inputs, singleton_dim + 2)
outputs = model(inputs)
if singleton_dim is not None:
outputs = unsqueeze_singleton_dim(outputs, singleton_dim + 2)
outputs = structure_model_output(
outputs,
model_classes,
num_channels_per_class,
)
# Filter outputs to only include the classes that should be saved
if model_class_to_index is not None:
filtered_outputs = {}
for array_name, class_outputs in outputs.items():
if isinstance(class_outputs, dict):
# Filter to only include classes_to_save
filtered_outputs[array_name] = {
class_name: class_tensor
for class_name, class_tensor in class_outputs.items()
if class_name in classes_to_save
}
else:
# If it's not a dict (just a tensor), we need to index the tensor
# This assumes the tensor has shape (B, C, ...) where C corresponds to model_classes
# We need to select only the channels for classes_to_save
# classes_to_save should be a subset of model_classes by design
# Use pre-computed mapping for O(1) lookup instead of O(n) index()
class_indices = [
model_class_to_index[c] for c in classes_to_save
]
filtered_outputs[array_name] = class_outputs[
:, class_indices, ...
]
outputs = filtered_outputs
# Save the outputs
dataset_writer[batch["idx"]] = outputs
[docs]
def predict(
config_path: str,
crops: str = "test",
output_path: str = PREDICTIONS_PATH,
do_orthoplanes: bool = True,
overwrite: bool = False,
search_path: str = SEARCH_PATH,
raw_name: str = RAW_NAME,
crop_name: str = CROP_NAME,
):
"""
Given a model configuration file and list of crop numbers, predicts the output of a model on a large dataset by splitting it into blocks and predicting each block separately.
Parameters
----------
config_path : str
The path to the model configuration file. This can be the same as the config file used for training.
crops: str, optional
A comma-separated list of crop numbers to predict on, or "test" to predict on the entire test set. Default is "test".
When crops="test", only the labels specified in the test_crop_manifest for each crop will be saved.
If a crop's test_crop_manifest specifies labels that the model wasn't trained on, those labels will be
automatically filtered out (i.e., only the intersection of model classes and crop labels will be saved).
output_path: str, optional
The path to save the output predictions to, formatted as a string with a placeholders for the dataset, crop number, and label. Default is PREDICTIONS_PATH set in `cellmap-segmentation/config.py`.
do_orthoplanes: bool, optional
Whether to compute the average of predictions from x, y, and z orthogonal planes for the full 3D volume. This is sometimes called 2.5D predictions. It expects a model that yields 2D outputs. Similarly, it expects the input shape to the model to be 2D. Default is True for 2D models.
overwrite: bool, optional
Whether to overwrite the output dataset if it already exists. Default is False.
search_path: str, optional
The path to search for the raw dataset, with placeholders for dataset and name. Default is SEARCH_PATH set in `cellmap-segmentation/config.py`.
raw_name: str, optional
The name of the raw dataset. Default is RAW_NAME set in `cellmap-segmentation/config.py`.
crop_name: str, optional
The name of the crop dataset with placeholders for crop and label. Default is CROP_NAME set in `cellmap-segmentation/config.py`.
Notes
-----
When crops="test", the function will only save predictions for labels that are specified
in the test_crop_manifest for each specific crop AND that the model was trained on (the
intersection of both sets). This ensures that only the labels that will be scored are saved,
reducing storage requirements and processing time.
"""
config = load_safe_config(config_path)
classes = config.classes
batch_size = getattr(config, "batch_size", 8)
input_array_info = getattr(
config, "input_array_info", {"shape": (1, 128, 128), "scale": (8, 8, 8)}
)
target_array_info = getattr(config, "target_array_info", input_array_info)
value_transforms = getattr(
config,
"value_transforms",
T.Compose(
[
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
),
)
model = config.model
# %% Check that the GPU is available
if getattr(config, "device", None) is not None:
device = config.device
elif torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Prediction device: {device}")
# %% Move model to device
model = model.to(device)
# Optionally, load a pre-trained model
checkpoint_epoch = get_model(config)
if checkpoint_epoch is not None:
print(f"Loaded model checkpoint from epoch: {checkpoint_epoch}")
if do_orthoplanes and (
array_has_singleton_dim(input_array_info)
or is_array_2D(input_array_info, summary=any)
):
# If the model is a 2D model, compute the average of predictions from x, y, and z orthogonal planes
predict_func = predict_orthoplanes
elif is_array_2D(input_array_info, summary=any) or is_array_2D(
target_array_info, summary=any
):
if is_array_2D(input_array_info, summary=any):
permute_singleton_dimension(input_array_info, axis=0)
if is_array_2D(target_array_info, summary=any):
permute_singleton_dimension(target_array_info, axis=0)
print(
"Warning: Model appears to be 2D, but do_orthoplanes is set to False. Predictions will be made only on z slices."
)
predict_func = _predict
else:
predict_func = _predict
assert (
input_array_info is not None and target_array_info is not None
), "No array info provided"
input_arrays = {"input": input_array_info}
target_arrays = {"output": target_array_info}
# Get the crops to predict on
if crops == "test":
test_crops = get_test_crops()
dataset_writers = []
for crop in test_crops:
# Get path to raw dataset
raw_path = search_path.format(dataset=crop.dataset, name=raw_name)
# Get the boundaries of the crop
target_bounds = {
"output": {
axis: [
crop.gt_source.translation[i],
crop.gt_source.translation[i]
+ crop.gt_source.voxel_size[i] * crop.gt_source.shape[i],
]
for i, axis in enumerate("zyx")
},
}
# Get the labels that should be scored for this specific crop from the test_crop_manifest
crop_labels = get_test_crop_labels(crop.id)
# Filter to only include labels that are in the model's classes
filtered_classes = [c for c in classes if c in crop_labels]
# If there are no matching labels between the model and this crop, skip it
if not filtered_classes:
tqdm.write(
f"Skipping crop {crop.id} (dataset={crop.dataset}) because there are "
f"no labels in common between model classes {classes} and crop labels {crop_labels}."
)
continue
# Create the writer
# Note: We pass all classes to the model for prediction, but only the filtered
# classes will be saved by the CellMapDatasetWriter
dataset_writers.append(
{
"raw_path": raw_path,
"target_path": output_path.format(
crop=f"crop{crop.id}",
dataset=crop.dataset,
),
"classes": filtered_classes,
"model_classes": classes, # All classes the model was trained on
"input_arrays": input_arrays,
"target_arrays": target_arrays,
"target_bounds": target_bounds,
"overwrite": overwrite,
"device": device,
"raw_value_transforms": value_transforms,
}
)
else:
crop_list = crops.split(",")
crop_paths = []
for i, crop in enumerate(crop_list):
if (isinstance(crop, str) and crop.isnumeric()) or isinstance(crop, int):
crop = f"crop{crop}"
crop_list[i] = crop # type: ignore
crop_paths.extend(
glob(
search_path.format(
dataset="*", name=crop_name.format(crop=crop, label="")
).rstrip(os.path.sep)
)
)
dataset_writers = []
for crop, crop_path in zip(crop_list, crop_paths): # type: ignore
# Get path to raw dataset
raw_path = get_raw_path(crop_path, label="")
# Get the boundaries of the crop
gt_images = {
array_name: CellMapImage(
str(UPath(crop_path) / classes[0]),
target_class=classes[0],
target_scale=array_info["scale"],
target_voxel_shape=array_info["shape"],
pad=True,
pad_value=0,
)
for array_name, array_info in target_arrays.items()
}
target_bounds = {
array_name: image.bounding_box
for array_name, image in gt_images.items()
}
dataset = get_formatted_fields(raw_path, search_path, ["{dataset}"])[
"dataset"
]
# Create the writer
dataset_writers.append(
{
"raw_path": raw_path,
"target_path": output_path.format(crop=crop, dataset=dataset),
"classes": classes,
"input_arrays": input_arrays,
"target_arrays": target_arrays,
"target_bounds": target_bounds,
"overwrite": overwrite,
"device": device,
"raw_value_transforms": value_transforms,
}
)
for dataset_writer in dataset_writers:
predict_func(model, dataset_writer, batch_size)