diff --git a/numpyro_sts/base.py b/numpyro_sts/base.py index f00d9a1..14e0711 100644 --- a/numpyro_sts/base.py +++ b/numpyro_sts/base.py @@ -1,3 +1,4 @@ +import warnings from functools import cached_property, reduce from typing import Tuple @@ -90,7 +91,12 @@ def __init__( std_is_matrix: bool = False, column_mask: np.ndarray = None, validate_args=None, + **kwargs, ): + if "mask" in kwargs: + warnings.warn("'mask' is deprecated in favor of 'column_mask'", DeprecationWarning) + column_mask = kwargs.pop("mask") + _verify_parameters(offset, matrix, std, initial_value, std_is_matrix) times = jnp.arange(n)