Skip to content

Commit

Permalink
docstrings and fix test warning
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed May 12, 2024
1 parent f1733b9 commit c3a7ef0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
8 changes: 4 additions & 4 deletions tests/timebasedsplit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
valid_kwargs = {
"frequency": "days",
"train_size": 7,
"forecast_horizon": 3,
"gap": 0,
"stride": 2,
"forecast_horizon": 4,
"gap": 1,
"stride": 3,
"window": "rolling",
}


start_dt = pd.Timestamp(2023, 1, 1)
end_dt = pd.Timestamp(2023, 3, 31)
end_dt = pd.Timestamp(2023, 1, 31)

time_series = pd.Series(pd.date_range(start_dt, end_dt, freq="D"))
size = len(time_series)
Expand Down
34 changes: 22 additions & 12 deletions timebasedcv/timebasedsplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ class _CoreTimeBasedSplit:
Arguments:
frequency: The frequency of the time series. Must be one of "days", "seconds", "microseconds", "milliseconds",
"minutes", "hours", "weeks". These are the only valid values for the `unit` argument of the `timedelta`.
"minutes", "hours", "weeks". These are the only valid values for the `unit` argument of `timedelta` from
python `datetime` standard library.
train_size: The size of the training set.
forecast_horizon: The size of the forecast horizon.
forecast_horizon: The size of the forecast horizon, i.e. the size of the test set.
gap: The size of the gap between the training set and the forecast horizon.
stride: The size of the stride between consecutive splits. Notice that if stride is not provided (or set to 0),
it is set to `forecast_horizon`.
window: The type of window to use. Must be one of "rolling" or "expanding".
it fallbacks to the `forecast_horizon` quantity.
window: The type of window to use, either "rolling" or "expanding".
mode: Determines in which orders the splits are generated, either "forward" or "backward".
Raises:
ValueError: If `frequency` is not one of "days", "seconds", "microseconds", "milliseconds", "minutes", "hours",
Expand Down Expand Up @@ -266,13 +268,15 @@ class TimeBasedSplit(_CoreTimeBasedSplit):
Arguments:
frequency: The frequency of the time series. Must be one of "days", "seconds", "microseconds", "milliseconds",
"minutes", "hours", "weeks". These are the only valid values for the `unit` argument of the `timedelta`.
"minutes", "hours", "weeks". These are the only valid values for the `unit` argument of `timedelta` from
python `datetime` standard library.
train_size: The size of the training set.
forecast_horizon: The size of the forecast horizon.
forecast_horizon: The size of the forecast horizon, i.e. the size of the test set.
gap: The size of the gap between the training set and the forecast horizon.
stride: The size of the stride between consecutive splits. Notice that if stride is not provided (or set to 0),
it is set to `forecast_horizon`.
window: The type of window to use. Must be one of "rolling" or "expanding".
it fallbacks to the `forecast_horizon` quantity.
window: The type of window to use, either "rolling" or "expanding".
mode: Determines in which orders the splits are generated, either "forward" or "backward".
Raises:
ValueError: If `frequency` is not one of "days", "seconds", "microseconds", "milliseconds", "minutes", "hours",
Expand Down Expand Up @@ -499,6 +503,7 @@ def __init__( # noqa: PLR0913
forecast_horizon: int,
gap: int = 0,
stride: Union[int, None] = None,
mode: ModeType,
) -> None:
super().__init__(
frequency=frequency,
Expand All @@ -507,6 +512,7 @@ def __init__( # noqa: PLR0913
gap=gap,
stride=stride,
window="expanding",
mode=mode,
)


Expand All @@ -523,6 +529,7 @@ def __init__( # noqa: PLR0913
forecast_horizon: int,
gap: int = 0,
stride: Union[int, None] = None,
mode: ModeType,
) -> None:
super().__init__(
frequency=frequency,
Expand All @@ -531,6 +538,7 @@ def __init__( # noqa: PLR0913
gap=gap,
stride=stride,
window="rolling",
mode=mode,
)


Expand All @@ -549,9 +557,10 @@ class TimeBasedCVSplitter(BaseCrossValidator):
Arguments:
frequency: The frequency of the time series. Must be one of "days", "seconds", "microseconds", "milliseconds",
"minutes", "hours", "weeks". These are the only valid values for the `unit` argument of the `timedelta`.
"minutes", "hours", "weeks". These are the only valid values for the `unit` argument of `timedelta` from
python `datetime` standard library.
train_size: The size of the training set.
forecast_horizon: The size of the forecast horizon.
forecast_horizon: The size of the forecast horizon, i.e. the size of the test set.
time_series: The time series used to create boolean mask for splits. It is not required to be sorted, but it
must support:
Expand All @@ -561,8 +570,9 @@ class TimeBasedCVSplitter(BaseCrossValidator):
- `.shape` attribute.
gap: The size of the gap between the training set and the forecast horizon.
stride: The size of the stride between consecutive splits. Notice that if stride is not provided (or set to 0),
it is set to `forecast_horizon`.
window: The type of window to use. Must be one of "rolling" or "expanding".
it fallbacks to the `forecast_horizon` quantity.
window: The type of window to use, either "rolling" or "expanding".
mode: Determines in which orders the splits are generated, either "forward" or "backward".
start_dt: The start of the time period. If provided, it is used in place of the `time_series.min()`.
end_dt: The end of the time period. If provided,it is used in place of the `time_series.max()`.
Expand Down

0 comments on commit c3a7ef0

Please sign in to comment.