Source code for cellmap_data.transforms.augment.random_contrast

from typing import Sequence
import torch
from cellmap_data.utils import torch_max_value


[docs] class RandomContrast(torch.nn.Module): """ Randomly change the contrast of the input. Attributes: contrast_range (tuple): Contrast range. Methods: forward: Forward pass. """ def __init__(self, contrast_range: Sequence[float] = (0.5, 1.5)) -> None: """ Initialize the random contrast. Args: contrast_range (tuple, optional): Contrast range. Defaults to (0.5, 1.5). """ super().__init__() self.contrast_range = contrast_range
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" ratio = float( torch.rand(1) * (self.contrast_range[1] - self.contrast_range[0]) + self.contrast_range[0] ) bound = torch_max_value(x.dtype) result = ( (ratio * x + (1.0 - ratio) * x.mean(dim=0, keepdim=True)) .clamp(0, bound) .to(x.dtype) ) # Hack to avoid NaNs torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0, out=x) return result