Skip to content

Commit

Permalink
Doc tweaks (#28)
Browse files Browse the repository at this point in the history
* fix circular import; simplify quickstart

* remove todo and typo

* Update air_quality.ipynb

* kwargs_getter= dataset_to_kwargs

* Update electricity.ipynb
  • Loading branch information
jwdink authored Jan 7, 2025
1 parent 656ff10 commit d6d04c8
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 136 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ torchcast

1. An API designed around training and forecasting with *batches* of time-series, rather than training separate models for one time-series at a time.
2. Robust support for *multivariate* time-series, where multiple correlated measures are being forecasted.
3. Forecasting models that are hybrids: they are classic state-space models with the twist that every part is differentiable and can take advantage of PyTorch's flexibility. For `example <https://docs.strong.io/torchcast/examples/electricity.html#Training-our-Hybrid-Forecasting-Model>`_, we can use arbitrary PyTorch :class:`torch.nn.Modules` to learn seasonal variations across multiple groups, embedding complex seasonality into lower-dimensional space.
3. Forecasting models that are hybrids: they are classic state-space models with the twist that every part is differentiable and can take advantage of PyTorch's flexibility. For `example <https://docs.strong.io/torchcast/examples/electricity.html#Training-our-Hybrid-Forecasting-Model>`_, we can use arbitrary PyTorch ``torch.nn.Modules`` to learn seasonal variations across multiple groups, embedding complex seasonality into lower-dimensional space.

This repository is the work of `Strong <https://www.strong.io/>`_.

Expand Down
140 changes: 35 additions & 105 deletions docs/examples/air_quality.ipynb

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions docs/examples/electricity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@
"id": "4a82175c-56c8-454a-925f-9cabfbedd079"
},
"source": [
"The first thing you'll notice about this approach is that it's **incredibly slow to train**. The second problem is that the forecasts are **terrible**:"
"The problem is that the forecasts are **terrible**:"
]
},
{
Expand Down Expand Up @@ -495,10 +495,9 @@
"source": [
"### Attempt 2\n",
"\n",
"Let's see if we can improve on this. We'll leave the model unchanged but make two changes:\n",
"Let's see if we can improve on this. We'll leave the model unchanged, but train it differently. We'll use the `n_step` argument to train our model on one-week ahead forecasts, instead of one-hour ahead. This improves the efficiency of training by encouraging the model to 'care about' longer range forecasts vs. over-focusing on the easier problem of forecasting the next hour.\n",
"\n",
"- Use the `n_step` argument to train our model on one-week ahead forecasts, instead of one step (i.e. hour) ahead. This improves the efficiency of training by encouraging the model to 'care about' longer range forecasts vs. over-focusing on the easier problem of forecasting the next hour.\n",
"- Split our single series into multiple groups. This is helpful to speed up training, since pytorch has a non-trivial overhead for separate tensors: i.e., it scales well with an increasing batch-size (fewer, but bigger, tensors), but poorly with an increasing time-series length (smaller, but more, tensors)."
"Another thing we'll address is the fact that the simple (and lousy) model above was surprisingly slow to train. This is because Pytorch has a non-trivial overhead for separate tensors: i.e., it scales well with an increasing batch-size (fewer, but bigger, tensors), but poorly with an increasing time-series length (smaller, but more, tensors). So to speed things up, we'll split our single series into multiple groups."
]
},
{
Expand Down Expand Up @@ -632,7 +631,7 @@
")\n",
"\n",
"ds_example_train2, _ = ds_example_building2.train_val_split(dt=SPLIT_DT, quiet=True)\n",
"# TODO: explain\n",
"# our subgroup approach leaves a few of the resulting time-series as very small periods (<60 days) we'd like to drop\n",
"ds_example_train2 = ds_example_train2[ds_example_train2.get_durations() > 1400]\n",
"ds_example_train2"
]
Expand Down Expand Up @@ -1453,7 +1452,7 @@
"id": "caf84e63",
"metadata": {},
"source": [
"2. Next, we create add our season features to the dataframe, and create a dataloader, passing these feature-names to the `X_colnames` argument:"
"2. Next, we add our season features to the dataframe, and create a dataloader, passing these feature-names to the `X_colnames` argument:"
]
},
{
Expand Down Expand Up @@ -1509,7 +1508,7 @@
"\n",
"ss_trainer = StateSpaceTrainer(\n",
" module=kf_nn,\n",
" kwargs_getter=lambda batch: {'X' : season_trainer.module(batch.tensors[1])},\n",
" dataset_to_kwargs=lambda batch: {'X' : season_trainer.module(batch.tensors[1])},\n",
")\n",
"\n",
"## commented out since we're going with option 2 below\n",
Expand All @@ -1535,13 +1534,13 @@
"metadata": {},
"outputs": [],
"source": [
"def _kwargs_getter(batch: TimeSeriesDataset) -> dict:\n",
"def dataset_to_kwargs(batch: TimeSeriesDataset) -> dict:\n",
" seasonX = season_trainer.times_to_model_mat(batch.times()).to(dtype=torch.float, device=DEVICE)\n",
" return {'X' : season_trainer.module.season_nn(seasonX)}\n",
"\n",
"ss_trainer = StateSpaceTrainer(\n",
" module=kf_nn,\n",
" kwargs_getter=_kwargs_getter,\n",
" dataset_to_kwargs=dataset_to_kwargs,\n",
" optimizer=torch.optim.Adam(kf_nn.parameters(), lr=.05)\n",
")"
]
Expand Down
9 changes: 3 additions & 6 deletions docs/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,7 @@
# `Predictions` can easily be converted to Pandas `DataFrames` for ease of inspecting predictions, comparing them to actuals, and visualizing:

# %%
df_pred = pred.to_dataframe(dataset_all, conf=None)
# bias-correction for log-transform (see https://otexts.com/fpp2/transformations.html#bias-adjustments)
df_pred['mean'] += .5 * df_pred['std'] ** 2
df_pred['lower'] = df_pred['mean'] - 1.96 * df_pred['std']
df_pred['upper'] = df_pred['mean'] + 1.96 * df_pred['std']
# inverse the log10:
df_pred = pred.to_dataframe(dataset_all)
df_pred[['actual','mean','upper','lower']] = 10 ** df_pred[['actual','mean','upper','lower']]
df_pred

Expand All @@ -137,3 +132,5 @@
pred.to_dataframe(dataset_all, type='components').query("group=='Changping'"), split_dt=SPLIT_DT,
time_colname='time', group_colname='group'
)

