Source code for cellmap_segmentation_challenge.predict

import os
import tempfile
from glob import glob
from typing import Any

import torch
import torchvision.transforms.v2 as T
from cellmap_data import CellMapDatasetWriter, CellMapImage
from cellmap_data.transforms.augment import NaNtoNum, Normalize
from tqdm import tqdm
from upath import UPath

from .config import CROP_NAME, PREDICTIONS_PATH, RAW_NAME, SEARCH_PATH
from .models import load_best_val, load_latest
from .utils import load_safe_config, get_test_crops
from .utils.datasplit import get_formatted_fields, get_raw_path


[docs] def predict_orthoplanes( model: torch.nn.Module, dataset_writer_kwargs: dict[str, Any], batch_size: int ): print("Predicting orthogonal planes.") # Make a temporary prediction for each axis tmp_dir = tempfile.TemporaryDirectory() print(f"Temporary directory for predictions: {tmp_dir.name}") for axis in range(3): temp_kwargs = dataset_writer_kwargs.copy() temp_kwargs["target_path"] = os.path.join( tmp_dir.name, "output.zarr", str(axis) ) _predict( model, temp_kwargs, batch_size=batch_size, ) # Get dataset writer for the average of predictions from x, y, and z orthogonal planes # TODO: Skip loading raw data dataset_writer = CellMapDatasetWriter(**dataset_writer_kwargs) # Load the images for the individual predictions single_axis_images = { array_name: { label: [ CellMapImage( os.path.join(tmp_dir.name, "output.zarr", str(axis), label), target_class=label, target_scale=array_info["scale"], target_voxel_shape=array_info["shape"], pad=True, pad_value=0, ) for axis in range(3) ] for label in dataset_writer_kwargs["classes"] } for array_name, array_info in dataset_writer_kwargs["target_arrays"].items() } # Combine the predictions from the x, y, and z orthogonal planes print("Combining predictions.") for batch in tqdm(dataset_writer.loader(batch_size=batch_size), dynamic_ncols=True): # For each class, get the predictions from the x, y, and z orthogonal planes outputs = {} for array_name, images in single_axis_images.items(): outputs[array_name] = {} for label in dataset_writer_kwargs["classes"]: outputs[array_name][label] = [] for idx in batch["idx"]: average_prediction = [] for image in images[label]: average_prediction.append(image[dataset_writer.get_center(idx)]) average_prediction = torch.stack(average_prediction).mean(dim=0) outputs[array_name][label].append(average_prediction) outputs[array_name][label] = torch.stack(outputs[array_name][label]) # Save the outputs dataset_writer[batch["idx"]] = outputs tmp_dir.cleanup()
def _predict( model: torch.nn.Module, dataset_writer_kwargs: dict[str, Any], batch_size: int ): """ Predicts the output of a model on a large dataset by splitting it into blocks and predicting each block separately. Parameters ---------- model : torch.nn.Module The model to use for prediction. dataset_writer_kwargs : dict[str, Any] A dictionary containing the arguments for the dataset writer. batch_size : int The batch size to use for prediction """ value_transforms = T.Compose( [ Normalize(), T.ToDtype(torch.float, scale=True), NaNtoNum({"nan": 0, "posinf": None, "neginf": None}), ], ) dataset_writer = CellMapDatasetWriter( **dataset_writer_kwargs, raw_value_transforms=value_transforms ) dataloader = dataset_writer.loader(batch_size=batch_size) model.eval() with torch.no_grad(): for batch in tqdm(dataloader, dynamic_ncols=True): # Get the inputs and outputs inputs = batch["input"] outputs = model(inputs) outputs = {"output": model(inputs)} # Save the outputs dataset_writer[batch["idx"]] = outputs
[docs] def predict( config_path: str, crops: str = "test", output_path: str = PREDICTIONS_PATH, do_orthoplanes: bool = True, overwrite: bool = False, search_path: str = SEARCH_PATH, raw_name: str = RAW_NAME, crop_name: str = CROP_NAME, ): """ Given a model configuration file and list of crop numbers, predicts the output of a model on a large dataset by splitting it into blocks and predicting each block separately. Parameters ---------- config_path : str The path to the model configuration file. This can be the same as the config file used for training. crops: str, optional A comma-separated list of crop numbers to predict on, or "test" to predict on the entire test set. Default is "test". output_path: str, optional The path to save the output predictions to, formatted as a string with a placeholders for the dataset, crop number, and label. Default is PREDICTIONS_PATH set in `cellmap-segmentation/config.py`. do_orthoplanes: bool, optional Whether to compute the average of predictions from x, y, and z orthogonal planes for the full 3D volume. This is sometimes called 2.5D predictions. It expects a model that yields 2D outputs. Similarly, it expects the input shape to the model to be 2D. Default is True for 2D models. overwrite: bool, optional Whether to overwrite the output dataset if it already exists. Default is False. search_path: str, optional The path to search for the raw dataset, with placeholders for dataset and name. Default is SEARCH_PATH set in `cellmap-segmentation/config.py`. raw_name: str, optional The name of the raw dataset. Default is RAW_NAME set in `cellmap-segmentation/config.py`. crop_name: str, optional The name of the crop dataset with placeholders for crop and label. Default is CROP_NAME set in `cellmap-segmentation/config.py`. """ config = load_safe_config(config_path) classes = config.classes batch_size = getattr(config, "batch_size", 8) input_array_info = getattr( config, "input_array_info", {"shape": (1, 128, 128), "scale": (8, 8, 8)} ) target_array_info = getattr(config, "target_array_info", input_array_info) model_name = getattr(config, "model_name", "2d_unet") model_to_load = getattr(config, "model_to_load", model_name) model = config.model load_model = getattr(config, "load_model", "latest") model_save_path = getattr( config, "model_save_path", UPath("checkpoints/{model_name}_{epoch}.pth").path ) logs_save_path = getattr( config, "logs_save_path", UPath("tensorboard/{model_name}").path ) # %% Check that the GPU is available if getattr(config, "device", None) is not None: device = config.device elif torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print(f"Prediction device: {device}") # %% Move model to device model = model.to(device) # Optionally, load a pre-trained model if load_model.lower() == "latest": # Check to see if there are any checkpoints and if so load the latest one # Use the command below for loading the latest model, otherwise comment it out load_latest(model_save_path.format(epoch="*", model_name=model_to_load), model) elif load_model.lower() == "best": # Load the checkpoint with the best validation score # Use the command below for loading the epoch with the best validation score, otherwise comment it out load_best_val( logs_save_path.format(model_name=model_to_load), model_save_path.format(epoch="{epoch}", model_name=model_to_load), model, ) if do_orthoplanes and any([s == 1 for s in input_array_info["shape"]]): # If the model is a 2D model, compute the average of predictions from x, y, and z orthogonal planes predict_func = predict_orthoplanes else: predict_func = _predict input_arrays = {"input": input_array_info} target_arrays = {"output": target_array_info} assert ( input_arrays is not None and target_arrays is not None ), "No array info provided" # Get the crops to predict on if crops == "test": test_crops = get_test_crops() dataset_writers = [] for crop in test_crops: # Get path to raw dataset raw_path = search_path.format(dataset=crop.dataset, name=raw_name) # Get the boundaries of the crop target_bounds = { "output": { axis: [ crop.gt_source.translation[i], crop.gt_source.translation[i] + crop.gt_source.voxel_size[i] * crop.gt_source.shape[i], ] for i, axis in enumerate("zyx") }, } # Create the writer dataset_writers.append( { "raw_path": raw_path, "target_path": output_path.format( crop=f"crop{crop.id}", dataset=crop.dataset, ), "classes": classes, "input_arrays": input_arrays, "target_arrays": target_arrays, "target_bounds": target_bounds, "overwrite": overwrite, "device": device, } ) else: crop_list = crops.split(",") crop_paths = [] for i, crop in enumerate(crop_list): if (isinstance(crop, str) and crop.isnumeric()) or isinstance(crop, int): crop = f"crop{crop}" crop_list[i] = crop # type: ignore crop_paths.extend( glob( search_path.format( dataset="*", name=crop_name.format(crop=crop, label="") ).rstrip(os.path.sep) ) ) dataset_writers = [] for crop, crop_path in zip(crop_list, crop_paths): # type: ignore # Get path to raw dataset raw_path = get_raw_path(crop_path, label="") # Get the boundaries of the crop gt_images = { array_name: CellMapImage( str(UPath(crop_path) / classes[0]), target_class=classes[0], target_scale=array_info["scale"], target_voxel_shape=array_info["shape"], pad=True, pad_value=0, ) for array_name, array_info in target_arrays.items() } target_bounds = { array_name: image.bounding_box for array_name, image in gt_images.items() } dataset = get_formatted_fields(raw_path, search_path, ["{dataset}"])[ "dataset" ] # Create the writer dataset_writers.append( { "raw_path": raw_path, "target_path": output_path.format(crop=crop, dataset=dataset), "classes": classes, "input_arrays": input_arrays, "target_arrays": target_arrays, "target_bounds": target_bounds, "overwrite": overwrite, "device": device, } ) for dataset_writer in dataset_writers: predict_func(model, dataset_writer, batch_size)