diff --git a/torchcast/utils/training.py b/torchcast/utils/training.py index 117b471..e8c10ce 100644 --- a/torchcast/utils/training.py +++ b/torchcast/utils/training.py @@ -1,6 +1,6 @@ """ These are simple training classes for PyTorch models, with specialized subclasses for torchcast's model-classes (i.e., -when the data are too big for the :func:`torchcast.state_space.StateSpaceModel.fit()` method). Additionally, there is a +when the data are too big for the ``StateSpaceModel.fit()`` method). Additionally, there is a special class for training neural networks to embed complex seasonal patterns into lower dimensional embeddings. While the classes in this module are helpful for quick development, they are not necessarily meant to replace more @@ -134,9 +134,9 @@ def _get_batch_numel(self, batch: Dataset) -> int: class StateSpaceTrainer(BaseTrainer): """ - A trainer for a :class:`torchcast.state_space.StateSpaceModel` instance. This is for usage in contexts where the - data are too large for :func:`torchcast.state_space.StateSpaceModel.fit()` to be practical. Rather than the base - DataLoader, this class takes a :class:`torchcast.utils.TimeSeriesDataLoader`. + A trainer for a :``StateSpaceModel``. This is for usage in contexts where the data are too large for + ``StateSpaceModel.fit()`` to be practical. Rather than the base DataLoader, this class takes a + :class:`torchcast.utils.TimeSeriesDataLoader`. Usage: @@ -153,7 +153,7 @@ class StateSpaceTrainer(BaseTrainer): # log the loss, early-stopping, etc. - :param module: A :class:`torchcast.state_space.StateSpaceModel` instance (e.g. ``KalmanFilter`` or ``ExpSmoother``). + :param module: A ``StateSpaceModel`` instance (e.g. ``KalmanFilter`` or ``ExpSmoother``). :param kwargs_getter: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary of keyword-arguments to pass to each call of the module's ``forward`` method. If left unspecified and the batch has a 2nd tensor, will pass ``tensor[1]`` as the ``X`` keyword.