diff --git a/tests/test_union.py b/tests/test_union.py index 27d773f..fce675f 100644 --- a/tests/test_union.py +++ b/tests/test_union.py @@ -15,5 +15,9 @@ def test_union(): key = jrnd.PRNGKey(123) samples = combined.sample(key) + log_prob = combined.log_prob(samples) + assert samples.shape == combined.event_shape assert combined.event_shape[-1] == 7 + +