Skip to content

Commit

Permalink
Add classes/ignore_index to losses (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley authored Feb 27, 2024
1 parent 3401b84 commit bdea2e8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
13 changes: 13 additions & 0 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torchseg.losses._functional as F
from torchseg.losses import (
DiceLoss,
FocalLoss,
JaccardLoss,
MCCLoss,
SoftBCEWithLogitsLoss,
Expand Down Expand Up @@ -333,3 +334,15 @@ def test_binary_mcc_loss():

loss = criterion(y_pred, y_true)
assert float(loss) == pytest.approx(0.5, abs=eps)


@torch.no_grad()
@pytest.mark.parametrize("loss_fn", [DiceLoss, JaccardLoss, FocalLoss])
@pytest.mark.parametrize("classes", [None, [1]])
@pytest.mark.parametrize("ignore_index", [None, 0, -255])
def test_classes_arg(loss_fn, classes, ignore_index):
criterion = loss_fn(mode="multiclass", classes=classes, ignore_index=ignore_index)
y_pred = torch.zeros(1, 2, 128, 128, dtype=torch.float)
y_pred[:, 0, ...] = 1.0
y_true = torch.ones(1, 128, 128, dtype=torch.long)
criterion(y_pred, y_true)
17 changes: 11 additions & 6 deletions torchseg/losses/focal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
reduction: Optional[str] = "mean",
normalized: bool = False,
reduced_threshold: Optional[float] = None,
classes: Optional[list[int]] = None,
):
"""Compute Focal loss
Expand All @@ -30,6 +31,8 @@ def __init__(
normalized: Use normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf)
reduced_threshold: Switch to reduced focal loss.
Note, when using this mode you should use `reduction="sum"`.
classes: List of classes that contribute in loss computation.
By default, all channels are included. Only supported in multiclass mode
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
Expand All @@ -44,6 +47,7 @@ def __init__(

self.mode = mode
self.ignore_index = ignore_index
self.classes = classes
self.focal_loss_fn = partial(
focal_loss_with_logits,
alpha=alpha,
Expand Down Expand Up @@ -75,13 +79,14 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
not_ignored = y_true != self.ignore_index

for cls in range(num_classes):
cls_y_true = (y_true == cls).long()
cls_y_pred = y_pred[:, cls, ...]
if self.classes is None or cls in self.classes:
cls_y_true = (y_true == cls).long()
cls_y_pred = y_pred[:, cls, ...]

if self.ignore_index is not None:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]
if self.ignore_index is not None:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]

loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
loss += self.focal_loss_fn(cls_y_pred, cls_y_true)

return loss
27 changes: 25 additions & 2 deletions torchseg/losses/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
):
"""Jaccard loss for image segmentation task.
Expand All @@ -29,6 +30,8 @@ def __init__(
otherwise `1 - jaccard_coeff`
from_logits: If True, assumes input is raw logits
smooth: Smoothness constant for dice coefficient
ignore_index: Label that indicates ignored pixels
(does not contribute to loss)
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
Expand All @@ -53,6 +56,7 @@ def __init__(
self.from_logits = from_logits
self.smooth = smooth
self.eps = eps
self.ignore_index = ignore_index
self.log_loss = log_loss

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
Expand All @@ -76,17 +80,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_true = y_true.view(bs, 1, -1)
y_pred = y_pred.view(bs, 1, -1)

if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask

if self.mode == MULTICLASS_MODE:
y_true = y_true.view(bs, -1)
y_pred = y_pred.view(bs, num_classes, -1)

y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # H, C, H*W
if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask.unsqueeze(1)

y_true = F.one_hot(
(y_true * mask).to(torch.long), num_classes
) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
else:
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
y_true = y_true.permute(0, 2, 1) # H, C, H*W

if self.mode == MULTILABEL_MODE:
y_true = y_true.view(bs, num_classes, -1)
y_pred = y_pred.view(bs, num_classes, -1)

if self.ignore_index is not None:
mask = y_true != self.ignore_index
y_pred = y_pred * mask
y_true = y_true * mask

scores = soft_jaccard_score(
y_pred,
y_true.type(y_pred.dtype),
Expand Down

0 comments on commit bdea2e8

Please sign in to comment.