Skip to content

Commit

Permalink
Merge pull request #2 from AdeelH/ignore_index
Browse files Browse the repository at this point in the history
Fix ignore_index bug
  • Loading branch information
AdeelH authored Oct 12, 2020
2 parents de74657 + 22a239b commit f0d5355
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(self,
self.nll_loss = nn.NLLLoss(
weight=alpha, reduction='none', ignore_index=ignore_index)

self.ignore_index = ignore_index

if reduction in ('mean', 'sum', 'none'):
self.reduction = reduction
else:
Expand All @@ -49,6 +51,12 @@ def forward(self, x: Tensor, y: Tensor) -> Tensor:
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
y = y.view(-1)

unignored_mask = y != self.ignore_index
y = y[unignored_mask]
if len(y) == 0:
return 0.
x = x[unignored_mask]

# compute weighted cross entropy term: -alpha * log(pt)
log_p = F.log_softmax(x, dim=-1)
ce = self.nll_loss(log_p, y)
Expand Down

0 comments on commit f0d5355

Please sign in to comment.