Skip to content

Commit

Permalink
Merge pull request #3601 from ZipRecruiter/mattb.fix.classifier-score…
Browse files Browse the repository at this point in the history
…-index-bug

fix: cast indices tensor to int to fix bug
  • Loading branch information
alanakbik authored Jan 27, 2025
2 parents 087e441 + 32c875b commit e2865f7
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,9 @@ def predict(

if has_unknown_label:
has_any_unknown_label = True
scores = torch.index_select(scores, 0, torch.tensor(filtered_indices, device=flair.device))
scores = torch.index_select(
scores, 0, torch.tensor(filtered_indices, device=flair.device, dtype=torch.int32)
)

gold_labels = self._prepare_label_tensor([data_points[index] for index in filtered_indices])
overall_loss += self._calculate_loss(scores, gold_labels)[0]
Expand Down

0 comments on commit e2865f7

Please sign in to comment.