diff --git a/numpyro_sts/base.py b/numpyro_sts/base.py index b71496e..b8bfa34 100644 --- a/numpyro_sts/base.py +++ b/numpyro_sts/base.py @@ -40,6 +40,19 @@ def _sample_shocks( return rotated_samples.reshape(batch_shape + event_shape) +def _verify_parameters(offset, matrix, std, initial_value, std_is_matrix): + ndim = matrix.shape[-1] + + assert initial_value.ndim >= 1 + assert matrix.ndim >= 2 and matrix.shape[-2] == matrix.shape[-1] == ndim + assert initial_value.shape[-1] == matrix.shape[-1] + + assert offset.shape[-1] == initial_value.shape[-1] + + if std_is_matrix: + assert std.ndim >= 2 and std.shape[-1] == std.shape[-2] == ndim + + class LinearTimeseries(Distribution): r""" Defines a base model for linear stochastic models with Gaussian increments. @@ -66,17 +79,6 @@ class LinearTimeseries(Distribution): "n": constraints.positive_integer, } - @staticmethod - def _verify_parameters(offset, matrix, std, initial_value, std_is_matrix): - ndim = matrix.shape[-1] - - assert initial_value.ndim >= 1 - assert matrix.ndim >= 2 and matrix.shape[-2] == matrix.shape[-1] == ndim - assert initial_value.shape[-1] == matrix.shape[-1] - - if std_is_matrix: - assert std.ndim >= 2 and std.shape[-1] == std.shape[-2] == ndim - def __init__( self, n: int, @@ -89,7 +91,7 @@ def __init__( column_mask: np.ndarray = None, validate_args=None, ): - self._verify_parameters(offset, matrix, std, initial_value, std_is_matrix) + _verify_parameters(offset, matrix, std, initial_value, std_is_matrix) times = jnp.arange(n) self._std_is_matrix = std_is_matrix