Skip to content

Commit

Permalink
"Fix" and version bump (#2)
Browse files Browse the repository at this point in the history
* "Fix" and version bump

* Squeeze

* Order

---------

Co-authored-by: Victor <[email protected]>
  • Loading branch information
tingiskhan and tingiskhan authored May 31, 2024
1 parent ed148e8 commit 74f1658
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion numpyro_sts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .random_walk import RandomWalk
from .smooth_llt import SmoothLocalLinearTrend

__version__ = "0.0.1"
__version__ = "0.0.2"
5 changes: 3 additions & 2 deletions numpyro_sts/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ def __init__(
std = jnp.reshape(std, batch_shape + (1,))
mu = jnp.reshape(mu, batch_shape + (1,))

offset = mu * (1.0 - phi.sum(axis=-1))

if order > 1:
offset = mu * (1.0 - phi.sum(axis=-1))
phi = jnp.concatenate([phi, jnp.eye(order - 1, order)], axis=-2)

zeros = jnp.zeros(batch_shape + (order - 1,))
offset = jnp.concatenate([offset, zeros], axis=-1)
std = jnp.concatenate([std, zeros], axis=-1)
else:
offset = mu * (1.0 - phi.squeeze(-1))

init = jnp.reshape(initial_value if initial_value is not None else jnp.zeros(order), batch_shape + (order,))

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dev = [
"isort",
"ruff",
"coverage",
"bumpver",
]

[project.urls]
Expand All @@ -56,7 +57,7 @@ profile = "black"
include = ["numpyro_sts*"]

[tool.bumpver]
current_version = "0.0.1"
current_version = "0.0.2"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "bump version {old_version} -> {new_version}"
commit = false
Expand Down

0 comments on commit 74f1658

Please sign in to comment.