diff --git a/bayesnewton/kernels.py b/bayesnewton/kernels.py index dab7e37..356931c 100644 --- a/bayesnewton/kernels.py +++ b/bayesnewton/kernels.py @@ -6,6 +6,28 @@ from warnings import warn from tensorflow_probability.substrates.jax.math import bessel_ive +FACTORIALS = np.array([1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800]) + + +def factorial(i): + return FACTORIALS[i] + + +def coeff(j, lengthscale, order): + """ + Can be used to generate co-efficients for the quasi-periodic kernels that guarantee a valid covariance function: + q2 = np.array([coeff(j, lengthscale, order) for j in range(order + 1)]) + See eq (26) of [1]. + Not currently used (we use Bessel functions instead for clarity). + + [1] Solin & Sarkka (2014) "Explicit Link Between Periodic Covariance Functions and State Space Models". + """ + s = sum([ + # (2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order-j) / 2)) + (2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order+2-j) / 2)) + ]) + return 1 / np.exp(lengthscale ** -2) * s + class Kernel(objax.Module): """ @@ -839,8 +861,9 @@ def state_transition(self, dt): :param dt: step size(s), Δt = tₙ - tₙ₋₁ [1] :return: state transition matrix A [2(N+1), 2(N+1)] """ - A = expm(self.feedback_matrix()*dt) - + omega = 2 * np.pi / self.period # The angular frequency + harmonics = np.arange(self.order + 1) * omega + A = block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics)) return A def feedback_matrix(self): @@ -934,7 +957,12 @@ def state_transition(self, dt): :return: state transition matrix A [M+1, D, D] """ # The angular frequency - A = expm(self.feedback_matrix() * dt) + omega = 2 * np.pi / self.period + harmonics = np.arange(self.order + 1) * omega + A = ( + np.exp(-dt / self.lengthscale_matern) + * block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics)) + ) return A def feedback_matrix(self): @@ -990,35 +1018,34 @@ def kernel_to_state_space(self, R=None): q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) # The angular frequency omega = 2 * np.pi / self.period + harmonics = np.arange(self.order + 1) * omega # The model - F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]])) L_p = np.eye(2 * (self.order + 1)) # Qc_p = np.zeros(2 * (self.N + 1)) - Pinf_p = np.kron(np.diag(q2), np.eye(2)) + Pinf_per = q2[:, None, None] * np.eye(2) H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.])) - lam = 3.0 ** 0.5 / self.lengthscale_matern - F_m = np.array([[0.0, 1.0], - [-lam ** 2, -2 * lam]]) L_m = np.array([[0], [1]]) Qc_m = np.array([[12.0 * 3.0 ** 0.5 / self.lengthscale_matern ** 3.0 * self.variance]]) H_m = np.array([[1.0, 0.0]]) - Pinf_m = np.array([[self.variance, 0.0], - [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) - # F = np.kron(F_p, np.eye(2)) + np.kron(np.eye(14), F_m) - F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p) + Pinf_mat = np.array([[self.variance, 0.0], + [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) + F = block_diag( + *vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics) + ) L = np.kron(L_m, L_p) - Qc = np.kron(Qc_m, Pinf_p) + # note: Qc is always kron(Qc_m, q2I), not kron(Qc_m, Pinf_per). See eq (32) of Solin & Sarkka 2014. + Qc = block_diag(*np.kron(Qc_m, q2[:, None, None] * np.eye(2))) H = np.kron(H_m, H_p) - Pinf = np.kron(Pinf_m, Pinf_p) + Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per)) return F, L, Qc, H, Pinf def stationary_covariance(self): q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) - Pinf_m = np.array([[self.variance, 0.0], - [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) - Pinf_p = np.kron(np.diag(q2), np.eye(2)) - Pinf = np.kron(Pinf_m, Pinf_p) + Pinf_mat = np.array([[self.variance, 0.0], + [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) + Pinf_per = q2[:, None, None] * np.eye(2) + Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per)) return Pinf def measurement_model(self): @@ -1034,19 +1061,55 @@ def state_transition(self, dt): :param dt: step size(s), Δt = tₙ - tₙ₋₁ [M+1, 1] :return: state transition matrix A [M+1, D, D] """ - A = expm(self.feedback_matrix()*dt) - + # The angular frequency + omega = 2 * np.pi / self.period + harmonics = np.arange(self.order + 1) * omega + A = block_diag( + *vmap(self.state_transition_subband_matern32, [None, None, 0])(dt, self.lengthscale_matern, harmonics) + ) return A def feedback_matrix(self): - # The angular frequency + # The angular fundamental frequency omega = 2 * np.pi / self.period - # The model - F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]])) - lam = 3.0 ** 0.5 / self.lengthscale_matern - F_m = np.array([[0.0, 1.0], - [-lam ** 2, -2 * lam]]) - F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p) + harmonics = np.arange(self.order + 1) * omega + F = block_diag( + *vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics) + ) + return F + + @staticmethod + def state_transition_subband_matern32(dt, ell, radial_freq): + # TODO: re-use code from SubbandMatern32 kernel + """ + Calculation of the closed form discrete-time state + transition matrix A = expm(FΔt) for the Subband Matern-3/2 prior + :param dt: step size(s), Δt = tₙ - tₙ₋₁ [1] + :param ell: lengthscale of the Matern component of the kernel + :param radial_freq: the radial (i.e. angular) frequency of the oscillator + :return: state transition matrix A [4, 4] + """ + lam = np.sqrt(3.0) / ell + R = rotation_matrix(dt, radial_freq) + A = np.exp(-dt * lam) * np.block([ + [(1. + dt * lam) * R, dt * R], + [-dt * lam ** 2 * R, (1. - dt * lam) * R] + ]) + return A + + @staticmethod + def feedback_matrix_subband_matern32(len, omega): + # TODO: re-use code from SubbandMatern32 kernel + lam = 3.0 ** 0.5 / len + F_mat = np.array([[0.0, 1.0], + [-lam ** 2, -2 * lam]]) + F_cos = np.array([[0.0, -omega], + [omega, 0.0]]) + # F = (0 -ω 1 0 + # ω 0 0 1 + # -λ² 0 -2λ -ω + # 0 -λ² ω -2λ) + F = np.kron(F_mat, np.eye(2)) + np.kron(np.eye(2), F_cos) return F diff --git a/requirements.txt b/requirements.txt index f3686f4..cea5c11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ matplotlib scipy scikit-learn pandas -tensorflow_probability \ No newline at end of file +tensorflow_probability==0.21 \ No newline at end of file