Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add focal loss option for classification #9

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions spanet/network/jet_reconstruction/jet_reconstruction_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,35 @@ def add_classification_loss(
current_target = targets[key]

weight = None if not self.balance_classifications else self.classification_weights[key]
current_loss = F.cross_entropy(
current_prediction,
current_target,
ignore_index=-1,
weight=weight
)
if self.options.classification_focal_gamma == 0:
current_loss = F.cross_entropy(
current_prediction,
current_target,
ignore_index=-1,
weight=weight
)
else:
# From https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
log_p = F.log_softmax(current_prediction, dim=1)
ce = F.nll_loss(
log_p,
current_target,
ignore_index=-1,
weight=weight,
reduction='none'
)
# Get true class column from each row
all_rows = torch.arange(len(current_target))
log_pt = log_p[all_rows, current_target]
# Compute focal term: (1 - pt)^gamma
focal_term = (1 - log_pt.exp()) ** self.options.classification_focal_gamma
# Full loss: -alpha * ((1 - pt)^gamma) * log(pt)
if weight is None:
# Take mean
current_loss = torch.mean(focal_term * ce)
else:
# Divide by sum of class weights
current_loss = torch.sum(focal_term * ce) / weight[current_target].sum()

classification_terms.append(self.options.classification_loss_scale * current_loss)

Expand Down
3 changes: 3 additions & 0 deletions spanet/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(self, event_info_file: str = "", training_file: str = "", validatio
# Scalar term for classification Cross Entropy loss term
self.classification_loss_scale: float = 0.0

# Gamma exponent for classification focal loss. Setting it to 0.0 will disable focal loss and use regular cross-entropy.
self.classification_focal_gamma: float = 0.0

# Automatically balance loss terms using Jacobians.
self.balance_losses: bool = True

Expand Down