Source code for cellmap_data.transforms.augment.nan_to_num
from typing import Any, Dict
import torchvision.transforms.v2 as T
[docs]
class NaNtoNum(T.Transform):
"""Replace NaNs with zeros in the input tensor. Subclasses torchvision.transforms.Transform.
Attributes:
params (Dict[str, Any]): Parameters for the transformation. Defaults to {}, see https://pytorch.org/docs/stable/generated/torch.nan_to_num.html for details.
Methods:
_transform: Transform the input.
"""
def __init__(self, params: Dict[str, Any]) -> None:
"""Initialize the NaN to number transformation.
Args:
params (Dict[str, Any]): Parameters for the transformation. Defaults to {}, see https://pytorch.org/docs/stable/generated/torch.nan_to_num.html for details.
"""
super().__init__()
self.params = params
def _transform(self, x: Any, params: Dict[str, Any]) -> Any:
"""Transform the input."""
return x.nan_to_num(**self.params)