Source code for cellmap_segmentation_challenge.utils.loss

import torch


[docs] class CellMapLossWrapper(torch.nn.modules.loss._Loss): """ Wrapper for any PyTorch loss function that is applied to the output of a model and the target. Because the target can contain NaN values, the loss function is applied only to the non-NaN values. This is done by multiplying the loss by a mask that is 1 where the target is not NaN and 0 where the target is NaN. The loss is then averaged across the non-NaN values. Parameters ---------- loss_fn : torch.nn.modules.loss._Loss or torch.nn.modules.loss._WeightedLoss The loss function to apply to the output and target. **kwargs Keyword arguments to pass to the loss function. """ def __init__( self, loss_fn: torch.nn.modules.loss._Loss | torch.nn.modules.loss._WeightedLoss, **kwargs, ): super().__init__() self.kwargs = kwargs self.kwargs["reduction"] = "none" self.loss_fn = loss_fn(**self.kwargs)
[docs] def forward(self, outputs: torch.Tensor, target: torch.Tensor): loss = self.loss_fn(outputs, target.nan_to_num(0)) loss = (loss * target.isnan().logical_not()).nanmean() return loss