Skip to content

Commit

Permalink
fixes to quasi periodic kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
William Wilkinson authored and William Wilkinson committed Dec 15, 2023
1 parent fad0483 commit bd6a1ee
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 28 deletions.
117 changes: 90 additions & 27 deletions bayesnewton/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
from warnings import warn
from tensorflow_probability.substrates.jax.math import bessel_ive

FACTORIALS = np.array([1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800])


def factorial(i):
return FACTORIALS[i]


def coeff(j, lengthscale, order):
"""
Can be used to generate co-efficients for the quasi-periodic kernels that guarantee a valid covariance function:
q2 = np.array([coeff(j, lengthscale, order) for j in range(order + 1)])
See eq (26) of [1].
Not currently used (we use Bessel functions instead for clarity).
[1] Solin & Sarkka (2014) "Explicit Link Between Periodic Covariance Functions and State Space Models".
"""
s = sum([
# (2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order-j) / 2))
(2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order+2-j) / 2))
])
return 1 / np.exp(lengthscale ** -2) * s


class Kernel(objax.Module):
"""
Expand Down Expand Up @@ -839,8 +861,9 @@ def state_transition(self, dt):
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [1]
:return: state transition matrix A [2(N+1), 2(N+1)]
"""
A = expm(self.feedback_matrix()*dt)

omega = 2 * np.pi / self.period # The angular frequency
harmonics = np.arange(self.order + 1) * omega
A = block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics))
return A

def feedback_matrix(self):
Expand Down Expand Up @@ -934,7 +957,12 @@ def state_transition(self, dt):
:return: state transition matrix A [M+1, D, D]
"""
# The angular frequency
A = expm(self.feedback_matrix() * dt)
omega = 2 * np.pi / self.period
harmonics = np.arange(self.order + 1) * omega
A = (
np.exp(-dt / self.lengthscale_matern)
* block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics))
)
return A

def feedback_matrix(self):
Expand Down Expand Up @@ -990,35 +1018,34 @@ def kernel_to_state_space(self, R=None):
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
# The angular frequency
omega = 2 * np.pi / self.period
harmonics = np.arange(self.order + 1) * omega
# The model
F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
L_p = np.eye(2 * (self.order + 1))
# Qc_p = np.zeros(2 * (self.N + 1))
Pinf_p = np.kron(np.diag(q2), np.eye(2))
Pinf_per = q2[:, None, None] * np.eye(2)
H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
lam = 3.0 ** 0.5 / self.lengthscale_matern
F_m = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
L_m = np.array([[0],
[1]])
Qc_m = np.array([[12.0 * 3.0 ** 0.5 / self.lengthscale_matern ** 3.0 * self.variance]])
H_m = np.array([[1.0, 0.0]])
Pinf_m = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
# F = np.kron(F_p, np.eye(2)) + np.kron(np.eye(14), F_m)
F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p)
Pinf_mat = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
F = block_diag(
*vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics)
)
L = np.kron(L_m, L_p)
Qc = np.kron(Qc_m, Pinf_p)
# note: Qc is always kron(Qc_m, q2I), not kron(Qc_m, Pinf_per). See eq (32) of Solin & Sarkka 2014.
Qc = block_diag(*np.kron(Qc_m, q2[:, None, None] * np.eye(2)))
H = np.kron(H_m, H_p)
Pinf = np.kron(Pinf_m, Pinf_p)
Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per))
return F, L, Qc, H, Pinf

def stationary_covariance(self):
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
Pinf_m = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
Pinf_p = np.kron(np.diag(q2), np.eye(2))
Pinf = np.kron(Pinf_m, Pinf_p)
Pinf_mat = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
Pinf_per = q2[:, None, None] * np.eye(2)
Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per))
return Pinf

def measurement_model(self):
Expand All @@ -1034,19 +1061,55 @@ def state_transition(self, dt):
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [M+1, 1]
:return: state transition matrix A [M+1, D, D]
"""
A = expm(self.feedback_matrix()*dt)

# The angular frequency
omega = 2 * np.pi / self.period
harmonics = np.arange(self.order + 1) * omega
A = block_diag(
*vmap(self.state_transition_subband_matern32, [None, None, 0])(dt, self.lengthscale_matern, harmonics)
)
return A

def feedback_matrix(self):
# The angular frequency
# The angular fundamental frequency
omega = 2 * np.pi / self.period
# The model
F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
lam = 3.0 ** 0.5 / self.lengthscale_matern
F_m = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p)
harmonics = np.arange(self.order + 1) * omega
F = block_diag(
*vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics)
)
return F

@staticmethod
def state_transition_subband_matern32(dt, ell, radial_freq):
# TODO: re-use code from SubbandMatern32 kernel
"""
Calculation of the closed form discrete-time state
transition matrix A = expm(FΔt) for the Subband Matern-3/2 prior
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [1]
:param ell: lengthscale of the Matern component of the kernel
:param radial_freq: the radial (i.e. angular) frequency of the oscillator
:return: state transition matrix A [4, 4]
"""
lam = np.sqrt(3.0) / ell
R = rotation_matrix(dt, radial_freq)
A = np.exp(-dt * lam) * np.block([
[(1. + dt * lam) * R, dt * R],
[-dt * lam ** 2 * R, (1. - dt * lam) * R]
])
return A

@staticmethod
def feedback_matrix_subband_matern32(len, omega):
# TODO: re-use code from SubbandMatern32 kernel
lam = 3.0 ** 0.5 / len
F_mat = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
F_cos = np.array([[0.0, -omega],
[omega, 0.0]])
# F = (0 -ω 1 0
# ω 0 0 1
# -λ² 0 -2λ -ω
# 0 -λ² ω -2λ)
F = np.kron(F_mat, np.eye(2)) + np.kron(np.eye(2), F_cos)
return F


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ matplotlib
scipy
scikit-learn
pandas
tensorflow_probability
tensorflow_probability==0.21

0 comments on commit bd6a1ee

Please sign in to comment.