Skip to content

Commit

Permalink
Feature/seasonality (#4)
Browse files Browse the repository at this point in the history
* Adds a seasonal timeseries

* initial commit

* Rename to periodic

* More assertions

* Adds test

* Adds trigonometric seasonality

* Adds testing

* Adds ervent shape

* Adds optionals and notebook

* NB

* Use corret mask, NB

* Sigh

* NB

* Correct

* Fix

* Removes mask functionality in favor of more general matrix multipplication

* Fixes tests

* Temp

* Remmoves trigonometric.py for now

* Fixes

* Slightly improved

* Fix

* Fix

* Adds test, to fix

* Seasonal working

* FIx

* FIx

* FIx

* Adds determinstic sampling procedure

* ix

* Fix

* Fix

* Slightly cleaner

* Fix

* Fix

* Adds trigonometric series

* Minor cleanup

* Docs

* Black

* Fix

* "Fix"

* Test fix

* version

* Think this is the neatest way

* Predictions

* Black

* NB

* Bug fix

* NB

* NB

* NB

* NB

* rename

* fix

* Also ensure that posterior std is lower than prior

* Backwards compatability

* Docs

* Do not sample

* Typing

* Adds test just in case

* TEst deterministic seasonal

* Adds selector to stuff

* Fix

* NB

* Fix

* Fix

---------

Co-authored-by: victor <baj>
  • Loading branch information
tingiskhan authored Nov 4, 2024
1 parent 84f8ff8 commit c5bb414
Show file tree
Hide file tree
Showing 12 changed files with 1,292 additions and 75 deletions.
964 changes: 964 additions & 0 deletions notebooks/mauna-loa-co2.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion numpyro_sts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .llt import LocalLinearTrend
from .random_walk import RandomWalk
from .smooth_llt import SmoothLocalLinearTrend
from . import periodic

__version__ = "0.0.3"
__version__ = "0.1.0"
5 changes: 2 additions & 3 deletions numpyro_sts/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
offset = mu * (1.0 - phi.squeeze(-1))

init = jnp.reshape(initial_value if initial_value is not None else jnp.zeros(order), batch_shape + (order,))
mask = np.eye(order, 1, dtype=np.bool_).squeeze(-1)

mask = np.array([True] + (order - 1) * [False], dtype=jnp.bool_)

super().__init__(n, offset, phi, std, init, mask=mask, **kwargs)
super().__init__(n, offset, phi, std, init, column_mask=mask, **kwargs)
222 changes: 159 additions & 63 deletions numpyro_sts/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import warnings
from functools import reduce
from typing import Tuple

import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax.random import normal
from jax.random import normal, PRNGKey
from numpyro.contrib.control_flow import scan
from numpyro.distributions import Distribution, Normal, constraints, MultivariateNormal
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key
from jax.typing import ArrayLike
import jax.scipy.linalg as linalg


def _broadcast_and_reshape(x: jnp.ndarray, shape, dim: int) -> jnp.ndarray:
Expand All @@ -17,6 +23,37 @@ def _loc_transition(state, offset, matrix) -> jnp.ndarray:
return offset + (matrix @ state[..., None]).reshape(state.shape)


def _sample_shocks(
key: PRNGKey, event_shape: Tuple[int, ...], batch_shape: Tuple[int, ...], selector: jnp.ndarray
) -> jnp.ndarray:
shock_shape = event_shape[:-1] + selector.shape[-1:]

flat_shape = () if not batch_shape else (reduce(lambda u, v: u * v, batch_shape),)
samples = normal(key, shape=flat_shape + shock_shape)

fun = jnp.matmul
if batch_shape:
selector = jnp.broadcast_to(selector, samples.shape[:1] + selector.shape)
fun = vmap(fun)

rotated_samples = fun(selector, samples[..., None]).squeeze(-1)

return rotated_samples.reshape(batch_shape + event_shape)


def _verify_parameters(offset, matrix, std, initial_value, std_is_matrix):
ndim = matrix.shape[-1]

assert initial_value.ndim >= 1
assert matrix.ndim >= 2 and matrix.shape[-2] == matrix.shape[-1] == ndim
assert initial_value.shape[-1] == matrix.shape[-1]

assert offset.shape[-1] == initial_value.shape[-1]

if std_is_matrix:
assert std.ndim >= 2 and std.shape[-1] == std.shape[-2] == ndim


class LinearTimeseries(Distribution):
r"""
Defines a base model for linear stochastic models with Gaussian increments.
Expand All @@ -26,11 +63,11 @@ class LinearTimeseries(Distribution):
matrix: Matrix of linear combination of states. Of size :math:`[batch size] \times dimension \times dimension`.
std: Standard deviation of innovations. Of size :math:`[batch size] \times dimension`.
initial_value: Initial value of the time series. Of size :math:`[batch size] \times dimension`.
mask: Mask for removing specific columns from calculation.
column_mask: Mask for constructing the "selector" matrix.
"""

pytree_data_fields = ("offset", "matrix", "std", "initial_value", "mask")
pytree_aux_fields = ("n", "_sample_shape", "_column_mask", "_std_is_matrix")
pytree_data_fields = ("offset", "matrix", "std", "initial_value")
pytree_aux_fields = ("n", "_std_is_matrix", "column_mask", "selector")

support = constraints.real_matrix
has_enumerate_support = False
Expand All @@ -43,16 +80,6 @@ class LinearTimeseries(Distribution):
"n": constraints.positive_integer,
}

