Source code for cellmap_data.transforms.targets.cellpose
import torch
[docs]
class CellposeFlow:
"""
Cellpose flow transform.
Args:
ndim (int): Number of dimensions.
device (str | None, optional): Device to use. Defaults to None
(use GPU if available, else CPU).
"""
def __init__(self, ndim: int, device: str | None = None) -> None:
UserWarning("This is still in development and may not work as expected")
from cellpose.dynamics import masks_to_flows_gpu_3d, masks_to_flows
from cellpose.dynamics import masks_to_flows_gpu as masks_to_flows_gpu_2d
self.ndim = ndim
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
_device = torch.device(device)
if device == "cuda" or device == "mps":
if ndim == 3:
flows_func = lambda x: masks_to_flows_gpu_3d(x, device=_device)
elif ndim == 2:
flows_func = lambda x: masks_to_flows_gpu_2d(x, device=_device)
else:
raise ValueError(f"Unsupported dimension {ndim}")
else:
flows_func = lambda x: masks_to_flows(x, device=_device)
self.flows_func = flows_func
self.device = _device
def __call__(self, masks: torch.Tensor) -> torch.Tensor:
# flows, _ = masks_to_flows(
# (masks > 0).squeeze().numpy().astype(int), device=self.device
# )
flows, centers = self.flows_func( # type: ignore
(masks > 0).squeeze().cpu().numpy().astype(int)
)
flows = torch.tensor(flows)
flows[:, masks.isnan().squeeze()] = torch.nan
flows = flows[None, ...]
if self.ndim == 2:
flows = flows[None, ...]
return flows.to(masks.device) # type: ignore