Skip to content

Commit

Permalink
"Fix"
Browse files Browse the repository at this point in the history
  • Loading branch information
victor committed Jul 11, 2024
1 parent 5ba892c commit b26cffc
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions numpyro_sts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,18 +244,15 @@ def union(self, other: "LinearTimeseries") -> "LinearTimeseries":

return model

def sample_deterministic(self, sample_shape=()) -> jnp.ndarray:
def deterministic(self) -> "LinearTimeseries":
"""
Utility function for "sampling" the deterministic part of the series.
Args:
sample_shape: See :meth:`sample`.
Constructs a deterministic version of the timeseries.
Returns:
Returns samples.
A :class:`LinearTimeseries` with column mask set to False.
"""

copy = LinearTimeseries(
return LinearTimeseries(
self.n,
self.offset,
self.matrix,
Expand All @@ -264,4 +261,15 @@ def sample_deterministic(self, sample_shape=()) -> jnp.ndarray:
column_mask=np.zeros_like(self.column_mask),
)

return copy.sample(PRNGKey(0), sample_shape)
def sample_deterministic(self, sample_shape=()) -> jnp.ndarray:
"""
Utility function for "sampling" the deterministic part of the series.
Args:
sample_shape: See :meth:`sample`.
Returns:
Returns samples.
"""

return self.deterministic().sample(PRNGKey(0), sample_shape)

0 comments on commit b26cffc

Please sign in to comment.