Source code for cellmap_data.transforms.augment.binarize

from typing import Any, Dict
import torchvision.transforms.v2 as T
import torch


[docs] class Binarize(T.Transform): """Binarize the input tensor. Subclasses torchvision.transforms.Transform. Methods: _transform: Transform the input. """ def __init__(self, threshold=0) -> None: """Initialize the normalization transformation.""" super().__init__() self.threshold = threshold def _transform(self, x: Any, params: Dict[str, Any] | None = None) -> Any: """Transform the input.""" out = (x > self.threshold).to(x.dtype) out[x.isnan()] *= torch.nan return out def __repr__(self) -> str: """Return a string representation of the transformation.""" return f"{self.__class__.__name__}(threshold={self.threshold})"
[docs] def transform(self, x: Any, params: Dict[str, Any] | None = None) -> Any: """Transform the input.""" return self._transform(x, params)