Source code for cellmap_data.transforms.augment.normalize

from typing import Any, Dict
import torchvision.transforms.v2 as T


[docs] class Normalize(T.Transform): """Normalize the input tensor. Subclasses torchvision.transforms.Transform. Methods: _transform: Transform the input. """ def __init__(self) -> None: """Initialize the normalization transformation.""" super().__init__() def _transform(self, x: Any, params: Dict[str, Any]) -> Any: """Transform the input.""" min_val = x.nan_to_num().min() diff = x.nan_to_num().max() - min_val if diff == 0: return x else: return (x - min_val) / diff