Skip to content

Commit

Permalink
Merged changes from Aalto-main
Browse files Browse the repository at this point in the history
  • Loading branch information
ThoreWietzke committed Dec 22, 2023
2 parents 59f52fc + 0281955 commit 6cd6fa1
Show file tree
Hide file tree
Showing 17 changed files with 157 additions and 54 deletions.
6 changes: 3 additions & 3 deletions bayesnewton/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None):
H = self.kernel.measurement_model()
if self.spatio_temporal:
# TODO: if R is fixed, only compute B, C once
B, C = self.kernel.spatial_conditional(X, R)
B, C = self.kernel.spatial_conditional(X, R, predict=True)
W = B @ H
test_mean = W @ state_mean
test_var = W @ state_cov @ transpose(W) + C
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def predict(self, X, R=None):
H = self.kernel.measurement_model()
if self.spatio_temporal:
# TODO: if R is fixed, only compute B, C once
B, C = self.kernel.spatial_conditional(X, R)
B, C = self.kernel.spatial_conditional(X, R, predict=True)
W = B @ H
test_mean = W @ state_mean
test_var = W @ state_cov @ transpose(W) + C
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None):
H = self.kernel.measurement_model()
if self.spatio_temporal:
# TODO: if R is fixed, only compute B, C once
B, C = self.kernel.spatial_conditional(X, R)
B, C = self.kernel.spatial_conditional(X, R, predict=True)
W = B @ H
test_mean = W @ state_mean
test_var = W @ state_cov @ transpose(W) + C
Expand Down
123 changes: 94 additions & 29 deletions bayesnewton/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
from warnings import warn
from tensorflow_probability.substrates.jax.math import bessel_ive

FACTORIALS = np.array([1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800])


def factorial(i):
return FACTORIALS[i]


def coeff(j, lengthscale, order):
"""
Can be used to generate co-efficients for the quasi-periodic kernels that guarantee a valid covariance function:
q2 = np.array([coeff(j, lengthscale, order) for j in range(order + 1)])
See eq (26) of [1].
Not currently used (we use Bessel functions instead for clarity).
[1] Solin & Sarkka (2014) "Explicit Link Between Periodic Covariance Functions and State Space Models".
"""
s = sum([
# (2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order-j) / 2))
(2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order+2-j) / 2))
])
return 1 / np.exp(lengthscale ** -2) * s


class Kernel(objax.Module):
"""
Expand Down Expand Up @@ -39,6 +61,7 @@ def feedback_matrix(self):
raise NotImplementedError

def state_transition(self, dt):
# TODO(32): fix prediction when using expm to compute the state transition.
F = self.feedback_matrix()
A = expm(F * dt)
return A
Expand Down Expand Up @@ -839,8 +862,9 @@ def state_transition(self, dt):
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [1]
:return: state transition matrix A [2(N+1), 2(N+1)]
"""
A = expm(self.feedback_matrix()*dt)

omega = 2 * np.pi / self.period # The angular frequency
harmonics = np.arange(self.order + 1) * omega
A = block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics))
return A

