diff --git a/.github/workflows/pipeline.yaml b/.github/workflows/pipeline.yaml index 8cc1018..e878e77 100644 --- a/.github/workflows/pipeline.yaml +++ b/.github/workflows/pipeline.yaml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/numpyro_sts/__init__.py b/numpyro_sts/__init__.py index d97ce27..60fa868 100644 --- a/numpyro_sts/__init__.py +++ b/numpyro_sts/__init__.py @@ -4,4 +4,4 @@ from .random_walk import RandomWalk from .smooth_llt import SmoothLocalLinearTrend -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/numpyro_sts/ar.py b/numpyro_sts/ar.py index 0af5da1..8edd391 100644 --- a/numpyro_sts/ar.py +++ b/numpyro_sts/ar.py @@ -1,7 +1,8 @@ import jax.numpy as jnp import numpy as np +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries from .util import cast_to_tensor diff --git a/numpyro_sts/base.py b/numpyro_sts/base.py index 5a5fa84..ad9d562 100644 --- a/numpyro_sts/base.py +++ b/numpyro_sts/base.py @@ -1,17 +1,11 @@ -from numbers import Number -from typing import Union - import jax.numpy as jnp -import numpy as np from jax import vmap from jax.random import normal from numpyro.contrib.control_flow import scan from numpyro.distributions import Distribution, Normal, constraints, MultivariateNormal from numpyro.distributions.util import validate_sample from numpyro.util import is_prng_key - - -ArrayLike = Union[jnp.ndarray, Number, np.ndarray] +from jax.typing import ArrayLike def _broadcast_and_reshape(x: jnp.ndarray, shape, dim: int) -> jnp.ndarray: diff --git a/numpyro_sts/llt.py b/numpyro_sts/llt.py index 89408f3..0a6ce81 100644 --- a/numpyro_sts/llt.py +++ b/numpyro_sts/llt.py @@ -1,7 +1,8 @@ import jax.numpy as np from numpyro.distributions.util import promote_shapes +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries class LocalLinearTrend(LinearTimeseries): diff --git a/numpyro_sts/random_walk.py b/numpyro_sts/random_walk.py index e73259a..4541331 100644 --- a/numpyro_sts/random_walk.py +++ b/numpyro_sts/random_walk.py @@ -1,6 +1,7 @@ import jax.numpy as np +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries from .util import cast_to_tensor diff --git a/numpyro_sts/smooth_llt.py b/numpyro_sts/smooth_llt.py index 03adda4..506e3e2 100644 --- a/numpyro_sts/smooth_llt.py +++ b/numpyro_sts/smooth_llt.py @@ -1,8 +1,9 @@ import jax.numpy as jnp import numpy as np from numpyro.distributions.util import promote_shapes +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries class SmoothLocalLinearTrend(LinearTimeseries): diff --git a/pyproject.toml b/pyproject.toml index 921e31f..eccc700 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ profile = "black" include = ["numpyro_sts*"] [tool.bumpver] -current_version = "0.0.2" +current_version = "0.0.3" version_pattern = "MAJOR.MINOR.PATCH" commit_message = "bump version {old_version} -> {new_version}" commit = false @@ -75,3 +75,6 @@ push = false [tool.setuptools.dynamic] version = {attr = "numpyro_sts.__version__"} + +[tool.pytest.ini_options] +pythonpath = ["."] \ No newline at end of file