diff --git a/tests/test_sts.py b/tests/test_sts.py index 5f7bf89..0edd439 100644 --- a/tests/test_sts.py +++ b/tests/test_sts.py @@ -74,10 +74,12 @@ def numpyro_model(n, y_): mcmc.run(key, y.shape[0], y_=y) samples = mcmc.get_samples() - quantiles = np.quantile(samples["std"], [0.001, 0.999]) + low, high = np.quantile(samples["std"], [0.001, 0.999]) - assert (quantiles[0] <= true_model.std <= quantiles[1]).all() + assert (low <= true_model.std <= high).all() + # NB: this should be fetched from the predictive distribution rather... + assert samples["std"].std() <= 1.0 @pt.mark.parametrize("shape", [(), (10,)])