Skip to content

Commit

Permalink
fix: cast indices tensor to int to fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MattGPT-ai committed Jan 25, 2025
1 parent 30974f2 commit 32c875b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,9 @@ def predict(

if has_unknown_label:
has_any_unknown_label = True
scores = torch.index_select(scores, 0, torch.tensor(filtered_indices, device=flair.device))
scores = torch.index_select(
scores, 0, torch.tensor(filtered_indices, device=flair.device, dtype=torch.int32)
)

gold_labels = self._prepare_label_tensor([data_points[index] for index in filtered_indices])
overall_loss += self._calculate_loss(scores, gold_labels)[0]
Expand Down

0 comments on commit 32c875b

Please sign in to comment.