Source code for cellmap_data.transforms.augment.gaussian_blur

import torch
import torch.nn.functional as F


[docs] class GaussianBlur(torch.nn.Module): def __init__( self, kernel_size: int = 3, sigma: float = 0.1, dim: int = 2, channels: int = 1 ): """ Initialize a Gaussian Blur module. Args: kernel_size (int): Size of the Gaussian kernel (should be odd). sigma (float): Standard deviation of the Gaussian distribution. dim (int): Dimensionality (2 or 3) for applying the blur. channels (int): Number of input channels (default is 1). """ super().__init__() assert dim in (2, 3), "Only 2D or 3D Gaussian blur is supported." assert kernel_size % 2 == 1, "Kernel size should be an odd number." self.kernel_size = kernel_size self.kernel_shape = (kernel_size,) * dim self.sigma = sigma self.dim = dim self.kernel = self._create_gaussian_kernel() self.conv = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}[dim]( in_channels=channels, out_channels=channels, kernel_size=self.kernel_shape, bias=False, padding="same", # Automatically pads to keep output size same as input groups=channels, # Apply the same kernel to each channel independently padding_mode="replicate", # Use 'replicate' padding to avoid artifacts ) kernel = self.kernel.view(1, 1, *self.kernel_shape) kernel = kernel.repeat(channels, 1, *(1,) * self.dim) self.conv.weight.data = kernel self.conv.weight.requires_grad = False # Freeze the kernel weights def _create_gaussian_kernel(self): """Create a Gaussian kernel for 2D or 3D convolution.""" coords = torch.arange(self.kernel_size) - self.kernel_size // 2 axes_coords = torch.meshgrid(*[[coords] * self.dim], indexing="ij") kernel = torch.exp( -torch.sum(torch.stack([coord**2 for coord in axes_coords]), dim=0) / (2 * self.sigma**2) ) kernel /= kernel.sum() # Normalize return kernel
[docs] def forward(self, x: torch.Tensor): """Apply Gaussian blur to the input tensor.""" self.conv.to(x.device, non_blocking=True) if len(x.shape) == self.dim: # For 2D or 3D input without batch dimension x = x.view(1, *x.shape) # Add batch dimension out = self.conv(x.view(1, *x.shape).to(torch.float)) out = out.view(*x.shape) # Remove batch dimension else: out = self.conv(x.to(torch.float)) return out
# # Example usage # image_2d = torch.rand(4, 3, 128, 128) # Batch of 2D images with 3 channels # image_3d = torch.rand(2, 3, 32, 32, 32) # Batch of 3D volumes with 3 channels # blur_2d = GaussianBlur(kernel_size=5, sigma=1.0, dim=2) # blur_3d = GaussianBlur(kernel_size=5, sigma=1.0, dim=3) # blurred_2d = blur_2d(image_2d) # blurred_3d = blur_3d(image_3d)