Skip to content

Commit

Permalink
replace links for now to avoid circular
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Jan 7, 2025
1 parent b312f4c commit f6e59cc
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torchcast/utils/training.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
These are simple training classes for PyTorch models, with specialized subclasses for torchcast's model-classes (i.e.,
when the data are too big for the :func:`torchcast.state_space.StateSpaceModel.fit()` method). Additionally, there is a
when the data are too big for the ``StateSpaceModel.fit()`` method). Additionally, there is a
special class for training neural networks to embed complex seasonal patterns into lower dimensional embeddings.
While the classes in this module are helpful for quick development, they are not necessarily meant to replace more
Expand Down Expand Up @@ -134,9 +134,9 @@ def _get_batch_numel(self, batch: Dataset) -> int:

class StateSpaceTrainer(BaseTrainer):
"""
A trainer for a :class:`torchcast.state_space.StateSpaceModel` instance. This is for usage in contexts where the
data are too large for :func:`torchcast.state_space.StateSpaceModel.fit()` to be practical. Rather than the base
DataLoader, this class takes a :class:`torchcast.utils.TimeSeriesDataLoader`.
A trainer for a :``StateSpaceModel``. This is for usage in contexts where the data are too large for
``StateSpaceModel.fit()`` to be practical. Rather than the base DataLoader, this class takes a
:class:`torchcast.utils.TimeSeriesDataLoader`.
Usage:
Expand All @@ -153,7 +153,7 @@ class StateSpaceTrainer(BaseTrainer):
# log the loss, early-stopping, etc.
:param module: A :class:`torchcast.state_space.StateSpaceModel` instance (e.g. ``KalmanFilter`` or ``ExpSmoother``).
:param module: A ``StateSpaceModel`` instance (e.g. ``KalmanFilter`` or ``ExpSmoother``).
:param kwargs_getter: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary
of keyword-arguments to pass to each call of the module's ``forward`` method. If left unspecified and the batch
has a 2nd tensor, will pass ``tensor[1]`` as the ``X`` keyword.
Expand Down

0 comments on commit f6e59cc

Please sign in to comment.