diff --git a/torchcast/internals/utils.py b/torchcast/internals/utils.py index ed15690..6c8258d 100644 --- a/torchcast/internals/utils.py +++ b/torchcast/internals/utils.py @@ -136,7 +136,7 @@ def ragged_cat(tensors: Sequence[torch.Tensor], @torch.no_grad() -def true1d_idx(arr: Union[np.ndarray, torch.Tensor]) -> np.ndarray: +def true1d_idx(arr: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: if not isinstance(arr, torch.Tensor): arr = torch.as_tensor(arr) arr = arr.bool()