Source code for cellmap_data.transforms.augment.gaussian_blur
import torch
import torch.nn.functional as F
[docs]
class GaussianBlur(torch.nn.Module):
def __init__(self, kernel_size: int = 3, sigma: float = 0.1, dim: int = 2):
"""
Initialize a Gaussian Blur module.
Args:
kernel_size (int): Size of the Gaussian kernel (should be odd).
sigma (float): Standard deviation of the Gaussian distribution.
dim (int): Dimensionality (2 or 3) for applying the blur.
"""
super().__init__()
assert dim in (2, 3), "Only 2D or 3D Gaussian blur is supported."
assert kernel_size % 2 == 1, "Kernel size should be an odd number."
self.kernel_size = kernel_size
self.kernel_shape = (kernel_size,) * dim
self.sigma = sigma
self.dim = dim
self.kernel = self._create_gaussian_kernel()
padding = self.kernel_size // 2
if dim == 2:
self.conv = lambda x, kernel: F.conv2d(
x, kernel, padding=padding, groups=x.shape[1]
)
else:
self.conv = lambda x, kernel: F.conv3d(
x, kernel, padding=padding, groups=x.shape[1]
)
def _create_gaussian_kernel(self):
"""Create a Gaussian kernel for 2D or 3D convolution."""
coords = torch.arange(self.kernel_size) - self.kernel_size // 2
axes_coords = torch.meshgrid(*[[coords] * self.dim], indexing="ij")
kernel = torch.exp(
-torch.sum(torch.stack([coord**2 for coord in axes_coords]), dim=0)
/ (2 * self.sigma**2)
)
kernel /= kernel.sum() # Normalize
return kernel
[docs]
def forward(self, x: torch.Tensor):
"""Apply Gaussian blur to the input tensor."""
device = x.device
kernel = self.kernel.to(device)
# Add batch and channel dimensions
kernel = kernel.view(1, 1, *self.kernel_shape)
# Repeat for all channels
kernel = kernel.repeat(x.shape[1], 1, *(1,) * self.dim)
return self.conv(x, kernel)
if __name__ == "__main__":
# Example usage
image_2d = torch.rand(4, 3, 128, 128) # Batch of 2D images with 3 channels
image_3d = torch.rand(2, 3, 32, 32, 32) # Batch of 3D volumes with 3 channels
blur_2d = GaussianBlur(kernel_size=5, sigma=1.0, dim=2)
blur_3d = GaussianBlur(kernel_size=5, sigma=1.0, dim=3)
blurred_2d = blur_2d(image_2d)
blurred_3d = blur_3d(image_3d)