From 4bbf0b4357af98539d2416304d9d9150d5d20cb1 Mon Sep 17 00:00:00 2001 From: Thore Wietzke Date: Wed, 10 May 2023 08:22:36 +0200 Subject: [PATCH] Added custom order for periodic kernel --- bayesnewton/kernels.py | 43 ++++-------------------------------------- 1 file changed, 4 insertions(+), 39 deletions(-) diff --git a/bayesnewton/kernels.py b/bayesnewton/kernels.py index 8a97509..a2ed53b 100644 --- a/bayesnewton/kernels.py +++ b/bayesnewton/kernels.py @@ -4,6 +4,7 @@ from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix from warnings import warn +from tensorflow_probability.substrates import jax as tfp class Kernel(objax.Module): @@ -1471,7 +1472,6 @@ class Periodic(Kernel): period, p The associated continuous-time state space model matrices are constructed via a sum of cosines. - TODO: allow for orders other than 6 """ def __init__(self, variance=1.0, lengthscale=1.0, period=1.0, order=6, fix_variance=False): self.transformed_lengthscale = objax.TrainVar(np.array(softplus_inv(lengthscale))) @@ -1483,22 +1483,6 @@ def __init__(self, variance=1.0, lengthscale=1.0, period=1.0, order=6, fix_varia super().__init__() self.name = 'Periodic' self.order = order - self.M = np.meshgrid(np.arange(self.order + 1), np.arange(self.order + 1))[1] - factorial_mesh_M = np.array([[1., 1., 1., 1., 1., 1., 1.], - [1., 1., 1., 1., 1., 1., 1.], - [2., 2., 2., 2., 2., 2., 2.], - [6., 6., 6., 6., 6., 6., 6.], - [24., 24., 24., 24., 24., 24., 24.], - [120., 120., 120., 120., 120., 120., 120.], - [720., 720., 720., 720., 720., 720., 720.]]) - b = np.array([[1., 0., 0., 0., 0., 0., 0.], - [0., 2., 0., 0., 0., 0., 0.], - [2., 0., 2., 0., 0., 0., 0.], - [0., 6., 0., 2., 0., 0., 0.], - [6., 0., 8., 0., 2., 0., 0.], - [0., 20., 0., 10., 0., 2., 0.], - [20., 0., 30., 0., 12., 0., 2.]]) - self.b_fmK_2M = b * (1. / factorial_mesh_M) * (2. ** -self.M) @property def variance(self): @@ -1513,8 +1497,7 @@ def period(self): return softplus(self.transformed_period.value) def kernel_to_state_space(self, R=None): - a = self.b_fmK_2M * self.lengthscale ** (-2. * self.M) * np.exp(-1. / self.lengthscale ** 2.) * self.variance - q2 = np.sum(a, axis=0) + q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale) # The angular frequency omega = 2 * np.pi / self.period # The model @@ -1526,8 +1509,7 @@ def kernel_to_state_space(self, R=None): return F, L, Qc, H, Pinf def stationary_covariance(self): - a = self.b_fmK_2M * self.lengthscale ** (-2. * self.M) * np.exp(-1. / self.lengthscale ** 2.) * self.variance - q2 = np.sum(a, axis=0) + q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale) Pinf = np.kron(np.diag(q2), np.eye(2)) return Pinf @@ -1542,24 +1524,7 @@ def state_transition(self, dt): :param dt: step size(s), Δt = tₙ - tₙ₋₁ [1] :return: state transition matrix A [2(N+1), 2(N+1)] """ - omega = 2 * np.pi / self.period # The angular frequency - harmonics = np.arange(self.order + 1) * omega - R0 = rotation_matrix(dt, harmonics[0]) - R1 = rotation_matrix(dt, harmonics[1]) - R2 = rotation_matrix(dt, harmonics[2]) - R3 = rotation_matrix(dt, harmonics[3]) - R4 = rotation_matrix(dt, harmonics[4]) - R5 = rotation_matrix(dt, harmonics[5]) - R6 = rotation_matrix(dt, harmonics[6]) - A = np.block([ - [R0, np.zeros([2, 12])], - [np.zeros([2, 2]), R1, np.zeros([2, 10])], - [np.zeros([2, 4]), R2, np.zeros([2, 8])], - [np.zeros([2, 6]), R3, np.zeros([2, 6])], - [np.zeros([2, 8]), R4, np.zeros([2, 4])], - [np.zeros([2, 10]), R5, np.zeros([2, 2])], - [np.zeros([2, 12]), R6] - ]) + A = expm(self.feedback_matrix()*dt) return A def feedback_matrix(self):