From 84f8ff83eacf1313b300f5b7a88c33328178dc3c Mon Sep 17 00:00:00 2001 From: Victor Date: Sun, 23 Jun 2024 09:44:59 +0200 Subject: [PATCH] Uses ArrayLike from jax.typing and sets pythonpath (#3) * Uses ArrayLike from jax.typing and sets pythonpath * Version bump * Also test 3.12 --------- Co-authored-by: victor --- .github/workflows/pipeline.yaml | 2 +- numpyro_sts/__init__.py | 2 +- numpyro_sts/ar.py | 3 ++- numpyro_sts/base.py | 8 +------- numpyro_sts/llt.py | 3 ++- numpyro_sts/random_walk.py | 3 ++- numpyro_sts/smooth_llt.py | 3 ++- pyproject.toml | 5 ++++- 8 files changed, 15 insertions(+), 14 deletions(-) diff --git a/.github/workflows/pipeline.yaml b/.github/workflows/pipeline.yaml index 8cc1018..e878e77 100644 --- a/.github/workflows/pipeline.yaml +++ b/.github/workflows/pipeline.yaml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/numpyro_sts/__init__.py b/numpyro_sts/__init__.py index d97ce27..60fa868 100644 --- a/numpyro_sts/__init__.py +++ b/numpyro_sts/__init__.py @@ -4,4 +4,4 @@ from .random_walk import RandomWalk from .smooth_llt import SmoothLocalLinearTrend -__version__ = "0.0.2" +__version__ = "0.0.3" diff --git a/numpyro_sts/ar.py b/numpyro_sts/ar.py index 0af5da1..8edd391 100644 --- a/numpyro_sts/ar.py +++ b/numpyro_sts/ar.py @@ -1,7 +1,8 @@ import jax.numpy as jnp import numpy as np +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries from .util import cast_to_tensor diff --git a/numpyro_sts/base.py b/numpyro_sts/base.py index 5a5fa84..ad9d562 100644 --- a/numpyro_sts/base.py +++ b/numpyro_sts/base.py @@ -1,17 +1,11 @@ -from numbers import Number -from typing import Union - import jax.numpy as jnp -import numpy as np from jax import vmap from jax.random import normal from numpyro.contrib.control_flow import scan from numpyro.distributions import Distribution, Normal, constraints, MultivariateNormal from numpyro.distributions.util import validate_sample from numpyro.util import is_prng_key - - -ArrayLike = Union[jnp.ndarray, Number, np.ndarray] +from jax.typing import ArrayLike def _broadcast_and_reshape(x: jnp.ndarray, shape, dim: int) -> jnp.ndarray: diff --git a/numpyro_sts/llt.py b/numpyro_sts/llt.py index 89408f3..0a6ce81 100644 --- a/numpyro_sts/llt.py +++ b/numpyro_sts/llt.py @@ -1,7 +1,8 @@ import jax.numpy as np from numpyro.distributions.util import promote_shapes +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries class LocalLinearTrend(LinearTimeseries): diff --git a/numpyro_sts/random_walk.py b/numpyro_sts/random_walk.py index e73259a..4541331 100644 --- a/numpyro_sts/random_walk.py +++ b/numpyro_sts/random_walk.py @@ -1,6 +1,7 @@ import jax.numpy as np +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries from .util import cast_to_tensor diff --git a/numpyro_sts/smooth_llt.py b/numpyro_sts/smooth_llt.py index 03adda4..506e3e2 100644 --- a/numpyro_sts/smooth_llt.py +++ b/numpyro_sts/smooth_llt.py @@ -1,8 +1,9 @@ import jax.numpy as jnp import numpy as np from numpyro.distributions.util import promote_shapes +from jax.typing import ArrayLike -from .base import ArrayLike, LinearTimeseries +from .base import LinearTimeseries class SmoothLocalLinearTrend(LinearTimeseries): diff --git a/pyproject.toml b/pyproject.toml index 921e31f..eccc700 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ profile = "black" include = ["numpyro_sts*"] [tool.bumpver] -current_version = "0.0.2" +current_version = "0.0.3" version_pattern = "MAJOR.MINOR.PATCH" commit_message = "bump version {old_version} -> {new_version}" commit = false @@ -75,3 +75,6 @@ push = false [tool.setuptools.dynamic] version = {attr = "numpyro_sts.__version__"} + +[tool.pytest.ini_options] +pythonpath = ["."] \ No newline at end of file