Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try using narwhals to make dataframe-agnostic #38

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ authors = [{name = "Francesco Bruzzesi"}]

dependencies = [
"numpy",
"pandas",
"narwhals>=0.7.14",
"typing-extensions>=4.4.0; python_version < '3.11'",
]

Expand All @@ -37,6 +37,7 @@ issue-tracker = "https://github.com/fbruzzesi/timebasedcv/issues"

[project.optional-dependencies]
polars = ["polars"]
pandas = ["pandas"]
Copy link
Owner

@FBruzzesi FBruzzesi Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there min versions required by narwhals? It seems it doesn't from pyproject.toml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

I'll test this backwards and see how far we can get, but .filter I'd expect should go back essentially forever

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would pandas>=1.2.0 and Polars>=0.20.3 work for you? I've set those as the minimum, and have tested with the minimum versions

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


dev = [
"pre-commit==2.21.0",
Expand All @@ -62,8 +63,8 @@ test = [
"scikit-learn>=0.19",
]

all = ["timebasedcv[polars]"]
all-dev = ["timebasedcv[dev,docs,lint,polars,test]"]
all = ["timebasedcv[pandas,polars]"]
all-dev = ["timebasedcv[dev,docs,lint,pandas,polars,test]"]

[tool.hatch.build.targets.sdist]
only-include = ["timebasedcv"]
Expand Down
3 changes: 3 additions & 0 deletions tests/utils/backends_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import pytest
import narwhals as nw

from timebasedcv.utils._backends import (
BACKEND_TO_INDEXING_METHOD,
Expand Down Expand Up @@ -49,6 +50,8 @@ def test_backend_to_indexing_method(arr, mask, expected):
"""
Tests the `BACKEND_TO_INDEXING_METHOD` dictionary with different backends.
"""
arr = nw.from_native(arr, allow_series=True, eager_only=True, strict=False)
mask = nw.from_native(mask, series_only=True, strict=False)
_type = str(type(arr))
result = BACKEND_TO_INDEXING_METHOD[_type](arr, mask)
assert np.array_equal(result, expected)
8 changes: 5 additions & 3 deletions timebasedcv/splitstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from operator import le as less_or_equal
from typing import Generic

import pandas as pd

from timebasedcv.utils._funcs import pairwise, pairwise_comparison
from timebasedcv.utils._types import DateTimeLike

Expand All @@ -14,6 +12,8 @@
else:
from typing_extensions import Self # pragma: no cover

from narwhals.dependencies import get_pandas


@dataclass(frozen=True)
class SplitState(Generic[DateTimeLike]):
Expand Down Expand Up @@ -57,10 +57,12 @@ def __post_init__(self: Self) -> None:
_values = tuple(getattr(self, _attr) for _attr in _slots)
_types = tuple(type(_value) for _value in _values)

pd = get_pandas()

if not (
all(_type is datetime for _type in _types)
or all(_type is date for _type in _types)
or all(_type is pd.Timestamp for _type in _types)
or (pd is not None and all(_type is pd.Timestamp for _type in _types))
):
# cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
raise TypeError("All attributes must be of type `datetime`, `date` or `pd.Timestamp`.")
Expand Down
7 changes: 5 additions & 2 deletions timebasedcv/timebasedsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import timedelta
from itertools import chain
from typing import Generator, Tuple, Union, get_args
import narwhals as nw

import numpy as np

Expand Down Expand Up @@ -370,7 +371,9 @@ def split(
ts_shape = time_series.shape
if len(ts_shape) != 1:
raise ValueError(f"Time series must be 1-dimensional. Got {len(ts_shape)} dimensions.")


arrays = tuple([nw.from_native(array, eager_only=True, allow_series=True, strict=False) for array in arrays])
time_series = nw.from_native(time_series, series_only=True, strict=False)
a0 = arrays[0]
arr_len = a0.shape[0]

Expand All @@ -393,7 +396,7 @@ def split(

train_forecast_arrays = tuple(
chain.from_iterable(
(_idx_method(_arr, train_mask), _idx_method(_arr, forecast_mask))
(nw.to_native(_idx_method(_arr, train_mask), strict=False), nw.to_native(_idx_method(_arr, forecast_mask), strict=False))
for _arr, _idx_method in zip(arrays, _index_methods)
)
)
Expand Down
40 changes: 8 additions & 32 deletions timebasedcv/utils/_backends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Dict, TypeVar

import numpy as np
import pandas as pd
import narwhals as nw


def default_indexing_method(arr, mask):
Expand All @@ -19,45 +19,21 @@ def default_indexing_method(arr, mask):
return arr[mask]


T_PD = TypeVar("T_PD", pd.DataFrame, pd.Series)
T_NW = TypeVar("T_NW", nw.DataFrame, nw.Series)


def pd_indexing_method(_dfs: T_PD, mask) -> T_PD:
"""Indexing method for pandas dataframes and series.
def nw_indexing_method(_dfs: T_NW, mask) -> T_NW:
"""Indexing method for Narwhals dataframes and series.
Arguments:
df: The pandas dataframe or series to index.
df: The Narwhals dataframe or series to index.
mask: The boolean mask to use for indexing.
"""
return _dfs.loc[mask]
return _dfs.filter(mask)


BACKEND_TO_INDEXING_METHOD: Dict[str, Callable] = {
str(np.ndarray): default_indexing_method,
str(pd.DataFrame): pd_indexing_method,
str(pd.Series): pd_indexing_method,
str(nw.DataFrame): nw_indexing_method,
str(nw.Series): nw_indexing_method,
}

try:
import polars as pl

T_PL = TypeVar("T_PL", pl.DataFrame, pl.Series)

def pl_indexing_method(_dfs: T_PL, mask) -> T_PL:
"""Indexing method for polars dataframes and series.
Arguments:
_dfs: The polars dataframe or series to index.
mask: The boolean mask to use for indexing.
"""
return _dfs.filter(mask)

BACKEND_TO_INDEXING_METHOD.update(
{
str(pl.DataFrame): pl_indexing_method,
str(pl.Series): pl_indexing_method,
}
)

except ImportError:
pass
10 changes: 6 additions & 4 deletions timebasedcv/utils/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import sys
from datetime import date, datetime
from typing import Literal, Protocol, Tuple, TypeVar, Union

import pandas as pd
from typing import Literal, Protocol, Tuple, TypeVar, Union, TYPE_CHECKING

if sys.version_info >= (3, 10):
from typing import TypeAlias # pragma: no cover
Expand All @@ -16,7 +14,11 @@
else:
from typing_extensions import Self # pragma: no cover

DateTimeLike = TypeVar("DateTimeLike", datetime, date, pd.Timestamp)
if TYPE_CHECKING:
import pandas as pd


DateTimeLike = TypeVar("DateTimeLike", datetime, date, "pd.Timestamp")
FrequencyUnit: TypeAlias = Literal["days", "seconds", "microseconds", "milliseconds", "minutes", "hours", "weeks"]
WindowType: TypeAlias = Literal["rolling", "expanding"]

Expand Down