Skip to content

Commit

Permalink
Uses ArrayLike from jax.typing and sets pythonpath (#3)
Browse files Browse the repository at this point in the history
* Uses ArrayLike from jax.typing and sets pythonpath

* Version bump

* Also test 3.12

---------

Co-authored-by: victor <baj>
  • Loading branch information
tingiskhan authored Jun 23, 2024
1 parent 74f1658 commit 84f8ff8
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion numpyro_sts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .random_walk import RandomWalk
from .smooth_llt import SmoothLocalLinearTrend

__version__ = "0.0.2"
__version__ = "0.0.3"
3 changes: 2 additions & 1 deletion numpyro_sts/ar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import jax.numpy as jnp
import numpy as np
from jax.typing import ArrayLike

from .base import ArrayLike, LinearTimeseries
from .base import LinearTimeseries
from .util import cast_to_tensor


Expand Down
8 changes: 1 addition & 7 deletions numpyro_sts/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from numbers import Number
from typing import Union

import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax.random import normal
from numpyro.contrib.control_flow import scan
from numpyro.distributions import Distribution, Normal, constraints, MultivariateNormal
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key


ArrayLike = Union[jnp.ndarray, Number, np.ndarray]
from jax.typing import ArrayLike


def _broadcast_and_reshape(x: jnp.ndarray, shape, dim: int) -> jnp.ndarray:
Expand Down
3 changes: 2 additions & 1 deletion numpyro_sts/llt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import jax.numpy as np
from numpyro.distributions.util import promote_shapes
from jax.typing import ArrayLike

from .base import ArrayLike, LinearTimeseries
from .base import LinearTimeseries


class LocalLinearTrend(LinearTimeseries):
Expand Down
3 changes: 2 additions & 1 deletion numpyro_sts/random_walk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax.numpy as np
from jax.typing import ArrayLike

from .base import ArrayLike, LinearTimeseries
from .base import LinearTimeseries
from .util import cast_to_tensor


Expand Down
3 changes: 2 additions & 1 deletion numpyro_sts/smooth_llt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import jax.numpy as jnp
import numpy as np
from numpyro.distributions.util import promote_shapes
from jax.typing import ArrayLike

from .base import ArrayLike, LinearTimeseries
from .base import LinearTimeseries


class SmoothLocalLinearTrend(LinearTimeseries):
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ profile = "black"
include = ["numpyro_sts*"]

[tool.bumpver]
current_version = "0.0.2"
current_version = "0.0.3"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "bump version {old_version} -> {new_version}"
commit = false
Expand All @@ -75,3 +75,6 @@ push = false

[tool.setuptools.dynamic]
version = {attr = "numpyro_sts.__version__"}

[tool.pytest.ini_options]
pythonpath = ["."]

0 comments on commit 84f8ff8

Please sign in to comment.