From bdea2e86ab881127468f8cf7a39c69ebc64c731c Mon Sep 17 00:00:00 2001 From: Isaac Corley <22203655+isaaccorley@users.noreply.github.com> Date: Tue, 27 Feb 2024 07:59:00 -0600 Subject: [PATCH] Add classes/ignore_index to losses (#11) --- tests/test_losses.py | 13 +++++++++++++ torchseg/losses/focal.py | 17 +++++++++++------ torchseg/losses/jaccard.py | 27 +++++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/tests/test_losses.py b/tests/test_losses.py index 941b7359..74d07b3b 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -5,6 +5,7 @@ import torchseg.losses._functional as F from torchseg.losses import ( DiceLoss, + FocalLoss, JaccardLoss, MCCLoss, SoftBCEWithLogitsLoss, @@ -333,3 +334,15 @@ def test_binary_mcc_loss(): loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.5, abs=eps) + + +@torch.no_grad() +@pytest.mark.parametrize("loss_fn", [DiceLoss, JaccardLoss, FocalLoss]) +@pytest.mark.parametrize("classes", [None, [1]]) +@pytest.mark.parametrize("ignore_index", [None, 0, -255]) +def test_classes_arg(loss_fn, classes, ignore_index): + criterion = loss_fn(mode="multiclass", classes=classes, ignore_index=ignore_index) + y_pred = torch.zeros(1, 2, 128, 128, dtype=torch.float) + y_pred[:, 0, ...] = 1.0 + y_true = torch.ones(1, 128, 128, dtype=torch.long) + criterion(y_pred, y_true) diff --git a/torchseg/losses/focal.py b/torchseg/losses/focal.py index 6677d47d..d0974483 100644 --- a/torchseg/losses/focal.py +++ b/torchseg/losses/focal.py @@ -18,6 +18,7 @@ def __init__( reduction: Optional[str] = "mean", normalized: bool = False, reduced_threshold: Optional[float] = None, + classes: Optional[list[int]] = None, ): """Compute Focal loss @@ -30,6 +31,8 @@ def __init__( normalized: Use normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf) reduced_threshold: Switch to reduced focal loss. Note, when using this mode you should use `reduction="sum"`. + classes: List of classes that contribute in loss computation. + By default, all channels are included. Only supported in multiclass mode Shape - **y_pred** - torch.Tensor of shape (N, C, H, W) @@ -44,6 +47,7 @@ def __init__( self.mode = mode self.ignore_index = ignore_index + self.classes = classes self.focal_loss_fn = partial( focal_loss_with_logits, alpha=alpha, @@ -75,13 +79,14 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: not_ignored = y_true != self.ignore_index for cls in range(num_classes): - cls_y_true = (y_true == cls).long() - cls_y_pred = y_pred[:, cls, ...] + if self.classes is None or cls in self.classes: + cls_y_true = (y_true == cls).long() + cls_y_pred = y_pred[:, cls, ...] - if self.ignore_index is not None: - cls_y_true = cls_y_true[not_ignored] - cls_y_pred = cls_y_pred[not_ignored] + if self.ignore_index is not None: + cls_y_true = cls_y_true[not_ignored] + cls_y_pred = cls_y_pred[not_ignored] - loss += self.focal_loss_fn(cls_y_pred, cls_y_true) + loss += self.focal_loss_fn(cls_y_pred, cls_y_true) return loss diff --git a/torchseg/losses/jaccard.py b/torchseg/losses/jaccard.py index 22abe176..9aea149f 100644 --- a/torchseg/losses/jaccard.py +++ b/torchseg/losses/jaccard.py @@ -16,6 +16,7 @@ def __init__( log_loss: bool = False, from_logits: bool = True, smooth: float = 0.0, + ignore_index: Optional[int] = None, eps: float = 1e-7, ): """Jaccard loss for image segmentation task. @@ -29,6 +30,8 @@ def __init__( otherwise `1 - jaccard_coeff` from_logits: If True, assumes input is raw logits smooth: Smoothness constant for dice coefficient + ignore_index: Label that indicates ignored pixels + (does not contribute to loss) eps: A small epsilon for numerical stability to avoid zero division error (denominator will be always greater or equal to eps) @@ -53,6 +56,7 @@ def __init__( self.from_logits = from_logits self.smooth = smooth self.eps = eps + self.ignore_index = ignore_index self.log_loss = log_loss def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: @@ -76,17 +80,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_true = y_true.view(bs, 1, -1) y_pred = y_pred.view(bs, 1, -1) + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask + y_true = y_true * mask + if self.mode == MULTICLASS_MODE: y_true = y_true.view(bs, -1) y_pred = y_pred.view(bs, num_classes, -1) - y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C - y_true = y_true.permute(0, 2, 1) # H, C, H*W + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask.unsqueeze(1) + + y_true = F.one_hot( + (y_true * mask).to(torch.long), num_classes + ) # N,H*W -> N,H*W, C + y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W + else: + y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C + y_true = y_true.permute(0, 2, 1) # H, C, H*W if self.mode == MULTILABEL_MODE: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask + y_true = y_true * mask + scores = soft_jaccard_score( y_pred, y_true.type(y_pred.dtype),