def feedback_matrix(self):
Expand Down Expand Up @@ -934,7 +958,12 @@ def state_transition(self, dt):
:return: state transition matrix A [M+1, D, D]
"""
# The angular frequency
A = expm(self.feedback_matrix() * dt)
omega = 2 * np.pi / self.period
harmonics = np.arange(self.order + 1) * omega
A = (
np.exp(-dt / self.lengthscale_matern)
* block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics))
)
return A

def feedback_matrix(self):
Expand Down Expand Up @@ -990,41 +1019,40 @@ def kernel_to_state_space(self, R=None):
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
# The angular frequency
omega = 2 * np.pi / self.period
harmonics = np.arange(self.order + 1) * omega
# The model
F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
L_p = np.eye(2 * (self.order + 1))
# Qc_p = np.zeros(2 * (self.N + 1))
Pinf_p = np.kron(np.diag(q2), np.eye(2))
Pinf_per = q2[:, None, None] * np.eye(2)
H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
lam = 3.0 ** 0.5 / self.lengthscale_matern
F_m = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
L_m = np.array([[0],
[1]])
Qc_m = np.array([[12.0 * 3.0 ** 0.5 / self.lengthscale_matern ** 3.0 * self.variance]])
H_m = np.array([[1.0, 0.0]])
Pinf_m = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
# F = np.kron(F_p, np.eye(2)) + np.kron(np.eye(14), F_m)
F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p)
Pinf_mat = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
F = block_diag(
*vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics)
)
L = np.kron(L_m, L_p)
Qc = np.kron(Qc_m, Pinf_p)
H = np.kron(H_m, H_p)
Pinf = np.kron(Pinf_m, Pinf_p)
# note: Qc is always kron(Qc_m, q2I), not kron(Qc_m, Pinf_per). See eq (32) of Solin & Sarkka 2014.
Qc = block_diag(*np.kron(Qc_m, q2[:, None, None] * np.eye(2)))
H = np.kron(H_p, H_m)
Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per))
return F, L, Qc, H, Pinf

def stationary_covariance(self):
q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2))
Pinf_m = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
Pinf_p = np.kron(np.diag(q2), np.eye(2))
Pinf = np.kron(Pinf_m, Pinf_p)
Pinf_mat = np.array([[self.variance, 0.0],
[0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]])
Pinf_per = q2[:, None, None] * np.eye(2)
Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per))
return Pinf

def measurement_model(self):
H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.]))
H_m = np.array([[1.0, 0.0]])
H = np.kron(H_m, H_p)
H = np.kron(H_p, H_m)
return H

def state_transition(self, dt):
Expand All @@ -1034,19 +1062,55 @@ def state_transition(self, dt):
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [M+1, 1]
:return: state transition matrix A [M+1, D, D]
"""
A = expm(self.feedback_matrix()*dt)

# The angular frequency
omega = 2 * np.pi / self.period
harmonics = np.arange(self.order + 1) * omega
A = block_diag(
*vmap(self.state_transition_subband_matern32, [None, None, 0])(dt, self.lengthscale_matern, harmonics)
)
return A

def feedback_matrix(self):
# The angular frequency
# The angular fundamental frequency
omega = 2 * np.pi / self.period
# The model
F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]]))
lam = 3.0 ** 0.5 / self.lengthscale_matern
F_m = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p)
harmonics = np.arange(self.order + 1) * omega
F = block_diag(
*vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics)
)
return F

@staticmethod
def state_transition_subband_matern32(dt, ell, radial_freq):
# TODO: re-use code from SubbandMatern32 kernel
"""
Calculation of the closed form discrete-time state
transition matrix A = expm(FΔt) for the Subband Matern-3/2 prior
:param dt: step size(s), Δt = tₙ - tₙ₋₁ [1]
:param ell: lengthscale of the Matern component of the kernel
:param radial_freq: the radial (i.e. angular) frequency of the oscillator
:return: state transition matrix A [4, 4]
"""
lam = np.sqrt(3.0) / ell
R = rotation_matrix(dt, radial_freq)
A = np.exp(-dt * lam) * np.block([
[(1. + dt * lam) * R, dt * R],
[-dt * lam ** 2 * R, (1. - dt * lam) * R]
])
return A

@staticmethod
def feedback_matrix_subband_matern32(len, omega):
# TODO: re-use code from SubbandMatern32 kernel
lam = 3.0 ** 0.5 / len
F_mat = np.array([[0.0, 1.0],
[-lam ** 2, -2 * lam]])
F_cos = np.array([[0.0, -omega],
[omega, 0.0]])
# F = (0 -ω 1 0
# ω 0 0 1
# -λ² 0 -2λ -ω
# 0 -λ² ω -2λ)
F = np.kron(F_mat, np.eye(2)) + np.kron(np.eye(2), F_cos)
return F