@staticmethod
def _verify_parameters(offset, matrix, std, initial_value, std_is_matrix):
ndim = matrix.shape[-1]

assert initial_value.ndim >= 1
assert matrix.ndim >= 2 and matrix.shape[-2] == matrix.shape[-1] == ndim

if std_is_matrix:
assert std.ndim >= 2 and std.shape[-1] == std.shape[-2] == ndim

def __init__(
self,
n: int,
Expand All @@ -62,10 +89,15 @@ def __init__(
initial_value: ArrayLike,
*,
std_is_matrix: bool = False,
mask: ArrayLike = None,
column_mask: np.ndarray = None,
validate_args=None,
**kwargs,
):
self._verify_parameters(offset, matrix, std, initial_value, std_is_matrix)
if "mask" in kwargs:
warnings.warn("'mask' is deprecated in favor of 'column_mask'", DeprecationWarning)
column_mask = kwargs.pop("mask")

_verify_parameters(offset, matrix, std, initial_value, std_is_matrix)
times = jnp.arange(n)

self._std_is_matrix = std_is_matrix
Expand All @@ -85,50 +117,37 @@ def __init__(
std_shape = parameter_shape if not self._std_is_matrix else parameter_shape + initial_value.shape[-1:]
self.std = jnp.broadcast_to(std, std_shape)

cols_to_sample = event_shape[-1]
if mask is not None:
assert mask.shape == event_shape[-1:], "Shapes not congruent!"
cols_to_sample = mask.sum(axis=-1)

self._column_mask = mask
self._shock_shape = times.shape + (cols_to_sample,)

super().__init__(batch_shape=batch_shape, event_shape=event_shape, validate_args=validate_args)

def _sample_shocks(self, key, batch_shape) -> jnp.ndarray:
samples = normal(key, shape=batch_shape + self._shock_shape)

if self._column_mask is None:
return samples
if column_mask is None:
column_mask = np.ones(self.event_shape[-1], dtype=np.bool_)

result = jnp.zeros(batch_shape + self.event_shape, dtype=samples.dtype)
return result.at[..., self._column_mask].set(samples)
self.column_mask = column_mask
self.selector = np.eye(self.event_shape[-1])[..., self.column_mask]

def sample(self, key, sample_shape=()):
assert is_prng_key(key)

batch_shape = sample_shape + self.batch_shape

def body(state, xs):
(eps_tp1,) = xs
def body(state, eps_tp1):
x_tp1 = self.sample_from_shock(state, eps_tp1)

return x_tp1, x_tp1

def scan_fn(init, noise):
return scan(body, init, (noise,))
return scan(body, init, noise)

batch_shape = sample_shape + self.batch_shape

eps = self._sample_shocks(key, batch_shape)
shocks = _sample_shocks(key, self.event_shape, batch_shape, self.selector)
inits = jnp.broadcast_to(self.initial_value, sample_shape + self.initial_value.shape)

batch_dim = len(batch_shape)
if batch_dim:
eps = jnp.moveaxis(eps, -2, 0)
_, samples = scan_fn(inits, eps)
shocks = jnp.moveaxis(shocks, -2, 0)
_, samples = scan_fn(inits, shocks)

return jnp.moveaxis(samples, 0, -2)

return scan_fn(inits, eps)[-1]
return scan_fn(inits, shocks)[-1]

@validate_sample
def log_prob(self, value):
Expand All @@ -140,40 +159,46 @@ def log_prob(self, value):
initial_value = jnp.broadcast_to(initial_value, sample_shape + initial_value.shape[-2:])

stacked = jnp.concatenate([initial_value, value], axis=-2)
loc_fun = _loc_transition

if sample_shape:
stacked_reshape = stacked.reshape((-1,) + stacked.shape[-2:])
x_tm1 = stacked_reshape[:, :-1]
offset = self.offset
matrix = self.matrix
std = self.std

offset = _broadcast_and_reshape(self.offset, sample_shape, -1)
matrix = _broadcast_and_reshape(self.matrix, sample_shape, -2)
selector = self.selector
inverse_fun = jnp.matmul

loc = vmap(_loc_transition)(x_tm1, offset, matrix).reshape(sample_shape + self.event_shape)
else:
x_tm1 = stacked[:-1]
loc = _loc_transition(x_tm1, self.offset, self.matrix)
if sample_shape:
stacked = stacked.reshape((-1,) + stacked.shape[-2:])

x_t = stacked[..., 1:, :]
offset = _broadcast_and_reshape(offset, sample_shape, -1)
matrix = _broadcast_and_reshape(matrix, sample_shape, -2)
std = _broadcast_and_reshape(std, sample_shape, -2 if self._std_is_matrix else -1)

std = self.std
if not self._std_is_matrix:
std = jnp.expand_dims(std, -2)
selector = jnp.broadcast_to(selector, stacked.shape[:1] + selector.shape)

if self._column_mask is not None:
loc = loc[..., self._column_mask]
std = std[..., self._column_mask]
loc_fun = vmap(_loc_transition)
inverse_fun = vmap(inverse_fun)

x_tm1 = stacked[..., :-1, :]
x_t = stacked[..., 1:, :]

if self._std_is_matrix:
std = std[..., self._column_mask, :]
loc = loc_fun(x_tm1, offset, matrix)

x_t = x_t[..., self._column_mask]
transposed_selector = selector.swapaxes(-1, -2)
loc = inverse_fun(transposed_selector, loc[..., None]).squeeze(-1)
x_t = inverse_fun(transposed_selector, x_t[..., None]).squeeze(-1)

if not self._std_is_matrix:
dist = Normal(loc, std).to_event(1)
std = inverse_fun(transposed_selector, std[..., None]).swapaxes(-1, -2)
dist = Normal(loc, std).to_event(2)
else:
dist = MultivariateNormal(loc, scale_tril=std)
std = inverse_fun(transposed_selector, std)
std = jnp.expand_dims(std, -3)

return dist.log_prob(x_t).sum(axis=-1)
dist = MultivariateNormal(loc, scale_tril=std).to_event(1)

return dist.log_prob(x_t).reshape(sample_shape)

def sample_from_shock(self, x_t, eps_t: jnp.ndarray) -> jnp.ndarray:
"""
Expand All @@ -193,3 +218,74 @@ def sample_from_shock(self, x_t, eps_t: jnp.ndarray) -> jnp.ndarray:
return loc + self.std * eps_t

return loc + (self.std @ eps_t[..., None]).squeeze(-1)

def union(self, other: "LinearTimeseries") -> "LinearTimeseries":
"""
Combines self with other series to create a joint series.
Args:
other: Other series to combine.
Returns:
Returns a new instance of :class:`LinearTimeseries`.
"""

assert self.n == other.n, "Number of steps do not match!"
batch_shape = jnp.broadcast_shapes(self.batch_shape, other.batch_shape)

assert not batch_shape, "Currently does not support batch shapes!"

matrix = linalg.block_diag(self.matrix, other.matrix)
offset = jnp.concatenate([self.offset, other.offset], axis=-1)
initial_value = jnp.concatenate([self.initial_value, other.initial_value], axis=-1)

# TODO: fix other ones as well
std = jnp.concatenate([self.std, other.std], axis=-1)
mask = np.concatenate([self.column_mask, other.column_mask], axis=-1)

model = LinearTimeseries(self.n, offset, matrix, std, initial_value, column_mask=mask, std_is_matrix=False)

return model

def deterministic(self) -> jnp.ndarray:
"""
Constructs a deterministic version of the timeseries and "samples" it.
Returns:
An array.
"""

model = LinearTimeseries(
self.n,
self.offset,
self.matrix,
self.std,
self.initial_value,
column_mask=np.zeros_like(self.column_mask),
)

return model.sample(PRNGKey(0))

def predict(self, n: int, value: jnp.ndarray) -> "LinearTimeseries":
"""
Creates a "prediction" instance of self.
Args:
n: Number of future predictions.
value: New start value.
Returns:
Returns new instance of :class:`LinearTimeseries`.
"""

future_model = LinearTimeseries(
n,
self.offset,
self.matrix,
self.std,
value,
std_is_matrix=self._std_is_matrix,
column_mask=self.column_mask,
)

return future_model
3 changes: 3 additions & 0 deletions numpyro_sts/periodic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .time import TimeSeasonal
from .cyclical import Cyclical
from .trigonometric import TrigonometricSeasonal
29 changes: 29 additions & 0 deletions numpyro_sts/periodic/cyclical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import jax.numpy as jnp
from numpy.typing import ArrayLike

from ..base import LinearTimeseries
from ..util import cast_to_tensor


class Cyclical(LinearTimeseries):
"""
Represents a cyclical component by means of a trigonometric series TODO: math
Args:
periodicity: Periodicity of component.
"""

def __init__(self, n: int, periodicity: ArrayLike, std: ArrayLike, initial_value: ArrayLike, **kwargs):
(lamda,) = cast_to_tensor(periodicity)

cos_lamda = jnp.cos(lamda)
sin_lamda = jnp.sin(lamda)

top = jnp.stack([cos_lamda, sin_lamda], axis=-1)
bottom = jnp.stack([-sin_lamda, cos_lamda], axis=-1)
matrix = jnp.stack([top, bottom], axis=-2)

offset = jnp.zeros(2)
std = jnp.full_like(offset, std)

super().__init__(n, offset, matrix, std, initial_value, **kwargs)
33 changes: 33 additions & 0 deletions numpyro_sts/periodic/time.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from numpy.typing import ArrayLike
import jax.numpy as jnp
import numpy as np

from ..base import LinearTimeseries
from ..util import cast_to_tensor


class TimeSeasonal(LinearTimeseries):
r"""
Represents a periodic series of the form:
.. math::
\gamma_{t + 1} = \sum_{i = 1}^{s - 1} \gamma_{t + 1 - j} + \eps_{t + 1}
Args:
num_seasons: Number of seasons to include.
"""

def __init__(self, n: int, num_seasons: int, std: ArrayLike, initial_value: ArrayLike, **kwargs):
top = -jnp.ones([1, num_seasons - 1])
bottom = jnp.eye(num_seasons - 2, num_seasons - 1)

matrix = jnp.concatenate([top, bottom], axis=-2)
offset = jnp.zeros_like(top).squeeze(-2)

std, initial_value = cast_to_tensor(std, initial_value)
std = jnp.concatenate([std[..., None], jnp.zeros(num_seasons - 2)], axis=-1)

mask = np.eye(num_seasons - 1, 1, dtype=np.bool_).squeeze(-1)

super().__init__(n, offset, matrix, std, initial_value, column_mask=mask, **kwargs)
Loading

0 comments on commit c5bb414

Please sign in to comment.