Source code for cellmap_data.transforms.augment.gaussian_noise
import torch
[docs]
class GaussianNoise(torch.nn.Module):
"""
Add Gaussian noise to the input. Subclasses torch.nn.Module.
Attributes:
mean (float): Mean of the noise.
std (float): Standard deviation of the noise.
Methods:
forward: Forward pass.
"""
def __init__(self, mean: float = 0.0, std: float = 1.0) -> None:
"""
Initialize the Gaussian noise.
Args:
mean (float, optional): Mean of the noise. Defaults to 0.0.
std (float, optional): Standard deviation of the noise. Defaults to 1.0.
"""
super().__init__()
self.mean = mean
self.std = std
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
noise = torch.normal(mean=self.mean, std=self.std, size=x.size())
return x + noise.to(x.device)