Skip to content

Commit

Permalink
Added custom order for periodic kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ThoreWietzke committed May 10, 2023
1 parent 29fec22 commit 4bbf0b4
Showing 1 changed file with 4 additions and 39 deletions.
43 changes: 4 additions & 39 deletions bayesnewton/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm
from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix
from warnings import warn
from tensorflow_probability.substrates import jax as tfp


class Kernel(objax.Module):
Expand Down Expand Up @@ -1471,7 +1472,6 @@ class Periodic(Kernel):
period, p
The associated continuous-time state space model matrices are constructed via
a sum of cosines.
TODO: allow for orders other than 6
"""
def __init__(self, variance=1.0, lengthscale=1.0, period=1.0, order=6, fix_variance=False):
self.transformed_lengthscale = objax.TrainVar(np.array(softplus_inv(lengthscale)))
Expand All @@ -1483,22 +1483,6 @@ def __init__(self, variance=1.0, lengthscale=1.0, period=1.0, order=6, fix_varia
super().__init__()
self.name = 'Periodic'
self.order = order
self.M = np.meshgrid(np.arange(self.order + 1), np.arange(self.order + 1))[1]
factorial_mesh_M = np.array([[1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2.],
[6., 6., 6., 6., 6., 6., 6.],
[24., 24., 24., 24., 24., 24., 24.],
[120., 120., 120., 120., 120., 120., 120.],
[720., 720., 720., 720., 720., 720., 720.]])
b = np.array([[1., 0., 0., 0., 0., 0., 0.],
[0., 2., 0., 0., 0., 0., 0.],
[2., 0., 2., 0., 0., 0., 0.],
[0., 6., 0., 2., 0., 0., 0.],
[6., 0., 8., 0., 2., 0., 0.],
[0., 20., 0., 10., 0., 2., 0.],
[20., 0., 30., 0., 12., 0., 2.]])
self.b_fmK_2M = b * (1. / factorial_mesh_M) * (2. ** -self.M)

@property
def variance(self):
Expand All @@ -1513,8 +1497,7 @@ def period(self):
return softplus(self.transformed_period.value)

def kernel_to_state_space(self, R=None):
a = self.b_fmK_2M * self.lengthscale ** (-2. * self.M) * np.exp(-1. / self.lengthscale ** 2.) * self.variance
q2 = np.sum(a, axis=0)
q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale)
# The angular frequency
omega = 2 * np.pi / self.period
# The model
Expand All @@ -1526,8 +1509,7 @@ def kernel_to_state_space(self, R=None):
return F, L, Qc, H, Pinf

def stationary_covariance(self):
a = self.b_fmK_2M * self.lengthscale ** (-2. * self.M) * np.exp(-1. / self.lengthscale ** 2.) * self.variance
q2 = np.sum(a, axis=0)
q2 = np.array([1, *[2]*self.order]) * self.variance * tfp.math.bessel_ive([*range(self.order+1)], self.lengthscale)
Pinf = np.kron(np.diag(q2), np.eye(2))
return Pinf

Expand All @@ -1542,24 +1524,7 @@ def state_transition(self, dt):
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [1]
:return: state transition matrix A [2(N+1), 2(N+1)]
"""
omega = 2 * np.pi / self.period # The angular frequency
harmonics = np.arange(self.order + 1) * omega
R0 = rotation_matrix(dt, harmonics[0])
R1 = rotation_matrix(dt, harmonics[1])
R2 = rotation_matrix(dt, harmonics[2])
R3 = rotation_matrix(dt, harmonics[3])
R4 = rotation_matrix(dt, harmonics[4])
R5 = rotation_matrix(dt, harmonics[5])
R6 = rotation_matrix(dt, harmonics[6])
A = np.block([
[R0, np.zeros([2, 12])],
[np.zeros([2, 2]), R1, np.zeros([2, 10])],
[np.zeros([2, 4]), R2, np.zeros([2, 8])],
[np.zeros([2, 6]), R3, np.zeros([2, 6])],
[np.zeros([2, 8]), R4, np.zeros([2, 4])],
[np.zeros([2, 10]), R5, np.zeros([2, 2])],
[np.zeros([2, 12]), R6]
])
A = expm(self.feedback_matrix()*dt)
return A

def feedback_matrix(self):
Expand Down

0 comments on commit 4bbf0b4

Please sign in to comment.