Expand Down Expand Up @@ -1667,6 +1731,7 @@ def __init__(self,
assert N.shape == R.shape
self.N_ = objax.StateVar(np.array(N))
self.R_ = objax.StateVar(np.array(R))
warn("Prediction is partially broken when using the LatentExponentiallyGenerated kernel. See issue #32.")

@property
def N(self):
Expand Down
36 changes: 36 additions & 0 deletions bayesnewton/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ def __init__(self, kernel, likelihood, X, Y):
super().__init__(kernel, likelihood, X, Y)


class VariationalRiemannGP(VariationalInferenceRiemann, GaussianProcess):
"""
Variational Gaussian process [1], adapted to use conjugate computation VI [2] with PSD guarantees [3].
:param kernel: a kernel object
:param likelihood: a likelihood object
:param X: inputs
:param Y: observations
[1] Opper, Archambeau: The Variational Gaussian Approximation Revisited, Neural Computation, 2009
[2] Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to
Inference in Conjugate Models, AISTATS 2017
[3] Lin, Schmidt & Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020
"""
def __init__(self, kernel, likelihood, X, Y):
super().__init__(kernel, likelihood, X, Y)


class SparseVariationalGP(VariationalInference, SparseGaussianProcess):
"""
Sparse variational Gaussian process (SVGP) [1], adapted to use conjugate computation VI [2]
Expand All @@ -76,6 +93,25 @@ def __init__(self, kernel, likelihood, X, Y, Z, opt_z=False):
super().__init__(kernel, likelihood, X, Y, Z, opt_z)


class SparseVariationalRiemannGP(VariationalInferenceRiemann, SparseGaussianProcess):
"""
Sparse variational Gaussian process (SVGP) [1], adapted to use conjugate computation VI [2] with PSD guarantees [3].
:param kernel: a kernel object
:param likelihood: a likelihood object
:param X: inputs
:param Y: observations
:param Z: inducing inputs
:param opt_z: boolean determining whether to optimise the inducing input locations
[1] Hensman, Matthews, Ghahramani: Scalable Variational Gaussian Process Classification, AISTATS 2015
[2] Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to
Inference in Conjugate Models, AISTATS 2017
[3] Lin, Schmidt & Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020
"""
def __init__(self, kernel, likelihood, X, Y, Z, opt_z=False):
super().__init__(kernel, likelihood, X, Y, Z, opt_z)


SVGP = SparseVariationalGP


Expand Down
4 changes: 2 additions & 2 deletions demos/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
x = np.concatenate([x0, np.array([50]), x1], axis=0)
# x = np.linspace(np.min(x), np.max(x), N)
f = lambda x_: 6 * np.sin(np.pi * x_ / 10.0) / (np.pi * x_ / 10.0 + 1)
y_ = f(x) + np.math.sqrt(0.05)*np.random.randn(x.shape[0])
y_ = f(x) + np.sqrt(0.05)*np.random.randn(x.shape[0])
y = np.sign(y_)
y[y == -1] = 0
x_test = np.linspace(np.min(x)-5.0, np.max(x)+5.0, num=500)
y_test = np.sign(f(x_test) + np.math.sqrt(0.05)*np.random.randn(x_test.shape[0]))
y_test = np.sign(f(x_test) + np.sqrt(0.05)*np.random.randn(x_test.shape[0]))
y_test[y_test == -1] = 0
x_plot = np.linspace(np.min(x)-10.0, np.max(x)+10.0, num=500)
z = np.linspace(min(x), max(x), num=M)
Expand Down
10 changes: 5 additions & 5 deletions demos/heteroscedastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
y_scaler = StandardScaler().fit(Y)
Xall = X_scaler.transform(X)
Yall = y_scaler.transform(Y)
x_plot = np.linspace(np.min(Xall)-0.2, np.max(Xall)+0.2, 200)
x_plot = np.linspace(np.min(Xall)-0.2, np.max(Xall)+0.2, 200)[:, None]

# Load cross-validation indices
cvind = np.loadtxt('../experiments/motorcycle/cvind.csv').astype(int)
Expand Down Expand Up @@ -123,17 +123,17 @@ def train_op():
link = model.likelihood.link_fn
lb = posterior_mean[:, 0] - np.sqrt(posterior_var[:, 0, 0] + link(posterior_mean[:, 1]) ** 2) * 1.96
ub = posterior_mean[:, 0] + np.sqrt(posterior_var[:, 0, 0] + link(posterior_mean[:, 1]) ** 2) * 1.96
post_mean = y_scaler.inverse_transform(posterior_mean[:, 0])
lb = y_scaler.inverse_transform(lb)
ub = y_scaler.inverse_transform(ub)
post_mean = y_scaler.inverse_transform(posterior_mean[:, 0:1])
lb = y_scaler.inverse_transform(lb[:, None])[:, 0]
ub = y_scaler.inverse_transform(ub[:, None])[:, 0]

print('plotting ...')
plt.figure(1, figsize=(12, 5))
plt.clf()
plt.plot(X_scaler.inverse_transform(X), y_scaler.inverse_transform(Y), 'k.', label='train')
plt.plot(X_scaler.inverse_transform(XT), y_scaler.inverse_transform(YT), 'r.', label='test')
plt.plot(x_pred, post_mean, 'c', label='posterior mean')
plt.fill_between(x_pred, lb, ub, color='c', alpha=0.05, label='95% confidence')
plt.fill_between(x_pred[:, 0], lb, ub, color='c', alpha=0.05, label='95% confidence')
plt.xlim(x_pred[0], x_pred[-1])
if hasattr(model, 'Z'):
plt.plot(X_scaler.inverse_transform(model.Z.value[:, 0]),
Expand Down
4 changes: 2 additions & 2 deletions demos/positive.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def wiggly_time_series(x_):
# x = np.linspace(np.min(x), np.max(x), N)
# f = lambda x_: 3 * np.sin(np.pi * x_ / 10.0)
f = wiggly_time_series
y = nonlinearity(f(x)) + np.math.sqrt(0.1)*np.random.randn(x.shape[0])
y = nonlinearity(f(x)) + np.sqrt(0.1)*np.random.randn(x.shape[0])
x_test = np.linspace(np.min(x), np.max(x), num=500)
y_test = nonlinearity(f(x_test)) + np.math.sqrt(0.05)*np.random.randn(x_test.shape[0])
y_test = nonlinearity(f(x_test)) + np.sqrt(0.05)*np.random.randn(x_test.shape[0])
x_plot = np.linspace(np.min(x)-10.0, np.max(x)+10.0, num=500)

M = 20
Expand Down
4 changes: 2 additions & 2 deletions demos/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ def wiggly_time_series(x_):
noise_var = 0.2 # true observation noise
# return 0.25 * (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) +
return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) +
np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) +
# np.math.sqrt(noise_var) * np.random.uniform(-4, 4, x_.shape) +
np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) +
# np.sqrt(noise_var) * np.random.uniform(-4, 4, x_.shape) +
0.0 * x_) # 0.02 * x_)
# 0.0 * x_) + 2.5 # 0.02 * x_)

Expand Down
2 changes: 1 addition & 1 deletion demos/studentt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def wiggly_time_series(x_):
noise_var = 0.2 # true observation noise scale
return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) +
np.math.sqrt(noise_var) * np.random.standard_t(3., x_.shape) +
np.sqrt(noise_var) * np.random.standard_t(3., x_.shape) +
0.0 * x_)


Expand Down
2 changes: 1 addition & 1 deletion experiments/binary/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
x = np.sort(70 * np.random.rand(N))
sn = 0.01
f = lambda x_: 12. * np.sin(4 * np.pi * x_) / (0.25 * np.pi * x_ + 1)
y_ = f(x) + np.math.sqrt(sn)*np.random.randn(x.shape[0])
y_ = f(x) + np.sqrt(sn)*np.random.randn(x.shape[0])
y = np.sign(y_)
y[y == -1] = 0

Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
jax==0.4.14
jaxlib==0.4.14
objax==1.7.0
tensorflow_probability==0.20.1
numpy
matplotlib
scipy
scipy
scikit-learn
pandas
tensorflow_probability==0.21
2 changes: 1 addition & 1 deletion tests/normaliser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def wiggly_time_series(x_):
noise_var = 0.15 # true observation noise
return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) +
np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) +
np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) +
0.0 * x_) # 0.02 * x_)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_gp_vs_markovgp_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def build_data(N):
x = 100 * np.random.rand(N)
x = np.sort(x) # since MarkovGP sorts the inputs, they must also be sorted for GP
f = lambda x_: 6 * np.sin(np.pi * x_ / 10.0) / (np.pi * x_ / 10.0 + 1)
y_ = f(x) + np.math.sqrt(0.05) * np.random.randn(x.shape[0])
y_ = f(x) + np.sqrt(0.05) * np.random.randn(x.shape[0])
y = np.sign(y_)
y[y == -1] = 0
x = x[:, None]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gp_vs_markovgp_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def wiggly_time_series(x_):
noise_var = 0.15 # true observation noise
return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) +
np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape))
np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape))


def build_data(N):
Expand Down
Loading

0 comments on commit 6cd6fa1

Please sign in to comment.