Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
victor committed Jul 10, 2024
1 parent 3ebf941 commit 5ba892c
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions numpyro_sts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ def _sample_shocks(
return rotated_samples.reshape(batch_shape + event_shape)


def _verify_parameters(offset, matrix, std, initial_value, std_is_matrix):
ndim = matrix.shape[-1]

assert initial_value.ndim >= 1
assert matrix.ndim >= 2 and matrix.shape[-2] == matrix.shape[-1] == ndim
assert initial_value.shape[-1] == matrix.shape[-1]

assert offset.shape[-1] == initial_value.shape[-1]

if std_is_matrix:
assert std.ndim >= 2 and std.shape[-1] == std.shape[-2] == ndim


class LinearTimeseries(Distribution):
r"""
Defines a base model for linear stochastic models with Gaussian increments.
Expand All @@ -66,17 +79,6 @@ class LinearTimeseries(Distribution):
"n": constraints.positive_integer,
}

@staticmethod
def _verify_parameters(offset, matrix, std, initial_value, std_is_matrix):
ndim = matrix.shape[-1]

assert initial_value.ndim >= 1
assert matrix.ndim >= 2 and matrix.shape[-2] == matrix.shape[-1] == ndim
assert initial_value.shape[-1] == matrix.shape[-1]

if std_is_matrix:
assert std.ndim >= 2 and std.shape[-1] == std.shape[-2] == ndim

def __init__(
self,
n: int,
Expand All @@ -89,7 +91,7 @@ def __init__(
column_mask: np.ndarray = None,
validate_args=None,
):
self._verify_parameters(offset, matrix, std, initial_value, std_is_matrix)
_verify_parameters(offset, matrix, std, initial_value, std_is_matrix)
times = jnp.arange(n)

self._std_is_matrix = std_is_matrix
Expand Down

0 comments on commit 5ba892c

Please sign in to comment.