Skip to content

Commit

Permalink
Black
Browse files Browse the repository at this point in the history
  • Loading branch information
victor committed Jul 10, 2024
1 parent be76704 commit 3ebf941
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion numpyro_sts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def _loc_transition(state, offset, matrix) -> jnp.ndarray:
return offset + (matrix @ state[..., None]).reshape(state.shape)


def _sample_shocks(key: PRNGKey, event_shape: Tuple[int, ...], batch_shape: Tuple[int, ...], selector: jnp.ndarray) -> jnp.ndarray:
def _sample_shocks(
key: PRNGKey, event_shape: Tuple[int, ...], batch_shape: Tuple[int, ...], selector: jnp.ndarray
) -> jnp.ndarray:
shock_shape = event_shape[:-1] + selector.shape[-1:]

flat_shape = () if not batch_shape else (reduce(lambda u, v: u * v, batch_shape),)
Expand Down
2 changes: 1 addition & 1 deletion numpyro_sts/periodic/cyclical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Cyclical(LinearTimeseries):
"""

def __init__(self, n: int, periodicity: ArrayLike, std: ArrayLike, initial_value: ArrayLike, **kwargs):
lamda, = cast_to_tensor(periodicity)
(lamda,) = cast_to_tensor(periodicity)

cos_lamda = jnp.cos(lamda)
sin_lamda = jnp.sin(lamda)
Expand Down

0 comments on commit 3ebf941

Please sign in to comment.