Source code for cellmap_data.transforms.targets.distance
# from py_distance_transforms import transform_cuda, transform
import torch
from scipy.ndimage import distance_transform_edt as edt
[docs]
def transform(x: torch.Tensor) -> torch.Tensor:
return torch.tensor(edt(x.cpu().numpy())).to(x.device)
[docs]
class DistanceTransform(torch.nn.Module):
"""
Compute the distance transform of the input.
Attributes:
use_cuda (bool): Use CUDA.
clip (list): Clip the output to the specified range.
Methods:
_transform: Transform the input.
forward: Forward pass.
"""
def __init__(self, use_cuda: bool = False, clip=[-torch.inf, torch.inf]) -> None:
"""
Initialize the distance transform.
Args:
use_cuda (bool, optional): Use CUDA. Defaults to False.
clip (list, optional): Clip the output to the specified range. Defaults to [-torch.inf, torch.inf].
Raises:
NotImplementedError: CUDA is not supported yet.
"""
UserWarning("This is still in development and may not work as expected")
super().__init__()
self.use_cuda = use_cuda
self.clip = clip
if self.use_cuda:
raise NotImplementedError(
"CUDA is not supported yet because testing did not return expected results."
)
def _transform(self, x: torch.Tensor) -> torch.Tensor:
"""Transform the input."""
if self.use_cuda and x.device.type == "cuda":
raise NotImplementedError(
"CUDA is not supported yet because testing did not return expected results."
)
# return transform_cuda(x)
else:
return transform(x).clip(self.clip[0], self.clip[1])
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
# TODO: Need to figure out how to prevent having inaccurate distance values at the edges --> precompute
# distance = self._transform(x.nan_to_num(0))
distance = self._transform(x)
distance[x.isnan()] = torch.nan
x = distance
return x
[docs]
class SignedDistanceTransform(torch.nn.Module):
"""
Compute the signed distance transform of the input - positive within objects and negative outside.
Attributes:
use_cuda (bool): Use CUDA.
clip (list): Clip the output to the specified range.
Methods:
_transform: Transform the input.
forward: Forward pass.
"""
def __init__(self, use_cuda: bool = False, clip=[-torch.inf, torch.inf]) -> None:
"""
Initialize the signed distance transform.
Args:
use_cuda (bool, optional): Use CUDA. Defaults to False.
clip (list, optional): Clip the output to the specified range. Defaults to [-torch.inf, torch.inf].
Raises:
NotImplementedError: CUDA is not supported yet.
"""
UserWarning("This is still in development and may not work as expected")
super().__init__()
self.use_cuda = use_cuda
self.clip = clip
if self.use_cuda:
raise NotImplementedError(
"CUDA is not supported yet because testing did not return expected results."
)
def _transform(self, x: torch.Tensor) -> torch.Tensor:
"""Transform the input."""
if self.use_cuda and x.device.type == "cuda":
raise NotImplementedError(
"CUDA is not supported yet because testing did not return expected results."
)
# return transform_cuda(x) - transform_cuda(x.logical_not())
else:
# TODO: Fix this to be correct
return transform(x).clip(self.clip[0], self.clip[1]) - transform(
x.logical_not()
).clip(self.clip[0], self.clip[1])
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
# TODO: Need to figure out how to prevent having inaccurate distance values at the edges --> precompute
# distance = self._transform(x.nan_to_num(0))
distance = self._transform(x)
distance[x.isnan()] = torch.nan
x = distance
return x