diff --git a/numpyro_sts/base.py b/numpyro_sts/base.py index 88c5945..b71496e 100644 --- a/numpyro_sts/base.py +++ b/numpyro_sts/base.py @@ -22,7 +22,9 @@ def _loc_transition(state, offset, matrix) -> jnp.ndarray: return offset + (matrix @ state[..., None]).reshape(state.shape) -def _sample_shocks(key: PRNGKey, event_shape: Tuple[int, ...], batch_shape: Tuple[int, ...], selector: jnp.ndarray) -> jnp.ndarray: +def _sample_shocks( + key: PRNGKey, event_shape: Tuple[int, ...], batch_shape: Tuple[int, ...], selector: jnp.ndarray +) -> jnp.ndarray: shock_shape = event_shape[:-1] + selector.shape[-1:] flat_shape = () if not batch_shape else (reduce(lambda u, v: u * v, batch_shape),) diff --git a/numpyro_sts/periodic/cyclical.py b/numpyro_sts/periodic/cyclical.py index ed705e1..67d736a 100644 --- a/numpyro_sts/periodic/cyclical.py +++ b/numpyro_sts/periodic/cyclical.py @@ -14,7 +14,7 @@ class Cyclical(LinearTimeseries): """ def __init__(self, n: int, periodicity: ArrayLike, std: ArrayLike, initial_value: ArrayLike, **kwargs): - lamda, = cast_to_tensor(periodicity) + (lamda,) = cast_to_tensor(periodicity) cos_lamda = jnp.cos(lamda) sin_lamda = jnp.sin(lamda)