Source code for cellmap_segmentation_challenge.utils.visualize

from typing import Sequence
import torch
import os
import matplotlib.pyplot as plt
from cellmap_data.utils import get_image_dict
from upath import UPath


[docs] def save_result_figs( inputs: torch.Tensor, outputs: torch.Tensor, targets: torch.Tensor, classes: Sequence[str], figures_save_path: str, ): """ Save the input, output, and target images to the specified path. Parameters ---------- inputs : torch.Tensor The input images. outputs : torch.Tensor The output images. targets : torch.Tensor The target images. classes : Sequence[str] The classes present in the images. figures_save_path : str The path to save the figures to. """ # Make sure the save path exists os.makedirs(os.path.dirname(figures_save_path), exist_ok=True) figs = get_image_dict(inputs, outputs, targets, classes) for label, fig in figs.items(): fig.savefig(UPath(figures_save_path.format(label=label)).path) plt.close(fig)