From 192a06a9fcf1b134680de89c82aa03a7521d0eeb Mon Sep 17 00:00:00 2001 From: jamesvrt Date: Thu, 14 Dec 2023 16:07:23 -0800 Subject: [PATCH] Fix: `utils.true1d_idx` should return `np.ndarray` --- torchcast/internals/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchcast/internals/utils.py b/torchcast/internals/utils.py index ed15690..ab8572f 100644 --- a/torchcast/internals/utils.py +++ b/torchcast/internals/utils.py @@ -142,7 +142,7 @@ def true1d_idx(arr: Union[np.ndarray, torch.Tensor]) -> np.ndarray: arr = arr.bool() if len(arr.shape) > 1: raise ValueError("Expected 1d array.") - return arr.nonzero(as_tuple=True)[0] + return arr.nonzero(as_tuple=True)[0].numpy() def is_near_zero(tens: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> torch.Tensor: