Skip to content

Commit

Permalink
kwargs_getter= dataset_to_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Jan 7, 2025
1 parent 08ba3b9 commit 52ae254
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
6 changes: 3 additions & 3 deletions docs/examples/electricity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@
"\n",
"ss_trainer = StateSpaceTrainer(\n",
" module=kf_nn,\n",
" kwargs_getter=lambda batch: {'X' : season_trainer.module(batch.tensors[1])},\n",
" dataset_to_kwargs=lambda batch: {'X' : season_trainer.module(batch.tensors[1])},\n",
")\n",
"\n",
"## commented out since we're going with option 2 below\n",
Expand All @@ -1535,13 +1535,13 @@
"metadata": {},
"outputs": [],
"source": [
"def _kwargs_getter(batch: TimeSeriesDataset) -> dict:\n",
"def dataset_to_kwargs(batch: TimeSeriesDataset) -> dict:\n",
" seasonX = season_trainer.times_to_model_mat(batch.times()).to(dtype=torch.float, device=DEVICE)\n",
" return {'X' : season_trainer.module.season_nn(seasonX)}\n",
"\n",
"ss_trainer = StateSpaceTrainer(\n",
" module=kf_nn,\n",
" kwargs_getter=_kwargs_getter,\n",
" dataset_to_kwargs=dataset_to_kwargs,\n",
" optimizer=torch.optim.Adam(kf_nn.parameters(), lr=.05)\n",
")"
]
Expand Down
24 changes: 12 additions & 12 deletions torchcast/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ class StateSpaceTrainer(BaseTrainer):
:param module: A ``StateSpaceModel`` instance (e.g. ``KalmanFilter`` or ``ExpSmoother``).
:param kwargs_getter: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary
:param dataset_to_kwargs: A callable that takes a :class:`torchcast.utils.TimeSeriesDataset` and returns a dictionary
of keyword-arguments to pass to each call of the module's ``forward`` method. If left unspecified and the batch
has a 2nd tensor, will pass ``tensor[1]`` as the ``X`` keyword.
:param optimizer: An optimizer (or a class to instantiate an optimizer). Default is :class:`torch.optim.Adam`.
"""

def __init__(self,
module: nn.Module,
kwargs_getter: Optional[Sequence[str]] = None,
dataset_to_kwargs: Optional[Sequence[str]] = None,
optimizer: Union[Optimizer, Type[Optimizer]] = torch.optim.Adam):

self.kwargs_getter = kwargs_getter
self.dataset_to_kwargs = dataset_to_kwargs
super().__init__(module=module, optimizer=optimizer)

def get_loss(self, pred: Predictions, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -175,22 +175,22 @@ def _batch_to_args(self, batch: TimeSeriesDataset) -> Tuple[torch.Tensor, dict]:
batch = batch.to(self._device)
y = batch.tensors[0]

if callable(self.kwargs_getter):
kwargs = self.kwargs_getter(batch)
if callable(self.dataset_to_kwargs):
kwargs = self.dataset_to_kwargs(batch)
else:
if self.kwargs_getter is None:
self.kwargs_getter = ['X'] if len(batch.tensors) > 1 else []
if self.dataset_to_kwargs is None:
self.dataset_to_kwargs = ['X'] if len(batch.tensors) > 1 else []

kwargs = {}
for i, (k, t) in enumerate(zip_longest(self.kwargs_getter, batch.tensors[1:])):
for i, (k, t) in enumerate(zip_longest(self.dataset_to_kwargs, batch.tensors[1:])):
if k is None:
raise RuntimeError(
f"Found element-{i + 1} of the dataset.tensors, but `kwargs_getter` doesn't have enough "
f"elements: {self.kwargs_getter}"
f"Found element-{i + 1} of the dataset.tensors, but `dataset_to_kwargs` doesn't have enough "
f"elements: {self.dataset_to_kwargs}"
)
if t is None:
raise RuntimeError(
f"Found element-{i} of `kwargs_getter`, but `dataset.tensors` doesn't have enough "
f"Found element-{i} of `dataset_to_kwargs`, but `dataset.tensors` doesn't have enough "
f"elements: {batch}"
)
kwargs[k] = t
Expand All @@ -199,7 +199,7 @@ def _batch_to_args(self, batch: TimeSeriesDataset) -> Tuple[torch.Tensor, dict]:
def _get_closure(self, batch: TimeSeriesDataset, forward_kwargs: dict) -> callable:

def closure():
# we call _batch_to_args from inside the closure in case `kwargs_getter` is callable & involves grad.
# we call _batch_to_args from inside the closure in case `dataset_to_kwargs` is callable & involves grad.
# only scenario this would matter is if optimizer is LBFGS (or another custom optimizer that calls closure
# multiple times per step), in which case grad from callable would be lost after the first step.
y, kwargs = self._batch_to_args(batch)
Expand Down

0 comments on commit 52ae254

Please sign in to comment.