from glob import glob
from importlib.machinery import SourceFileLoader
import tempfile
from time import time
import torch
# torch.multiprocessing.set_start_method("spawn", force=True)
from typing import Optional, Sequence
import os
import daisy
from funlib.persistence import open_ds, prepare_ds, Array
import numpy as np
from upath import UPath
import zarr
from cellmap_segmentation_challenge.utils.datasplit import (
CROP_NAME,
REPO_ROOT,
SEARCH_PATH,
get_raw_path,
)
def find_level(
group, target_scale: Sequence[float]
) -> tuple[str, np.ndarray, np.ndarray]:
"""
Finds the multiscale level that is closest to the target scale.
Parameters
----------
group : zarr.Group
The group to search for the multiscale levels.
target_scale : Sequence[float]
The target scale to find the closest multiscale level to.
Returns
-------
tuple[str, np.ndarray, np.ndarray]
The name of the closest multiscale level, the offset of the level, and the scale of the level.
"""
last_path: str | None = None
offset = None
scale = None
for level in group.attrs["multiscales"][0]["datasets"]:
for transform in level["coordinateTransformations"]:
if "translation" in transform:
offset = np.array(transform["translation"])
if "scale" in transform:
scale = np.array(transform["scale"])
if offset is not None and scale is not None:
break
for ts, s in zip(target_scale, scale):
if ts <= s:
if last_path is None:
return level["path"], offset, scale
else:
return last_path, offset, scale
last_path = level["path"]
return last_path, offset, scale # type: ignore
def get_crop_roi(crop_path: str, output_scale: Sequence[int]) -> str:
"""
Get the ROI of a crop from the crop path, formatted as a string: [start1:end1,start2:end2,...].
Parameters
----------
crop_path : str
The path to the crop.
Returns
-------
roi_str : str
The ROI of the crop formatted as a string.
"""
crop_group = zarr.open(crop_path, mode="r") # type: ignore
scale_level, offset, scale = find_level(crop_group, output_scale)
crop_dataset = zarr.open(UPath(crop_path) / scale_level, mode="r") # type: ignore
shape = np.array(crop_dataset.shape) * scale
starts = np.array(offset).astype(int)
ends = (starts + shape).astype(int)
roi = f"[{','.join([f'{str(s)}:{str(e)}' for s, e in zip(starts, ends)])}]"
return roi
def get_output_shape(
model: torch.nn.Module, input_shape: Sequence[int]
) -> Sequence[int]:
"""
Computes the output shape of a model given an input shape.
Parameters
----------
model : torch.nn.Module
The model to compute the output shape for.
input_shape : Sequence[int]
The input shape of the model.
Returns
-------
Sequence[int]
The output shape of the model.
"""
input_tensor = torch.zeros(input_shape).squeeze()[None, None, ...]
device = list(model.parameters())[0].device
output_tensor = model(input_tensor.to(device))
return output_tensor.shape[2:]
def predict_ortho_planes(
model: torch.nn.Module,
in_dataset: str | os.PathLike,
out_dataset: str | os.PathLike,
input_block_shape: Sequence[int],
channels: Sequence[str] | dict[str | int, str],
roi: Optional[str] = None,
output_voxel_size: Optional[Sequence[int]] = None,
min_raw: float = 0,
max_raw: float = 255,
) -> None:
"""
Predicts the average 3D output of a 2D model on a large dataset by splitting it into blocks and predicting each block separately, then averaging the predictions from the x, y, and z orthogonal planes.
Parameters
----------
model : torch.nn.Module
The model to use for prediction.
in_dataset: str | os.PathLike
The path to the input dataset.
out_dataset: str | os.PathLike
The path to the output dataset, not including the name of the channel.
input_block_shape : Sequence[int]
The shape of the input slices to use for prediction.
channels : Sequence[str] | dict[str | int, str]
The label classes to predict. The output will be saved in separate datasets for each label class. If multiple output channels belong to the same label class (such as for x,y,z affinities), they can be combined by indicating chanel:label matches with a dictionary and specifying the channels as a string of the form "0-2". Example: {"0-2":"nuc"}. For single channel per label class predictions, the channels can be specified as strings in a list or tuple.
roi : str, optional
The region of interest to predict. If None, the entire dataset will be predicted.
The format is a string of the form "[start1:end1,start2:end2,...]".
output_voxel_size : Sequence[int | float], optional
The voxel size of the output data. Default is the same as the input voxel size.
min_raw : float, optional
The minimum value of the raw data. Default is 0.
max_raw : float, optional
The maximum value of the raw data. Default is 255.
"""
print("Predicting orthogonal planes.")
# Make a temporary prediction for each axis
tmp_dir = tempfile.TemporaryDirectory()
print(f"Temporary directory for predictions: {tmp_dir.name}")
for axis in range(3):
_predict(
model,
in_dataset,
os.path.join(tmp_dir.name, "output.zarr", str(axis)),
input_block_shape,
channels,
roi,
output_voxel_size,
min_raw,
max_raw,
)
# Combine the predictions from the x, y, and z orthogonal planes
raw_dataset = open_ds(in_dataset)
if output_voxel_size is None:
output_voxel_size = raw_dataset.voxel_size
labels = channels if isinstance(channels, Sequence) else channels.values()
for label in labels:
# Load the predictions from the x, y, and z orthogonal planes
predictions = []
for axis in range(3):
predictions.append(
open_ds(os.path.join(tmp_dir.name, f"output.zarr", str(axis), label))[:]
)
# Combine the predictions
combined_predictions = np.mean(predictions, axis=0)
# Save the combined predictions
example_ds = open_ds(
os.path.join(tmp_dir.name, f"output.zarr", str(axis), label)
)
dataset = prepare_ds(
UPath(f"{out_dataset}/{label}").path,
shape=example_ds.roi.shape,
offset=example_ds.roi.offset,
voxel_size=example_ds.voxel_size,
chunk_shape=example_ds.chunk_shape,
dtype=example_ds.dtype,
)
dataset[:] = combined_predictions
tmp_dir.cleanup()
def _predict(
model: torch.nn.Module,
in_dataset: str | os.PathLike,
out_dataset: str | os.PathLike,
input_block_shape: Sequence[int],
channels: Sequence[str] | dict[str | int, str],
roi: Optional[str] = None,
output_voxel_size: Optional[Sequence[int]] = None,
min_raw: float = 0,
max_raw: float = 255,
) -> None:
"""
Predicts the output of a model on a large dataset by splitting it into blocks
and predicting each block separately.
Parameters
----------
model : torch.nn.Module
The model to use for prediction.
in_dataset: str | os.PathLike
The path to the input dataset.
out_dataset: str | os.PathLike
The path to the output dataset, not including the name of the channel.
input_block_shape : Sequence[int]
The shape of the input blocks to use for prediction.
channels : Sequence[str] | dict[str | int, str]
The label classes to predict. The output will be saved in separate datasets for each label class. If multiple output channels belong to the same label class (such as for x,y,z affinities), they can be combined by indicating chanel:label matches with a dictionary and specifying the channels as a string of the form "0-2". Example: {"0-2":"nuc"}. For single channel per label class predictions, the channels can be specified as strings in a list or tuple.
roi : str, optional
The region of interest to predict. If None, the entire dataset will be predicted.
The format is a string of the form "[start1:end1,start2:end2,...]".
output_voxel_size : Sequence[int | float], optional
The voxel size of the output data. Default is the same as the input voxel size.
min_raw : float, optional
The minimum value of the raw data. Default is 0.
max_raw : float, optional
The maximum value of the raw data. Default is 255.
"""
model.eval()
if torch.cuda.is_available():
model = model.cuda()
shift = min_raw
scale = max_raw - min_raw
raw_dataset = open_ds(in_dataset)
if output_voxel_size is None:
output_voxel_size = raw_dataset.voxel_size
output_voxel_size = daisy.Coordinate(output_voxel_size)
if len(input_block_shape) == 2:
input_block_shape = (1,) + tuple(input_block_shape)
output_block_shape = get_output_shape(model, input_block_shape)
if len(output_block_shape) == 2:
output_block_shape = (1,) + tuple(output_block_shape)
read_shape = daisy.Coordinate(input_block_shape) * raw_dataset.voxel_size
write_shape = daisy.Coordinate(output_block_shape) * output_voxel_size
context = (read_shape - write_shape) / 2
read_roi = daisy.Roi((0,) * read_shape.dims, read_shape)
write_roi = read_roi.grow(-context, -context)
if roi is not None:
print(f"Predicting on ROI: {roi}")
parsed_start, parsed_end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
parsed_roi = daisy.Roi(
daisy.Coordinate(parsed_start),
daisy.Coordinate(parsed_end) - daisy.Coordinate(parsed_start),
)
total_write_roi = parsed_roi.snap_to_grid(output_voxel_size)
total_read_roi = total_write_roi.grow(context, context).snap_to_grid(
raw_dataset.voxel_size, mode="grow"
)
else:
total_read_roi = raw_dataset.roi
total_write_roi = total_read_roi.grow(-context, -context).snap_to_grid(
output_voxel_size
)
if isinstance(channels, Sequence):
channels = {str(i): c for i, c in enumerate(channels)}
out_datasets = {}
for channel, label in channels.items():
dataset = prepare_ds(
UPath(f"{out_dataset}/{label}").path,
shape=total_write_roi.shape / output_voxel_size,
offset=total_write_roi.offset,
voxel_size=output_voxel_size,
chunk_shape=output_block_shape,
dtype=np.float32,
units=[
"nm",
]
* len(output_voxel_size),
)
out_datasets[channel] = dataset
device = list(model.parameters())[0].device
print(f"Predicting on {in_dataset} and saving to {out_dataset} using {device}.")
def predict_worker(block):
start_time = time()
raw_input = (
2.0
* (
raw_dataset.to_ndarray(
roi=block.read_roi, fill_value=shift + scale
).astype(np.float32)
- shift
)
/ scale
) - 1.0
print(f"Reading from {block.read_roi} in {time() - start_time:.2f} s.")
ndims = len(raw_input.squeeze().shape)
raw_input = raw_input.squeeze()[None, None, ...]
write_roi = block.write_roi # .intersect(out_datasets[0].roi))
with torch.no_grad():
# Time the prediction
# start_time = time()
output = (
model(torch.Tensor(raw_input).to(device=device, dtype=torch.float32))[0]
.detach()
.cpu()
.numpy()
)
# print(f"Prediction time: {time() - start_time:.2f} s")
if ndims == 2:
output = output[:, None, ...]
predictions = Array(
output,
block.write_roi.offset,
dataset.voxel_size,
)
write_data = predictions.to_ndarray(write_roi)
# print(f"Writing to {write_roi}...")
for i, out_dataset in out_datasets.items():
# start_time = time()
if "-" in i:
indexes = i.split("-")
indexes = np.arange(int(indexes[0]), int(indexes[1]) + 1)
else:
indexes = [int(i)]
if len(indexes) > 1:
out_dataset[write_roi] = np.stack(
[write_data[j] for j in indexes], axis=0
)
else:
out_dataset[write_roi] = write_data[indexes[0]]
# print(
# f"Finished writing to output(s) {i} in {time() - start_time:.2f} s."
# )
# block.status = daisy.BlockStatus.SUCCESS
task = daisy.Task(
f"predict_{in_dataset}",
total_roi=total_read_roi,
read_roi=read_roi,
write_roi=write_roi,
process_function=predict_worker,
check_function=None,
read_write_conflict=False,
# fit="overhang",
num_workers=1,
max_retries=0,
timeout=None,
)
daisy.run_blockwise([task], multiprocessing=False)
[docs]
def predict(
config_path: str,
crops: str = "test",
output_path: str = UPath(
REPO_ROOT / "data/predictions/predictions.zarr/{crop}"
).path,
do_orthoplanes: bool = True,
):
"""
Given a model configuration file and list of crop numbers, predicts the output of a model on a large dataset by splitting it into blocks
and predicting each block separately.
Parameters
----------
config_path : str
The path to the model configuration file. This can be the same as the config file used for training.
crops: str, optional
A comma-separated list of crop numbers to predict on, or "test" to predict on the entire test set. Default is "test".
output_path: str, optional
The path to save the output predictions to, formatted as a string with a placeholders for the crop number, and label class. Default is "cellmap-segmentation-challenge/data/predictions/predictions.zarr/{crop}/{label}".
do_orthoplanes: bool, optional
Whether to compute the average of predictions from x, y, and z orthogonal planes for the full 3D volume. This is sometimes called 2.5D predictions. It expects a model that yields 2D outputs. Similarly, it expects the input shape to the model to be 2D. Default is True for 2D models.
"""
config = SourceFileLoader(UPath(config_path).stem, str(config_path)).load_module()
model = config.model
input_block_shape = config.input_array_info["shape"]
input_scale = config.input_array_info["scale"]
output_scale = config.target_array_info["scale"]
classes = config.classes
if do_orthoplanes and any([s == 1 for s in input_block_shape]):
# If the model is a 2D model, compute the average of predictions from x, y, and z orthogonal planes
predict_func = predict_ortho_planes
else:
predict_func = _predict
# Get the crops to predict on
if crops == "test":
# TODO: Could make this more general to work for any class label
raw_search_label = "test"
crop_search_label = ""
crops_paths = glob(
SEARCH_PATH.format(
dataset="*", name=CROP_NAME.format(crop="*", label="test")
)
)
else:
crop_list = crops.split(",")
assert all(
[crop.isnumeric() for crop in crop_list]
), "Crop numbers must be numeric or `test`."
crop_paths = []
raw_search_label = ""
crop_search_label = classes[0]
for crop in crop_list:
crop_paths.extend(
glob(
SEARCH_PATH.format(
dataset="*", name=CROP_NAME.format(crop=f"crop{crop}", label="")
).rstrip(os.path.sep)
)
)
crop_args = []
for crop_path in crops_paths:
# Find raw scale level
raw_path = get_raw_path(crop_path, label=raw_search_label)
raw_group = zarr.open(raw_path, mode="r")
scale_level, _, _ = find_level(raw_group, input_scale)
in_dataset = UPath(raw_path) / scale_level
assert in_dataset.exists(), f"Input dataset {in_dataset} does not exist."
roi = get_crop_roi(str(UPath(crop_path) / crop_search_label), output_scale)
crop = UPath(crop_path).stem
out_dataset = output_path.format(crop=crop, label="")
crop_args.append(
{
"in_dataset": str(in_dataset),
"roi": roi,
"out_dataset": out_dataset,
}
)
for args in crop_args:
predict_func(
model,
input_block_shape=input_block_shape,
channels=classes,
output_voxel_size=output_scale,
**args,
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
"-m",
type=str,
help="Path to a script that will load the model. Often this can be the path to the training script (such as with the examples).",
)
parser.add_argument(
"--in_dataset",
"-in",
type=str,
help="Full path to the input dataset. This dataset should contain the raw data to predict on. Example: `/path/to/test_raw.zarr/em/s0`",
)
parser.add_argument(
"--out_dataset",
"-out",
type=str,
help="Path to the output dataset, minus the class label name(s). For example, the output dataset should be `/path/to/outputs.zarr/dataset_1`.",
)
parser.add_argument(
"--input_block_shape",
"-in_shape",
type=Sequence[int],
help="Shape of the input blocks to use for prediction.",
)
parser.add_argument(
"--channels",
"-ch",
type=Sequence[str] | dict[str | int, str],
help="The label classes to predict and their corresponding channels. Specify multiple channels for a single label class as a string of the form '0-2'. Example: {'0-2':'nuc'}.",
)
parser.add_argument(
"-roi",
type=str,
default=None,
help="Region of interest to predict. Default is to use the entire ROI of the input dataset. Format is a string of the form '[start1:end1,start2:end2,...]'.",
)
parser.add_argument(
"--min_raw",
"-min",
type=float,
default=0,
help="Minimum value of the raw data. Default is 0.",
)
parser.add_argument(
"--max_raw",
"-max",
type=float,
default=255,
help="Maximum value of the raw data. Default is 255.",
)
parser.add_argument(
"--do_ortho_planes",
"-ortho",
action="store_true",
help="Whether to compute the average of predictions from x, y, and z orthogonal planes for the full 3D volume. This is sometimes called 2.5D predictions. It expects a model that yields 2D outputs. Similarly, it expects the `input_shape` to be 2D (i.e. a sequence of 2 integers).",
)
args = parser.parse_args()
model_path = args.model
model_script = UPath(model_path).stem
model_script = SourceFileLoader(model_script, str(model_path)).load_module()
model = model_script.model
in_dataset = args.in_dataset
out_dataset = args.out_dataset
input_block_shape = args.input_block_shape
channels = args.channels
roi = args.roi
min_raw = args.min_raw
max_raw = args.max_raw
do_ortho_planes = args.do_ortho
# if isinstance(channels, str):
# channels = {
# i: c for channel in channels.split(",") for i, c in channel.split(":")
# }
# elif:
# channels = {i: c for i, c in enumerate(channels)}
# parsed_channels = [channel.split(":") for channel in channels.split(",")]
print(f"Predicting on dataset {in_dataset} and saving to {out_dataset}.")
if do_ortho_planes:
# If the model is a 2D model, compute the average of predictions from x, y, and z orthogonal planes
if len(input_block_shape) > 2:
raise ValueError(
"The input shape must be 2D for computing orthogonal planes."
)
predict_ortho_planes(
model,
in_dataset,
out_dataset,
input_block_shape,
channels,
roi,
min_raw,
max_raw,
)
else:
_predict(
model,
in_dataset,
out_dataset,
input_block_shape,
channels,
roi,
min_raw,
max_raw,
)