Skip to content

Commit

Permalink
Merge branch 'feature/outliers' into feature/poisson
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed May 18, 2023
2 parents 1c40d7d + 7c1b3c4 commit 9ee27b1
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 54 deletions.
27 changes: 0 additions & 27 deletions .circleci/config.yml

This file was deleted.

27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Torchcast Unit Tests

on:
push:
pull_request:

jobs:
test:
name: Run tests
runs-on: ubuntu-latest
strategy:
matrix:
py-version: ['3.7', '3.8', '3.9', '3.10']

steps:

- uses: actions/checkout@v3

- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py-version }}

- name: Install torchcast
run: pip install .[tests]

- name: Run tests
run: python3 -m unittest
14 changes: 8 additions & 6 deletions torchcast/exp_smooth/exp_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
from torch import Tensor

from torchcast.state_space import Predictions
from torchcast.exp_smooth.smoothing_matrix import SmoothingMatrix
from torchcast.covariance import Covariance
from torchcast.process import Process
Expand All @@ -26,24 +25,27 @@ def _mask_mats(self,
input: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
# torchscript doesn't support super, see: https://github.com/pytorch/pytorch/issues/42885
new_kwargs = kwargs.copy()
if val_idx is None:
return input[groups], {k: v[groups] for k, v in kwargs.items()}
for k in ['H', 'R', 'K']:
new_kwargs[k] = kwargs[k][groups]
return input[groups], new_kwargs
else:
m1d = torch.meshgrid(groups, val_idx, indexing='ij')
m2d = torch.meshgrid(groups, val_idx, val_idx, indexing='ij')
masked_input = input[m1d[0], m1d[1]]
masked_kwargs = {
new_kwargs.update({
'H': kwargs['H'][m1d[0], m1d[1]],
'R': kwargs['R'][m2d[0], m2d[1], m2d[2]],
'K': kwargs['K'][m1d[0], m1d[1]],
}
return masked_input, masked_kwargs
})
return masked_input, new_kwargs

def _update(self,
input: Tensor,
mean: Tensor,
cov: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
measured_mean = (kwargs['H'] @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean
new_mean = mean + (kwargs['K'] @ resid.unsqueeze(-1)).squeeze(-1)
Expand Down
63 changes: 52 additions & 11 deletions torchcast/kalman_filter/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,57 @@ def _mask_mats(self,
val_idx: Optional[Tensor],
input: Tensor,
kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
new_kwargs = kwargs.copy()
if val_idx is None:
return input[groups], {k: v[groups] for k, v in kwargs.items()}
for k in ['H', 'R']:
new_kwargs[k] = kwargs[k][groups]
return input[groups], new_kwargs
else:
m1d = torch.meshgrid(groups, val_idx, indexing='ij')
m2d = torch.meshgrid(groups, val_idx, val_idx, indexing='ij')
masked_input = input[m1d[0], m1d[1]]
masked_kwargs = {
new_kwargs.update({
'H': kwargs['H'][m1d[0], m1d[1]],
'R': kwargs['R'][m2d[0], m2d[1], m2d[2]]
}
return masked_input, masked_kwargs
})
return masked_input, new_kwargs

def _update(self, input: Tensor, mean: Tensor, cov: Tensor, kwargs: Dict[str, Tensor]) -> Tuple[Tensor, Tensor]:
H = kwargs['H']
R = kwargs['R']
K = self._kalman_gain(cov=cov, H=H, R=R)
Ht = H.permute(0, 2, 1)
system_covariance = torch.baddbmm(R, H @ cov, Ht)

# kalman-gain:
K = self._kalman_gain(cov=cov, Ht=Ht, system_covariance=system_covariance)

# residuals:
measured_mean = (H @ mean.unsqueeze(-1)).squeeze(-1)
resid = input - measured_mean
new_mean = mean + (K @ resid.unsqueeze(-1)).squeeze(-1)
new_cov = self._covariance_update(cov=cov, K=K, H=H, R=R)

# outlier-rejection:
valid_mask = self._get_update_mask(resid, system_covariance, outlier_threshold=kwargs['outlier_threshold'])

# update:
new_mean = mean.clone()
new_mean[valid_mask] = mean[valid_mask] + (K[valid_mask] @ resid[valid_mask].unsqueeze(-1)).squeeze(-1)
new_cov = cov.clone()
new_cov[valid_mask] = self._covariance_update(
cov=cov[valid_mask], K=K[valid_mask], H=H[valid_mask], R=R[valid_mask]
)

return new_mean, new_cov

@staticmethod
def _get_update_mask(resid: torch.Tensor,
system_covariance: torch.Tensor,
outlier_threshold: torch.Tensor) -> torch.Tensor:
if outlier_threshold > 0:
mdist = mahalanobis_dist(resid, system_covariance)
return mdist <= outlier_threshold
else:
return torch.ones(len(resid), dtype=torch.bool, device=resid.device)

def _covariance_update(self, cov: Tensor, K: Tensor, H: Tensor, R: Tensor) -> Tensor:
I = torch.eye(cov.shape[1], dtype=cov.dtype, device=cov.device).unsqueeze(0)
ikh = I - K @ H
Expand All @@ -71,17 +100,22 @@ def _covariance_update(self, cov: Tensor, K: Tensor, H: Tensor, R: Tensor) -> Te
return ikh @ cov

@staticmethod
def _kalman_gain(cov: Tensor, H: Tensor, R: Tensor) -> Tensor:
Ht = H.permute(0, 2, 1)
def _kalman_gain(cov: Tensor, Ht: Tensor, system_covariance: Tensor) -> Tensor:
covs_measured = cov @ Ht
system_covariance = torch.baddbmm(R, H @ cov, Ht)
A = system_covariance.permute(0, 2, 1)
B = covs_measured.permute(0, 2, 1)
Kt = torch.linalg.solve(A, B)
K = Kt.permute(0, 2, 1)
return K


def mahalanobis_dist(diff: torch.Tensor, covariance: torch.Tensor) -> torch.Tensor:
cholesky = torch.linalg.cholesky(covariance)
y = torch.cholesky_solve(diff.unsqueeze(-1), cholesky).squeeze(-1)
mahalanobis_dist = torch.sqrt(torch.sum(diff * y, 1))
return mahalanobis_dist


class KalmanFilter(StateSpaceModel):
"""
Uses the full kalman-filtering algorithm for generating forecasts.
Expand All @@ -90,14 +124,19 @@ class KalmanFilter(StateSpaceModel):
:param measures: A list of strings specifying the names of the dimensions of the time-series being measured.
:param process_covariance: A module created with ``Covariance.from_processes(processes)``.
:param measure_covariance: A module created with ``Covariance.from_measures(measures)``.
:param outlier_threshold: If specified, outliers are rejected using the mahalanobis distance.
:param outlier_burnin: If outlier_threshold is specified, this specifies the number of steps to wait before
starting to reject outliers.
"""
ss_step_cls = KalmanStep

def __init__(self,
processes: Sequence[Process],
measures: Optional[Sequence[str]] = None,
process_covariance: Optional[Covariance] = None,
measure_covariance: Optional[Covariance] = None):
measure_covariance: Optional[Covariance] = None,
outlier_threshold: float = 0.,
outlier_burnin: Optional[int] = None):

initial_covariance = Covariance.from_processes(processes, cov_type='initial')

Expand All @@ -111,6 +150,8 @@ def __init__(self,
processes=processes,
measures=measures,
measure_covariance=measure_covariance,
outlier_threshold=outlier_threshold,
outlier_burnin=outlier_burnin
)
self.process_covariance = process_covariance.set_id('process_covariance')
self.initial_covariance = initial_covariance.set_id('initial_covariance')
Expand Down
20 changes: 17 additions & 3 deletions torchcast/state_space/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,25 @@ class StateSpaceModel(nn.Module):
:param processes: A list of :class:`.Process` modules.
:param measures: A list of strings specifying the names of the dimensions of the time-series being measured.
:param measure_covariance: A module created with ``Covariance.from_measures(measures)``.
:param outlier_threshold: If specified, outliers are rejected using the mahalanobis distance.
:param outlier_burnin: If outlier_threshold is specified, this specifies the number of steps to wait before
starting to reject outliers.
"""
ss_step_cls: Type[StateSpaceStep]

def __init__(self,
processes: Sequence[Process],
measures: Optional[Sequence[str]] = None,
measure_covariance: Optional[Covariance] = None):
measure_covariance: Optional[Covariance] = None,
outlier_threshold: float = 0.0,
outlier_burnin: Optional[int] = None):
super().__init__()

self.outlier_threshold = outlier_threshold
if self.outlier_threshold and outlier_burnin is None:
raise ValueError("If `outlier_threshold` is set, `outlier_burnin` must be set as well.")
self.outlier_burnin = outlier_burnin or 0

if isinstance(measures, str):
measures = [measures]
warn(f"`measures` should be a list of strings not a string; interpreted as `{measures}`.")
Expand Down Expand Up @@ -455,18 +465,22 @@ def _script_forward(self,
mean1s.append(mean1step)
cov1s.append(cov1step)
if t < len(inputs):
update_kwargs_t = {k: v[t] for k, v in update_kwargs.items()}
update_kwargs_t['outlier_threshold'] = torch.tensor(
self.outlier_threshold if t > self.outlier_burnin else 0.
)
meanu, covu = self.ss_step.update(
inputs[t],
mean1step,
cov1step,
{k: v[t] for k, v in update_kwargs.items()}
update_kwargs_t,
)
else:
meanu, covu = mean1step, cov1step
meanus.append(meanu)
covus.append(covu)

# 2nd loop to get n_step updates:
# 2nd loop to get n_step predicts:
# idx: Dict[int, int] = {}
meanps: Dict[int, Tensor] = {}
covps: Dict[int, Tensor] = {}
Expand Down
27 changes: 20 additions & 7 deletions torchcast/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import datetime
import itertools

from typing import Sequence, Any, Union, Optional, Tuple, Callable
from typing import Sequence, Any, Union, Optional, Tuple, Callable, TYPE_CHECKING
from warnings import warn

import numpy as np
import torch

from torch import Tensor
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset, Dataset

from torchcast.internals.utils import ragged_cat, true1d_idx

if TYPE_CHECKING:
from pandas import DataFrame


class TimeSeriesDataset(TensorDataset):
"""
Expand Down Expand Up @@ -573,14 +577,15 @@ def from_dataframe(cls,


def complete_times(data: 'DataFrame',
group_colname: str,
group_colnames: Sequence[str] = None,
time_colname: Optional[str] = None,
dt_unit: Optional[str] = None):
dt_unit: Optional[str] = None,
group_colname: Optional[str] = None):
"""
Given a dataframe time-serieses, convert implicit missings within each time-series to explicit missings.
:param data: A pandas dataframe.
:param group_colname: The column name for the groups.
:param group_colnames: The column name(s) for the groups.
:param time_colname: The column name for the times. Will attempt to guess based on common labels.
:param dt_unit: A :class:`numpy.datetime64` or string representing the datetime increments. If not supplied will
try to guess based on the smallest difference in the data.
Expand All @@ -589,6 +594,14 @@ def complete_times(data: 'DataFrame',
"""
import pandas as pd

if isinstance(group_colnames, str):
group_colnames = [group_colnames]
elif group_colnames is None:
if group_colname is None:
raise TypeError("Missing required argument `group_colnames`")
warn("Please pass `group_colnames` instead of `group_colname`", DeprecationWarning)
group_colnames = [group_colname]

if time_colname is None:
for col in ('datetime', 'date', 'timestamp', 'time', 'dt'):
if col in data.columns:
Expand All @@ -607,7 +620,7 @@ def complete_times(data: 'DataFrame',
)

df_group_summary = data. \
groupby(group_colname). \
groupby(group_colnames). \
agg(_min=(time_colname, 'min'),
_max=(time_colname, 'max')). \
reset_index()
Expand All @@ -618,9 +631,9 @@ def complete_times(data: 'DataFrame',
merge(df_group_summary.assign(_cj=1), how='left', on=['_cj'])
# filter to min/max for each group
df_cj = df_cj. \
loc[df_cj[time_colname].between(df_cj['_min'], df_cj['_max']), [group_colname, time_colname]]. \
loc[df_cj[time_colname].between(df_cj['_min'], df_cj['_max']), group_colnames + [time_colname]]. \
reset_index(drop=True)
return df_cj.merge(data, how='left', on=[group_colname, time_colname])
return df_cj.merge(data, how='left', on=group_colnames + [time_colname])


def chunk_grouped_data(*tensors, group_ids: Sequence) -> Sequence[Tuple[Tensor]]:
Expand Down

0 comments on commit 9ee27b1

Please sign in to comment.