diff --git a/bayesnewton/kernels.py b/bayesnewton/kernels.py index 708ed0b..dab7e37 100644 --- a/bayesnewton/kernels.py +++ b/bayesnewton/kernels.py @@ -4,7 +4,7 @@ from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix from warnings import warn -from tensorflow_probability.substrates import jax as tfp +from tensorflow_probability.substrates.jax.math import bessel_ive class Kernel(objax.Module): @@ -812,7 +812,7 @@ def period(self): return softplus(self.transformed_period.value) def kernel_to_state_space(self, R=None): - q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale) + q2 = np.array([1, *[2]*self.order]) * self.variance * bessel_ive([*range(self.order+1)], self.lengthscale**(-2)) # The angular frequency omega = 2 * np.pi / self.period # The model @@ -824,7 +824,7 @@ def kernel_to_state_space(self, R=None): return F, L, Qc, H, Pinf def stationary_covariance(self): - q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale) + q2 = np.array([1, *[2]*self.order]) * self.variance * bessel_ive([*range(self.order+1)], self.lengthscale**(-2)) Pinf = np.kron(np.diag(q2), np.eye(2)) return Pinf @@ -840,10 +840,6 @@ def state_transition(self, dt): :return: state transition matrix A [2(N+1), 2(N+1)] """ A = expm(self.feedback_matrix()*dt) - # omega = 2 * np.pi / self.period # The angular frequency - # harmonics = np.arange(self.order + 1) * omega - # - # A = block_diag(*[rotation_matrix(dt, val) for val in harmonics]) return A @@ -896,7 +892,7 @@ def K(self, X, X2): raise NotImplementedError def kernel_to_state_space(self, R=None): - q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic) + q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) # The angular frequency omega = 2 * np.pi / self.period # The model @@ -915,31 +911,13 @@ def kernel_to_state_space(self, R=None): Qc = np.kron(Qc_m, Pinf_p) H = np.kron(H_m, H_p) Pinf = np.kron(Pinf_m, Pinf_p) - # Pinf = block_diag( - # np.kron(Pinf_m, q2[0] * np.eye(2)), - # np.kron(Pinf_m, q2[1] * np.eye(2)), - # np.kron(Pinf_m, q2[2] * np.eye(2)), - # np.kron(Pinf_m, q2[3] * np.eye(2)), - # np.kron(Pinf_m, q2[4] * np.eye(2)), - # np.kron(Pinf_m, q2[5] * np.eye(2)), - # np.kron(Pinf_m, q2[6] * np.eye(2)), - # ) return F, L, Qc, H, Pinf def stationary_covariance(self): - q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic) + q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) Pinf_m = np.array([[self.variance]]) Pinf_p = np.kron(np.diag(q2), np.eye(2)) Pinf = np.kron(Pinf_m, Pinf_p) - # Pinf = block_diag( - # np.kron(Pinf_m, q2[0] * np.eye(2)), - # np.kron(Pinf_m, q2[1] * np.eye(2)), - # np.kron(Pinf_m, q2[2] * np.eye(2)), - # np.kron(Pinf_m, q2[3] * np.eye(2)), - # np.kron(Pinf_m, q2[4] * np.eye(2)), - # np.kron(Pinf_m, q2[5] * np.eye(2)), - # np.kron(Pinf_m, q2[6] * np.eye(2)), - # ) return Pinf def measurement_model(self): @@ -957,16 +935,6 @@ def state_transition(self, dt): """ # The angular frequency A = expm(self.feedback_matrix() * dt) - # omega = 2 * np.pi / self.period - # harmonics = np.arange(self.order + 1) * omega - # R0 = rotation_matrix(dt, harmonics[0]) - # R1 = rotation_matrix(dt, harmonics[1]) - # R2 = rotation_matrix(dt, harmonics[2]) - # R3 = rotation_matrix(dt, harmonics[3]) - # R4 = rotation_matrix(dt, harmonics[4]) - # R5 = rotation_matrix(dt, harmonics[5]) - # R6 = rotation_matrix(dt, harmonics[6]) - # A = np.exp(-dt / self.lengthscale_matern) * block_diag(R0, R1, R2, R3, R4, R5, R6) return A def feedback_matrix(self): @@ -1019,7 +987,7 @@ def K(self, X, X2): raise NotImplementedError def kernel_to_state_space(self, R=None): - q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic) + q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) # The angular frequency omega = 2 * np.pi / self.period # The model @@ -1043,32 +1011,14 @@ def kernel_to_state_space(self, R=None): Qc = np.kron(Qc_m, Pinf_p) H = np.kron(H_m, H_p) Pinf = np.kron(Pinf_m, Pinf_p) - # Pinf = block_diag( - # np.kron(Pinf_m, q2[0] * np.eye(2)), - # np.kron(Pinf_m, q2[1] * np.eye(2)), - # np.kron(Pinf_m, q2[2] * np.eye(2)), - # np.kron(Pinf_m, q2[3] * np.eye(2)), - # np.kron(Pinf_m, q2[4] * np.eye(2)), - # np.kron(Pinf_m, q2[5] * np.eye(2)), - # np.kron(Pinf_m, q2[6] * np.eye(2)), - # ) return F, L, Qc, H, Pinf def stationary_covariance(self): - q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic) + q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) Pinf_m = np.array([[self.variance, 0.0], [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) Pinf_p = np.kron(np.diag(q2), np.eye(2)) Pinf = np.kron(Pinf_m, Pinf_p) - # Pinf = block_diag( - # np.kron(Pinf_m, q2[0] * np.eye(2)), - # np.kron(Pinf_m, q2[1] * np.eye(2)), - # np.kron(Pinf_m, q2[2] * np.eye(2)), - # np.kron(Pinf_m, q2[3] * np.eye(2)), - # np.kron(Pinf_m, q2[4] * np.eye(2)), - # np.kron(Pinf_m, q2[5] * np.eye(2)), - # np.kron(Pinf_m, q2[6] * np.eye(2)), - # ) return Pinf def measurement_model(self):