diff --git a/numpyro_sts/base.py b/numpyro_sts/base.py index b8bfa34..e22f5a0 100644 --- a/numpyro_sts/base.py +++ b/numpyro_sts/base.py @@ -244,18 +244,15 @@ def union(self, other: "LinearTimeseries") -> "LinearTimeseries": return model - def sample_deterministic(self, sample_shape=()) -> jnp.ndarray: + def deterministic(self) -> "LinearTimeseries": """ - Utility function for "sampling" the deterministic part of the series. - - Args: - sample_shape: See :meth:`sample`. + Constructs a deterministic version of the timeseries. Returns: - Returns samples. + A :class:`LinearTimeseries` with column mask set to False. """ - copy = LinearTimeseries( + return LinearTimeseries( self.n, self.offset, self.matrix, @@ -264,4 +261,15 @@ def sample_deterministic(self, sample_shape=()) -> jnp.ndarray: column_mask=np.zeros_like(self.column_mask), ) - return copy.sample(PRNGKey(0), sample_shape) + def sample_deterministic(self, sample_shape=()) -> jnp.ndarray: + """ + Utility function for "sampling" the deterministic part of the series. + + Args: + sample_shape: See :meth:`sample`. + + Returns: + Returns samples. + """ + + return self.deterministic().sample(PRNGKey(0), sample_shape)