From b26cffc5efbae0c3b3a17916f5760b819dce4180 Mon Sep 17 00:00:00 2001 From: victor Date: Thu, 11 Jul 2024 08:56:24 +0200 Subject: [PATCH] "Fix" --- numpyro_sts/base.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/numpyro_sts/base.py b/numpyro_sts/base.py index b8bfa34..e22f5a0 100644 --- a/numpyro_sts/base.py +++ b/numpyro_sts/base.py @@ -244,18 +244,15 @@ def union(self, other: "LinearTimeseries") -> "LinearTimeseries": return model - def sample_deterministic(self, sample_shape=()) -> jnp.ndarray: + def deterministic(self) -> "LinearTimeseries": """ - Utility function for "sampling" the deterministic part of the series. - - Args: - sample_shape: See :meth:`sample`. + Constructs a deterministic version of the timeseries. Returns: - Returns samples. + A :class:`LinearTimeseries` with column mask set to False. """ - copy = LinearTimeseries( + return LinearTimeseries( self.n, self.offset, self.matrix, @@ -264,4 +261,15 @@ def sample_deterministic(self, sample_shape=()) -> jnp.ndarray: column_mask=np.zeros_like(self.column_mask), ) - return copy.sample(PRNGKey(0), sample_shape) + def sample_deterministic(self, sample_shape=()) -> jnp.ndarray: + """ + Utility function for "sampling" the deterministic part of the series. + + Args: + sample_shape: See :meth:`sample`. + + Returns: + Returns samples. + """ + + return self.deterministic().sample(PRNGKey(0), sample_shape)