From 32c875b2e8a981ebf39ce7c0a453d2eebbd1bee9 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 24 Jan 2025 23:54:11 -0800 Subject: [PATCH] fix: cast indices tensor to int to fix bug --- flair/nn/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index 54c52eba3e..04a9c24a3b 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -899,7 +899,9 @@ def predict( if has_unknown_label: has_any_unknown_label = True - scores = torch.index_select(scores, 0, torch.tensor(filtered_indices, device=flair.device)) + scores = torch.index_select( + scores, 0, torch.tensor(filtered_indices, device=flair.device, dtype=torch.int32) + ) gold_labels = self._prepare_label_tensor([data_points[index] for index in filtered_indices]) overall_loss += self._calculate_loss(scores, gold_labels)[0]