From 92871d51daa18f5ce20961ec66a68eba139d6e40 Mon Sep 17 00:00:00 2001 From: paspf Date: Mon, 28 Oct 2024 15:58:15 +0100 Subject: [PATCH] Use PyTorchs.to() function for Tensor dtype conversion. --- segment_anything/predictor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/segment_anything/predictor.py b/segment_anything/predictor.py index 8a6e6d816..651092630 100644 --- a/segment_anything/predictor.py +++ b/segment_anything/predictor.py @@ -161,8 +161,8 @@ def predict( ) masks_np = masks[0].detach().cpu().numpy() - iou_predictions_np = iou_predictions[0].detach().cpu().numpy() - low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().to(torch.float32).numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().to(torch.float32).numpy() return masks_np, iou_predictions_np, low_res_masks_np @torch.no_grad()