# %%
8 changes: 5 additions & 3 deletions torchcast/state_space/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from scipy import stats

from torchcast.internals.utils import get_nan_groups, is_near_zero, transpose_last_dims, class_or_instancemethod
from torchcast.utils import TimeSeriesDataset

if TYPE_CHECKING:
from torchcast.state_space import StateSpaceModel
from torchcast.utils import TimeSeriesDataset


class Predictions(nn.Module):
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(self,
self._dataset_metadata = None

def set_metadata(self,
dataset: Optional[TimeSeriesDataset] = None,
dataset: Optional['TimeSeriesDataset'] = None,
group_names: Optional[Sequence[str]] = None,
start_offsets: Optional[np.ndarray] = None,
group_colname: str = 'group',
Expand Down Expand Up @@ -346,7 +346,7 @@ def _log_prob(self, obs: Tensor, means: Tensor, covs: Tensor) -> Tensor:
return torch.distributions.MultivariateNormal(means, covs, validate_args=False).log_prob(obs)

def to_dataframe(self,
dataset: Optional[TimeSeriesDataset] = None,
dataset: Optional['TimeSeriesDataset'] = None,
type: str = 'predictions',
group_colname: Optional[str] = None,
time_colname: Optional[str] = None,
Expand All @@ -364,6 +364,8 @@ def to_dataframe(self,
:return: A pandas DataFrame with group, 'time', 'measure', 'mean', 'lower', 'upper'. For ``type='components'``
additionally includes: 'process' and 'state_element'.
"""
from torchcast.utils import TimeSeriesDataset

multi = kwargs.pop('multi', False)
if multi is not False:
msg = "`multi` is deprecated, please use `conf` instead."
Expand Down
24 changes: 12 additions & 12 deletions torchcast/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ class StateSpaceTrainer(BaseTrainer):
:param module: A ``StateSpaceModel`` instance (e.g. ``KalmanFilter`` or ``ExpSmoother``).
:param kwargs_getter: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary
:param dataset_to_kwargs: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary
of keyword-arguments to pass to each call of the module's ``forward`` method. If left unspecified and the batch
has a 2nd tensor, will pass ``tensor[1]`` as the ``X`` keyword.
:param optimizer: An optimizer (or a class to instantiate an optimizer). Default is :class:`torch.optim.Adam`.
"""

def __init__(self,
module: nn.Module,
kwargs_getter: Optional[Sequence[str]] = None,
dataset_to_kwargs: Optional[Sequence[str]] = None,
optimizer: Union[Optimizer, Type[Optimizer]] = torch.optim.Adam):

self.kwargs_getter = kwargs_getter
self.dataset_to_kwargs = dataset_to_kwargs
super().__init__(module=module, optimizer=optimizer)

def get_loss(self, pred: Predictions, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -175,22 +175,22 @@ def _batch_to_args(self, batch: TimeSeriesDataset) -> Tuple[torch.Tensor, dict]:
batch = batch.to(self._device)
y = batch.tensors[0]

if callable(self.kwargs_getter):
kwargs = self.kwargs_getter(batch)
if callable(self.dataset_to_kwargs):
kwargs = self.dataset_to_kwargs(batch)
else:
if self.kwargs_getter is None:
self.kwargs_getter = ['X'] if len(batch.tensors) > 1 else []
if self.dataset_to_kwargs is None:
self.dataset_to_kwargs = ['X'] if len(batch.tensors) > 1 else []

kwargs = {}
for i, (k, t) in enumerate(zip_longest(self.kwargs_getter, batch.tensors[1:])):
for i, (k, t) in enumerate(zip_longest(self.dataset_to_kwargs, batch.tensors[1:])):
if k is None:
raise RuntimeError(
f"Found element-{i + 1} of the dataset.tensors, but `kwargs_getter` doesn't have enough "
f"elements: {self.kwargs_getter}"
f"Found element-{i + 1} of the dataset.tensors, but `dataset_to_kwargs` doesn't have enough "
f"elements: {self.dataset_to_kwargs}"
)
if t is None:
raise RuntimeError(
f"Found element-{i} of `kwargs_getter`, but `dataset.tensors` doesn't have enough "
f"Found element-{i} of `dataset_to_kwargs`, but `dataset.tensors` doesn't have enough "
f"elements: {batch}"
)
kwargs[k] = t
Expand All @@ -199,7 +199,7 @@ def _batch_to_args(self, batch: TimeSeriesDataset) -> Tuple[torch.Tensor, dict]:
def _get_closure(self, batch: TimeSeriesDataset, forward_kwargs: dict) -> callable:

def closure():
# we call _batch_to_args from inside the closure in case `kwargs_getter` is callable & involves grad.
# we call _batch_to_args from inside the closure in case `dataset_to_kwargs` is callable & involves grad.
# only scenario this would matter is if optimizer is LBFGS (or another custom optimizer that calls closure
# multiple times per step), in which case grad from callable would be lost after the first step.
y, kwargs = self._batch_to_args(batch)
Expand Down

0 comments on commit d6d04c8

Please sign in to comment.