Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TimeSeriesDataset .split_measures #26

Merged
merged 2 commits into from
Dec 27, 2024
Merged

TimeSeriesDataset .split_measures #26

merged 2 commits into from
Dec 27, 2024

Conversation

jwdink
Copy link
Collaborator

@jwdink jwdink commented Dec 10, 2024

Have not tested at all but I'm thinking something like this...

Then we'd do

dataset = TimeSeriesDataset.from_dataframe(
            dataframe=usage,
            dt_unit="D",
            group_colname=UsageColumns.DEVICE_ID,
            time_colname=UsageColumns.DATE,
            device=self.device,
            y_colnames=[self.measure_name],
            X_colnames=self.var_predict_colnames + self.static_features
        )

dataset = dataset.split_measures([self.measure_name], self.var_predict_colnames, self.static_features)

@jwdink jwdink requested a review from jamesvrt December 10, 2024 00:34
@jamesvrt
Copy link

Initializing used_gb_bc_trend.position to -3.4558966159820557
  0%|                                                                                           | 0/12 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 485, in <module>
    main()
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 382, in main
    df_backtest = run_backtest(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 255, in run_backtest
    df_forecast = get_forecast_with_caching(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 318, in get_forecast_with_caching
    model.fit(data)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 312, in fit
    self._fit_kalman_filter(dataset_train, **kwargs)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 374, in _fit_kalman_filter
    self.kalman_filter_.fit(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 169, in fit
    train_loss = optimizer.step(closure).item()
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/optim/lbfgs.py", line 330, in step
    orig_loss = closure()
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 153, in closure
    pred = self(y, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 337, in forward
    preds, updates, design_mats = self._script_forward(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 478, in _script_forward
    predict_kwargs, update_kwargs = self._build_design_mats(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/kalman_filter/kalman_filter.py", line 168, in _build_design_mats
    Fs, Hs = self._build_transition_and_measure_mats(kwargs_per_process, num_groups, out_timesteps)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 589, in _build_transition_and_measure_mats
    pH, pF = process(inputs=p_kwargs, num_groups=num_groups, num_times=out_timesteps)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/process/base.py", line 61, in forward
    H = self._build_h_mat(inputs, num_groups, num_times)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/process/regression.py", line 62, in _build_h_mat
    X = inputs['X']
KeyError: 'X'

@jwdink
Copy link
Collaborator Author

jwdink commented Dec 12, 2024

@jamesvrt try again with 62ae734? this change will also let us avoid the awkward monkey-patching I did here

@jamesvrt
Copy link

Initializing used_gb_bc_trend.position to -3.4558966159820557
Epoch: 0; Loss: 2.1363:   8%|████▉                                                      | 1/12 [00:02<00:28,  2.59s/it]Traceback (most recent call last):
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 485, in <module>
    main()
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 382, in main
    df_backtest = run_backtest(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 255, in run_backtest
    df_forecast = get_forecast_with_caching(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 318, in get_forecast_with_caching
    model.fit(data)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 313, in fit
    self._fit_kalman_filter(dataset_train, **kwargs)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 393, in _fit_kalman_filter
    raise e
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 375, in _fit_kalman_filter
    self.kalman_filter_.fit(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 169, in fit
    train_loss = optimizer.step(closure).item()
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/optim/lbfgs.py", line 444, in step
    loss, flat_grad, t, ls_func_evals = _strong_wolfe(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/optim/lbfgs.py", line 48, in _strong_wolfe
    f_new, g_new = obj_func(x, t, d)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/optim/lbfgs.py", line 442, in obj_func
    return self._directional_evaluate(closure, x, t, d)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/optim/lbfgs.py", line 296, in _directional_evaluate
    loss = float(closure())
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 157, in closure
    loss.backward()
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

@jwdink
Copy link
Collaborator Author

jwdink commented Dec 12, 2024

that one will be harder to fix... EDIT: nvm I think I got it

@jwdink
Copy link
Collaborator Author

jwdink commented Dec 12, 2024

@jamesvrt
Copy link

Completes training but fails prediction. I can't figure out what the difference would be since they both use your changes to _make_forward_kwargs()

Initializing used_gb_bc_trend.position to -3.4558966159820557
Epoch: 16; Loss: 0.7233: 100%|███████████████████████████████████████████████████████| 12/12 [00:13<00:00,  1.14s/it]
/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py:608: UserWarning: There are unused keyword arguments:
{'callable_kwargs'}
  warn(f"There are unused keyword arguments:\n{unused}")
Traceback (most recent call last):
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 485, in <module>
    main()
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 382, in main
    df_backtest = run_backtest(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 255, in run_backtest
    df_forecast = get_forecast_with_caching(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 318, in get_forecast_with_caching
    model.fit(data)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 314, in fit
    self._fit_kalman_filter(dataset_train, **kwargs)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 382, in _fit_kalman_filter
    pred = self.kalman_filter_(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 337, in forward
    preds, updates, design_mats = self._script_forward(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 478, in _script_forward
    predict_kwargs, update_kwargs = self._build_design_mats(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/kalman_filter/kalman_filter.py", line 168, in _build_design_mats
    Fs, Hs = self._build_transition_and_measure_mats(kwargs_per_process, num_groups, out_timesteps)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 589, in _build_transition_and_measure_mats
    pH, pF = process(inputs=p_kwargs, num_groups=num_groups, num_times=out_timesteps)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/process/base.py", line 61, in forward
    H = self._build_h_mat(inputs, num_groups, num_times)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/process/regression.py", line 63, in _build_h_mat
    X = inputs[self.expected_kwargs[0]]
KeyError: 'static_features_module_output'

@jwdink
Copy link
Collaborator Author

jwdink commented Dec 13, 2024

@jamesvrt
Copy link

jamesvrt commented Dec 16, 2024

Making progress. Now forecast() fails:

Traceback (most recent call last):
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 485, in <module>
    main()
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 382, in main
    df_backtest = run_backtest(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 255, in run_backtest
    df_forecast = get_forecast_with_caching(
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/scripts/phone-forecast-and-optimize.py", line 319, in get_forecast_with_caching
    df_forecast = model.forecast(df_for_fct, days=forecast_horizon)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 490, in forecast
    for dataset, preds in self._iter_forecasts(usage_df, days, batch_size=2500, **kwargs):
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 36, in generator_context
    response = gen.send(None)
  File "/home/jamesvrt/code/202308_bark/bark-ds-etl/barkdsetl/forecasting/torchcast_forecaster.py", line 453, in _iter_forecasts
    preds = self.kalman_filter_(dataset.tensors[0], **fkwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 337, in forward
    preds, updates, design_mats = self._script_forward(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 478, in _script_forward
    predict_kwargs, update_kwargs = self._build_design_mats(
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/kalman_filter/kalman_filter.py", line 168, in _build_design_mats
    Fs, Hs = self._build_transition_and_measure_mats(kwargs_per_process, num_groups, out_timesteps)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/state_space/base.py", line 589, in _build_transition_and_measure_mats
    pH, pF = process(inputs=p_kwargs, num_groups=num_groups, num_times=out_timesteps)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/process/base.py", line 61, in forward
    H = self._build_h_mat(inputs, num_groups, num_times)
  File "/home/jamesvrt/miniconda3/envs/bark/lib/python3.10/site-packages/torchcast/process/regression.py", line 64, in _build_h_mat
    assert not torch.isnan(X).any()
AssertionError

I'm running with 10k devices and there are no nans in static features. Again, I took a look at debugging myself but I can't follow what's going on at this level.

Thanks for your help getting this sorted.

@jwdink jwdink marked this pull request as ready for review December 27, 2024 19:38
@jwdink jwdink merged commit 5de80a8 into main Dec 27, 2024
5 checks passed
@jwdink jwdink deleted the hotfix/split_measures branch December 27, 2024 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants