diff --git a/docs/examples/electricity.ipynb b/docs/examples/electricity.ipynb index 0187063..3af4337 100644 --- a/docs/examples/electricity.ipynb +++ b/docs/examples/electricity.ipynb @@ -1509,7 +1509,7 @@ "\n", "ss_trainer = StateSpaceTrainer(\n", " module=kf_nn,\n", - " kwargs_getter=lambda batch: {'X' : season_trainer.module(batch.tensors[1])},\n", + " dataset_to_kwargs=lambda batch: {'X' : season_trainer.module(batch.tensors[1])},\n", ")\n", "\n", "## commented out since we're going with option 2 below\n", @@ -1535,13 +1535,13 @@ "metadata": {}, "outputs": [], "source": [ - "def _kwargs_getter(batch: TimeSeriesDataset) -> dict:\n", + "def dataset_to_kwargs(batch: TimeSeriesDataset) -> dict:\n", " seasonX = season_trainer.times_to_model_mat(batch.times()).to(dtype=torch.float, device=DEVICE)\n", " return {'X' : season_trainer.module.season_nn(seasonX)}\n", "\n", "ss_trainer = StateSpaceTrainer(\n", " module=kf_nn,\n", - " kwargs_getter=_kwargs_getter,\n", + " dataset_to_kwargs=dataset_to_kwargs,\n", " optimizer=torch.optim.Adam(kf_nn.parameters(), lr=.05)\n", ")" ] diff --git a/torchcast/utils/training.py b/torchcast/utils/training.py index e8c10ce..954c202 100644 --- a/torchcast/utils/training.py +++ b/torchcast/utils/training.py @@ -154,7 +154,7 @@ class StateSpaceTrainer(BaseTrainer): :param module: A ``StateSpaceModel`` instance (e.g. ``KalmanFilter`` or ``ExpSmoother``). - :param kwargs_getter: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary + :param dataset_to_kwargs: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary of keyword-arguments to pass to each call of the module's ``forward`` method. If left unspecified and the batch has a 2nd tensor, will pass ``tensor[1]`` as the ``X`` keyword. :param optimizer: An optimizer (or a class to instantiate an optimizer). Default is :class:`torch.optim.Adam`. @@ -162,10 +162,10 @@ class StateSpaceTrainer(BaseTrainer): def __init__(self, module: nn.Module, - kwargs_getter: Optional[Sequence[str]] = None, + dataset_to_kwargs: Optional[Sequence[str]] = None, optimizer: Union[Optimizer, Type[Optimizer]] = torch.optim.Adam): - self.kwargs_getter = kwargs_getter + self.dataset_to_kwargs = dataset_to_kwargs super().__init__(module=module, optimizer=optimizer) def get_loss(self, pred: Predictions, y: torch.Tensor) -> torch.Tensor: @@ -175,22 +175,22 @@ def _batch_to_args(self, batch: TimeSeriesDataset) -> Tuple[torch.Tensor, dict]: batch = batch.to(self._device) y = batch.tensors[0] - if callable(self.kwargs_getter): - kwargs = self.kwargs_getter(batch) + if callable(self.dataset_to_kwargs): + kwargs = self.dataset_to_kwargs(batch) else: - if self.kwargs_getter is None: - self.kwargs_getter = ['X'] if len(batch.tensors) > 1 else [] + if self.dataset_to_kwargs is None: + self.dataset_to_kwargs = ['X'] if len(batch.tensors) > 1 else [] kwargs = {} - for i, (k, t) in enumerate(zip_longest(self.kwargs_getter, batch.tensors[1:])): + for i, (k, t) in enumerate(zip_longest(self.dataset_to_kwargs, batch.tensors[1:])): if k is None: raise RuntimeError( - f"Found element-{i + 1} of the dataset.tensors, but `kwargs_getter` doesn't have enough " - f"elements: {self.kwargs_getter}" + f"Found element-{i + 1} of the dataset.tensors, but `dataset_to_kwargs` doesn't have enough " + f"elements: {self.dataset_to_kwargs}" ) if t is None: raise RuntimeError( - f"Found element-{i} of `kwargs_getter`, but `dataset.tensors` doesn't have enough " + f"Found element-{i} of `dataset_to_kwargs`, but `dataset.tensors` doesn't have enough " f"elements: {batch}" ) kwargs[k] = t @@ -199,7 +199,7 @@ def _batch_to_args(self, batch: TimeSeriesDataset) -> Tuple[torch.Tensor, dict]: def _get_closure(self, batch: TimeSeriesDataset, forward_kwargs: dict) -> callable: def closure(): - # we call _batch_to_args from inside the closure in case `kwargs_getter` is callable & involves grad. + # we call _batch_to_args from inside the closure in case `dataset_to_kwargs` is callable & involves grad. # only scenario this would matter is if optimizer is LBFGS (or another custom optimizer that calls closure # multiple times per step), in which case grad from callable would be lost after the first step. y, kwargs = self._batch_to_args(batch)