Skip to content

Commit

Permalink
Merge pull request #29 from ThoreWietzke/fix_periodic_kernel
Browse files Browse the repository at this point in the history
Fixed calculation of the periodic kernel and dependencies
  • Loading branch information
wil-j-wil authored Dec 14, 2023
2 parents f75f9c9 + ef00ff7 commit e9d6970
Showing 1 changed file with 7 additions and 57 deletions.
64 changes: 7 additions & 57 deletions bayesnewton/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
from warnings import warn
from tensorflow_probability.substrates import jax as tfp
from tensorflow_probability.substrates.jax.math import bessel_ive


class Kernel(objax.Module):
Expand Down Expand Up @@ -812,7 +812,7 @@ def period(self):
return softplus(self.transformed_period.value)

def kernel_to_state_space(self, R=None):
q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale)
q2 = np.array([1, *[2]*self.order]) * self.variance * bessel_ive([*range(self.order+1)], self.lengthscale**(-2))
# The angular frequency
omega = 2 * np.pi / self.period
# The model
Expand All @@ -824,7 +824,7 @@ def kernel_to_state_space(self, R=None):
return F, L, Qc, H, Pinf

def stationary_covariance(self):
q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale)
q2 = np.array([1, *[2]*self.order]) * self.variance * bessel_ive([*range(self.order+1)], self.lengthscale**(-2))
Pinf = np.kron(np.diag(q2), np.eye(2))
return Pinf

Expand All @@ -840,10 +840,6 @@ def state_transition(self, dt):
:return: state transition matrix A [2(N+1), 2(N+1)]
"""
A = expm(self.feedback_matrix()*dt)
# omega = 2 * np.pi / self.period # The angular frequency
# harmonics = np.arange(self.order + 1) * omega
#
# A = block_diag(*[rotation_matrix(dt, val) for val in harmonics])

return A

Expand Down Expand Up @@ -896,7 +892,7 @@ def K(self, X, X2):
raise NotImplementedError

def kernel_to_state_space(self, R=None):
q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic)
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
# The angular frequency
omega = 2 * np.pi / self.period
# The model
Expand All @@ -915,31 +911,13 @@ def kernel_to_state_space(self, R=None):
Qc = np.kron(Qc_m, Pinf_p)
H = np.kron(H_m, H_p)
Pinf = np.kron(Pinf_m, Pinf_p)
# Pinf = block_diag(
# np.kron(Pinf_m, q2[0] * np.eye(2)),
# np.kron(Pinf_m, q2[1] * np.eye(2)),
# np.kron(Pinf_m, q2[2] * np.eye(2)),
# np.kron(Pinf_m, q2[3] * np.eye(2)),
# np.kron(Pinf_m, q2[4] * np.eye(2)),
# np.kron(Pinf_m, q2[5] * np.eye(2)),
# np.kron(Pinf_m, q2[6] * np.eye(2)),
# )
return F, L, Qc, H, Pinf

def stationary_covariance(self):
q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic)
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
Pinf_m = np.array([[self.variance]])
Pinf_p = np.kron(np.diag(q2), np.eye(2))
Pinf = np.kron(Pinf_m, Pinf_p)
# Pinf = block_diag(
# np.kron(Pinf_m, q2[0] * np.eye(2)),
# np.kron(Pinf_m, q2[1] * np.eye(2)),
# np.kron(Pinf_m, q2[2] * np.eye(2)),
# np.kron(Pinf_m, q2[3] * np.eye(2)),
# np.kron(Pinf_m, q2[4] * np.eye(2)),
# np.kron(Pinf_m, q2[5] * np.eye(2)),
# np.kron(Pinf_m, q2[6] * np.eye(2)),
# )
return Pinf

def measurement_model(self):
Expand All @@ -957,16 +935,6 @@ def state_transition(self, dt):
"""
# The angular frequency
A = expm(self.feedback_matrix() * dt)
# omega = 2 * np.pi / self.period
# harmonics = np.arange(self.order + 1) * omega
# R0 = rotation_matrix(dt, harmonics[0])
# R1 = rotation_matrix(dt, harmonics[1])
# R2 = rotation_matrix(dt, harmonics[2])
# R3 = rotation_matrix(dt, harmonics[3])
# R4 = rotation_matrix(dt, harmonics[4])
# R5 = rotation_matrix(dt, harmonics[5])
# R6 = rotation_matrix(dt, harmonics[6])
# A = np.exp(-dt / self.lengthscale_matern) * block_diag(R0, R1, R2, R3, R4, R5, R6)
return A

def feedback_matrix(self):
Expand Down Expand Up @@ -1019,7 +987,7 @@ def K(self, X, X2):
raise NotImplementedError

def kernel_to_state_space(self, R=None):
q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic)
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
# The angular frequency
omega = 2 * np.pi / self.period
# The model
Expand All @@ -1043,32 +1011,14 @@ def kernel_to_state_space(self, R=None):
Qc = np.kron(Qc_m, Pinf_p)
H = np.kron(H_m, H_p)
Pinf = np.kron(Pinf_m, Pinf_p)
# Pinf = block_diag(
# np.kron(Pinf_m, q2[0] * np.eye(2)),
# np.kron(Pinf_m, q2[1] * np.eye(2)),
# np.kron(Pinf_m, q2[2] * np.eye(2)),
# np.kron(Pinf_m, q2[3] * np.eye(2)),
# np.kron(Pinf_m, q2[4] * np.eye(2)),
# np.kron(Pinf_m, q2[5] * np.eye(2)),
# np.kron(Pinf_m, q2[6] * np.eye(2)),
# )
return F, L, Qc, H, Pinf

def stationary_covariance(self):
q2 = np.array([1, *[2]*self.order]) * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale_periodic)
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
Pinf_m = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
Pinf_p = np.kron(np.diag(q2), np.eye(2))
Pinf = np.kron(Pinf_m, Pinf_p)
# Pinf = block_diag(
# np.kron(Pinf_m, q2[0] * np.eye(2)),
# np.kron(Pinf_m, q2[1] * np.eye(2)),
# np.kron(Pinf_m, q2[2] * np.eye(2)),
# np.kron(Pinf_m, q2[3] * np.eye(2)),
# np.kron(Pinf_m, q2[4] * np.eye(2)),
# np.kron(Pinf_m, q2[5] * np.eye(2)),
# np.kron(Pinf_m, q2[6] * np.eye(2)),
# )
return Pinf

def measurement_model(self):
Expand Down

0 comments on commit e9d6970

Please sign in to comment.