diff --git a/pyproject.toml b/pyproject.toml index 6da7c77..2d5c234 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ authors = [{name = "Francesco Bruzzesi"}] dependencies = [ "numpy", - "pandas", + "narwhals>=0.7.12", "typing-extensions>=4.4.0; python_version < '3.11'", ] @@ -37,6 +37,7 @@ issue-tracker = "https://github.com/fbruzzesi/timebasedcv/issues" [project.optional-dependencies] polars = ["polars"] +pandas = ["pandas"] dev = [ "pre-commit==2.21.0", @@ -62,8 +63,8 @@ test = [ "scikit-learn>=0.19", ] -all = ["timebasedcv[polars]"] -all-dev = ["timebasedcv[dev,docs,lint,polars,test]"] +all = ["timebasedcv[pandas,polars]"] +all-dev = ["timebasedcv[dev,docs,lint,pandas,polars,test]"] [tool.hatch.build.targets.sdist] only-include = ["timebasedcv"] diff --git a/tests/utils/backends_test.py b/tests/utils/backends_test.py index ab9f65f..180b513 100644 --- a/tests/utils/backends_test.py +++ b/tests/utils/backends_test.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +import narwhals as nw from timebasedcv.utils._backends import ( BACKEND_TO_INDEXING_METHOD, @@ -49,6 +50,8 @@ def test_backend_to_indexing_method(arr, mask, expected): """ Tests the `BACKEND_TO_INDEXING_METHOD` dictionary with different backends. """ + arr = nw.from_native(arr, allow_series=True, eager_only=True, strict=False) + mask = nw.from_native(mask, series_only=True, strict=False) _type = str(type(arr)) result = BACKEND_TO_INDEXING_METHOD[_type](arr, mask) assert np.array_equal(result, expected) diff --git a/timebasedcv/splitstate.py b/timebasedcv/splitstate.py index 10034a8..e4ec0cb 100644 --- a/timebasedcv/splitstate.py +++ b/timebasedcv/splitstate.py @@ -4,8 +4,6 @@ from operator import le as less_or_equal from typing import Generic -import pandas as pd - from timebasedcv.utils._funcs import pairwise, pairwise_comparison from timebasedcv.utils._types import DateTimeLike @@ -14,6 +12,8 @@ else: from typing_extensions import Self # pragma: no cover +from narwhals.dependencies import get_pandas + @dataclass(frozen=True) class SplitState(Generic[DateTimeLike]): @@ -57,10 +57,12 @@ def __post_init__(self: Self) -> None: _values = tuple(getattr(self, _attr) for _attr in _slots) _types = tuple(type(_value) for _value in _values) + pd = get_pandas() + if not ( all(_type is datetime for _type in _types) or all(_type is date for _type in _types) - or all(_type is pd.Timestamp for _type in _types) + or (pd is not None and all(_type is pd.Timestamp for _type in _types)) ): # cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object raise TypeError("All attributes must be of type `datetime`, `date` or `pd.Timestamp`.") diff --git a/timebasedcv/timebasedsplit.py b/timebasedcv/timebasedsplit.py index 1b48507..00dd69e 100644 --- a/timebasedcv/timebasedsplit.py +++ b/timebasedcv/timebasedsplit.py @@ -2,6 +2,7 @@ from datetime import timedelta from itertools import chain from typing import Generator, Tuple, Union, get_args +import narwhals as nw import numpy as np @@ -370,7 +371,9 @@ def split( ts_shape = time_series.shape if len(ts_shape) != 1: raise ValueError(f"Time series must be 1-dimensional. Got {len(ts_shape)} dimensions.") - + + arrays = tuple([nw.from_native(array, eager_only=True, allow_series=True, strict=False) for array in arrays]) + time_series = nw.from_native(time_series, series_only=True, strict=False) a0 = arrays[0] arr_len = a0.shape[0] @@ -393,7 +396,7 @@ def split( train_forecast_arrays = tuple( chain.from_iterable( - (_idx_method(_arr, train_mask), _idx_method(_arr, forecast_mask)) + (nw.to_native(_idx_method(_arr, train_mask), strict=False), nw.to_native(_idx_method(_arr, forecast_mask), strict=False)) for _arr, _idx_method in zip(arrays, _index_methods) ) ) diff --git a/timebasedcv/utils/_backends.py b/timebasedcv/utils/_backends.py index c511cde..ecd4ec0 100644 --- a/timebasedcv/utils/_backends.py +++ b/timebasedcv/utils/_backends.py @@ -1,7 +1,7 @@ from typing import Callable, Dict, TypeVar import numpy as np -import pandas as pd +import narwhals as nw def default_indexing_method(arr, mask): @@ -19,45 +19,21 @@ def default_indexing_method(arr, mask): return arr[mask] -T_PD = TypeVar("T_PD", pd.DataFrame, pd.Series) +T_NW = TypeVar("T_NW", nw.DataFrame, nw.Series) -def pd_indexing_method(_dfs: T_PD, mask) -> T_PD: - """Indexing method for pandas dataframes and series. +def nw_indexing_method(_dfs: T_NW, mask) -> T_NW: + """Indexing method for Narwhals dataframes and series. Arguments: - df: The pandas dataframe or series to index. + df: The Narwhals dataframe or series to index. mask: The boolean mask to use for indexing. """ - return _dfs.loc[mask] + return _dfs.filter(mask) BACKEND_TO_INDEXING_METHOD: Dict[str, Callable] = { str(np.ndarray): default_indexing_method, - str(pd.DataFrame): pd_indexing_method, - str(pd.Series): pd_indexing_method, + str(nw.DataFrame): nw_indexing_method, + str(nw.Series): nw_indexing_method, } - -try: - import polars as pl - - T_PL = TypeVar("T_PL", pl.DataFrame, pl.Series) - - def pl_indexing_method(_dfs: T_PL, mask) -> T_PL: - """Indexing method for polars dataframes and series. - - Arguments: - _dfs: The polars dataframe or series to index. - mask: The boolean mask to use for indexing. - """ - return _dfs.filter(mask) - - BACKEND_TO_INDEXING_METHOD.update( - { - str(pl.DataFrame): pl_indexing_method, - str(pl.Series): pl_indexing_method, - } - ) - -except ImportError: - pass diff --git a/timebasedcv/utils/_types.py b/timebasedcv/utils/_types.py index 901ee34..d3fb96f 100644 --- a/timebasedcv/utils/_types.py +++ b/timebasedcv/utils/_types.py @@ -2,9 +2,7 @@ import sys from datetime import date, datetime -from typing import Literal, Protocol, Tuple, TypeVar, Union - -import pandas as pd +from typing import Literal, Protocol, Tuple, TypeVar, Union, TYPE_CHECKING if sys.version_info >= (3, 10): from typing import TypeAlias # pragma: no cover @@ -16,7 +14,11 @@ else: from typing_extensions import Self # pragma: no cover -DateTimeLike = TypeVar("DateTimeLike", datetime, date, pd.Timestamp) +if TYPE_CHECKING: + import pandas as pd + + +DateTimeLike = TypeVar("DateTimeLike", datetime, date, "pd.Timestamp") FrequencyUnit: TypeAlias = Literal["days", "seconds", "microseconds", "milliseconds", "minutes", "hours", "weeks"] WindowType: TypeAlias = Literal["rolling", "expanding"]