Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Apr 25, 2024
1 parent c23f6fa commit 7f55cb5
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 43 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ authors = [{name = "Francesco Bruzzesi"}]

dependencies = [
"numpy",
"pandas",
"narwhals>=0.7.12",
"typing-extensions>=4.4.0; python_version < '3.11'",
]

Expand All @@ -37,6 +37,7 @@ issue-tracker = "https://github.com/fbruzzesi/timebasedcv/issues"

[project.optional-dependencies]
polars = ["polars"]
pandas = ["pandas"]

dev = [
"pre-commit==2.21.0",
Expand All @@ -62,8 +63,8 @@ test = [
"scikit-learn>=0.19",
]

all = ["timebasedcv[polars]"]
all-dev = ["timebasedcv[dev,docs,lint,polars,test]"]
all = ["timebasedcv[pandas,polars]"]
all-dev = ["timebasedcv[dev,docs,lint,pandas,polars,test]"]

[tool.hatch.build.targets.sdist]
only-include = ["timebasedcv"]
Expand Down
3 changes: 3 additions & 0 deletions tests/utils/backends_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import pytest
import narwhals as nw

from timebasedcv.utils._backends import (
BACKEND_TO_INDEXING_METHOD,
Expand Down Expand Up @@ -49,6 +50,8 @@ def test_backend_to_indexing_method(arr, mask, expected):
"""
Tests the `BACKEND_TO_INDEXING_METHOD` dictionary with different backends.
"""
arr = nw.from_native(arr, allow_series=True, eager_only=True, strict=False)
mask = nw.from_native(mask, series_only=True, strict=False)
_type = str(type(arr))
result = BACKEND_TO_INDEXING_METHOD[_type](arr, mask)
assert np.array_equal(result, expected)
8 changes: 5 additions & 3 deletions timebasedcv/splitstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from operator import le as less_or_equal
from typing import Generic

import pandas as pd

from timebasedcv.utils._funcs import pairwise, pairwise_comparison
from timebasedcv.utils._types import DateTimeLike

Expand All @@ -14,6 +12,8 @@
else:
from typing_extensions import Self # pragma: no cover

from narwhals.dependencies import get_pandas


@dataclass(frozen=True)
class SplitState(Generic[DateTimeLike]):
Expand Down Expand Up @@ -57,10 +57,12 @@ def __post_init__(self: Self) -> None:
_values = tuple(getattr(self, _attr) for _attr in _slots)
_types = tuple(type(_value) for _value in _values)

pd = get_pandas()

if not (
all(_type is datetime for _type in _types)
or all(_type is date for _type in _types)
or all(_type is pd.Timestamp for _type in _types)
or (pd is not None and all(_type is pd.Timestamp for _type in _types))
):
# cfr: https://stackoverflow.com/questions/16991948/detect-if-a-variable-is-a-datetime-object
raise TypeError("All attributes must be of type `datetime`, `date` or `pd.Timestamp`.")
Expand Down
5 changes: 4 additions & 1 deletion timebasedcv/timebasedsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import timedelta
from itertools import chain
from typing import Generator, Tuple, Union, get_args
import narwhals as nw

import numpy as np

Expand Down Expand Up @@ -370,7 +371,9 @@ def split(
ts_shape = time_series.shape
if len(ts_shape) != 1:
raise ValueError(f"Time series must be 1-dimensional. Got {len(ts_shape)} dimensions.")


arrays = tuple([nw.from_native(array, eager_only=True, allow_series=True, strict=False) for array in arrays])
time_series = nw.from_native(time_series, series_only=True, strict=False)
a0 = arrays[0]
arr_len = a0.shape[0]

Expand Down
40 changes: 8 additions & 32 deletions timebasedcv/utils/_backends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Dict, TypeVar

import numpy as np
import pandas as pd
import narwhals as nw


def default_indexing_method(arr, mask):
Expand All @@ -19,45 +19,21 @@ def default_indexing_method(arr, mask):
return arr[mask]


T_PD = TypeVar("T_PD", pd.DataFrame, pd.Series)
T_NW = TypeVar("T_NW", nw.DataFrame, nw.Series)


def pd_indexing_method(_dfs: T_PD, mask) -> T_PD:
"""Indexing method for pandas dataframes and series.
def nw_indexing_method(_dfs: T_NW, mask) -> T_NW:
"""Indexing method for Narwhals dataframes and series.
Arguments:
df: The pandas dataframe or series to index.
df: The Narwhals dataframe or series to index.
mask: The boolean mask to use for indexing.
"""
return _dfs.loc[mask]
return _dfs.filter(mask)


BACKEND_TO_INDEXING_METHOD: Dict[str, Callable] = {
str(np.ndarray): default_indexing_method,
str(pd.DataFrame): pd_indexing_method,
str(pd.Series): pd_indexing_method,
str(nw.DataFrame): nw_indexing_method,
str(nw.Series): nw_indexing_method,
}

try:
import polars as pl

T_PL = TypeVar("T_PL", pl.DataFrame, pl.Series)

def pl_indexing_method(_dfs: T_PL, mask) -> T_PL:
"""Indexing method for polars dataframes and series.
Arguments:
_dfs: The polars dataframe or series to index.
mask: The boolean mask to use for indexing.
"""
return _dfs.filter(mask)

BACKEND_TO_INDEXING_METHOD.update(
{
str(pl.DataFrame): pl_indexing_method,
str(pl.Series): pl_indexing_method,
}
)

except ImportError:
pass
10 changes: 6 additions & 4 deletions timebasedcv/utils/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import sys
from datetime import date, datetime
from typing import Literal, Protocol, Tuple, TypeVar, Union

import pandas as pd
from typing import Literal, Protocol, Tuple, TypeVar, Union, TYPE_CHECKING

if sys.version_info >= (3, 10):
from typing import TypeAlias # pragma: no cover
Expand All @@ -16,7 +14,11 @@
else:
from typing_extensions import Self # pragma: no cover

DateTimeLike = TypeVar("DateTimeLike", datetime, date, pd.Timestamp)
if TYPE_CHECKING:
import pandas as pd


DateTimeLike = TypeVar("DateTimeLike", datetime, date, "pd.Timestamp")
FrequencyUnit: TypeAlias = Literal["days", "seconds", "microseconds", "milliseconds", "minutes", "hours", "weeks"]
WindowType: TypeAlias = Literal["rolling", "expanding"]

Expand Down

0 comments on commit 7f55cb5

Please sign in to comment.