From 80a3ace0562f384834fe092d6bf7c0ba1cb6cdbe Mon Sep 17 00:00:00 2001 From: Mingxue Gu Date: Fri, 12 Jul 2024 03:41:37 +0000 Subject: [PATCH] fix data type --- scripts/utils/trans_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/utils/trans_utils.py b/scripts/utils/trans_utils.py index c6dc0c0..8aa4404 100644 --- a/scripts/utils/trans_utils.py +++ b/scripts/utils/trans_utils.py @@ -349,7 +349,7 @@ def __call__( pred += 0.5 # inplace mapping to avoid cloning pred for i in range(1, object_num + 1): frac = i + 0.5 - pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) + pred[pred == frac] = torch.tensor(data["label_prompt"][i - 1]).to(pred.dtype) pred[pred == 0.5] = 0.0 data[keys] = pred return data