diff --git a/bayesnewton/basemodels.py b/bayesnewton/basemodels.py index ac221dd..b992f34 100644 --- a/bayesnewton/basemodels.py +++ b/bayesnewton/basemodels.py @@ -803,7 +803,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C @@ -1060,7 +1060,7 @@ def predict(self, X, R=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C @@ -1216,7 +1216,7 @@ def predict(self, X=None, R=None, pseudo_lik_params=None): H = self.kernel.measurement_model() if self.spatio_temporal: # TODO: if R is fixed, only compute B, C once - B, C = self.kernel.spatial_conditional(X, R) + B, C = self.kernel.spatial_conditional(X, R, predict=True) W = B @ H test_mean = W @ state_mean test_var = W @ state_cov @ transpose(W) + C diff --git a/bayesnewton/kernels.py b/bayesnewton/kernels.py index dab7e37..4c09a03 100644 --- a/bayesnewton/kernels.py +++ b/bayesnewton/kernels.py @@ -6,6 +6,28 @@ from warnings import warn from tensorflow_probability.substrates.jax.math import bessel_ive +FACTORIALS = np.array([1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800]) + + +def factorial(i): + return FACTORIALS[i] + + +def coeff(j, lengthscale, order): + """ + Can be used to generate co-efficients for the quasi-periodic kernels that guarantee a valid covariance function: + q2 = np.array([coeff(j, lengthscale, order) for j in range(order + 1)]) + See eq (26) of [1]. + Not currently used (we use Bessel functions instead for clarity). + + [1] Solin & Sarkka (2014) "Explicit Link Between Periodic Covariance Functions and State Space Models". + """ + s = sum([ + # (2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order-j) / 2)) + (2 * lengthscale ** 2) ** -(j + 2 * i) / (factorial(j+i) * factorial(i)) for i in range(int((order+2-j) / 2)) + ]) + return 1 / np.exp(lengthscale ** -2) * s + class Kernel(objax.Module): """ @@ -39,6 +61,7 @@ def feedback_matrix(self): raise NotImplementedError def state_transition(self, dt): + # TODO(32): fix prediction when using expm to compute the state transition. F = self.feedback_matrix() A = expm(F * dt) return A @@ -839,8 +862,9 @@ def state_transition(self, dt): :param dt: step size(s), Δt = tₙ - tₙ₋₁ [1] :return: state transition matrix A [2(N+1), 2(N+1)] """ - A = expm(self.feedback_matrix()*dt) - + omega = 2 * np.pi / self.period # The angular frequency + harmonics = np.arange(self.order + 1) * omega + A = block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics)) return A def feedback_matrix(self): @@ -934,7 +958,12 @@ def state_transition(self, dt): :return: state transition matrix A [M+1, D, D] """ # The angular frequency - A = expm(self.feedback_matrix() * dt) + omega = 2 * np.pi / self.period + harmonics = np.arange(self.order + 1) * omega + A = ( + np.exp(-dt / self.lengthscale_matern) + * block_diag(*vmap(rotation_matrix, [None, 0])(dt, harmonics)) + ) return A def feedback_matrix(self): @@ -990,41 +1019,40 @@ def kernel_to_state_space(self, R=None): q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) # The angular frequency omega = 2 * np.pi / self.period + harmonics = np.arange(self.order + 1) * omega # The model - F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]])) L_p = np.eye(2 * (self.order + 1)) # Qc_p = np.zeros(2 * (self.N + 1)) - Pinf_p = np.kron(np.diag(q2), np.eye(2)) + Pinf_per = q2[:, None, None] * np.eye(2) H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.])) - lam = 3.0 ** 0.5 / self.lengthscale_matern - F_m = np.array([[0.0, 1.0], - [-lam ** 2, -2 * lam]]) L_m = np.array([[0], [1]]) Qc_m = np.array([[12.0 * 3.0 ** 0.5 / self.lengthscale_matern ** 3.0 * self.variance]]) H_m = np.array([[1.0, 0.0]]) - Pinf_m = np.array([[self.variance, 0.0], - [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) - # F = np.kron(F_p, np.eye(2)) + np.kron(np.eye(14), F_m) - F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p) + Pinf_mat = np.array([[self.variance, 0.0], + [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) + F = block_diag( + *vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics) + ) L = np.kron(L_m, L_p) - Qc = np.kron(Qc_m, Pinf_p) - H = np.kron(H_m, H_p) - Pinf = np.kron(Pinf_m, Pinf_p) + # note: Qc is always kron(Qc_m, q2I), not kron(Qc_m, Pinf_per). See eq (32) of Solin & Sarkka 2014. + Qc = block_diag(*np.kron(Qc_m, q2[:, None, None] * np.eye(2))) + H = np.kron(H_p, H_m) + Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per)) return F, L, Qc, H, Pinf def stationary_covariance(self): q2 = np.array([1, *[2]*self.order]) * bessel_ive([*range(self.order+1)], self.lengthscale_periodic**(-2)) - Pinf_m = np.array([[self.variance, 0.0], - [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) - Pinf_p = np.kron(np.diag(q2), np.eye(2)) - Pinf = np.kron(Pinf_m, Pinf_p) + Pinf_mat = np.array([[self.variance, 0.0], + [0.0, 3.0 * self.variance / self.lengthscale_matern ** 2.0]]) + Pinf_per = q2[:, None, None] * np.eye(2) + Pinf = block_diag(*np.kron(Pinf_mat, Pinf_per)) return Pinf def measurement_model(self): H_p = np.kron(np.ones([1, self.order + 1]), np.array([1., 0.])) H_m = np.array([[1.0, 0.0]]) - H = np.kron(H_m, H_p) + H = np.kron(H_p, H_m) return H def state_transition(self, dt): @@ -1034,19 +1062,55 @@ def state_transition(self, dt): :param dt: step size(s), Δt = tₙ - tₙ₋₁ [M+1, 1] :return: state transition matrix A [M+1, D, D] """ - A = expm(self.feedback_matrix()*dt) - + # The angular frequency + omega = 2 * np.pi / self.period + harmonics = np.arange(self.order + 1) * omega + A = block_diag( + *vmap(self.state_transition_subband_matern32, [None, None, 0])(dt, self.lengthscale_matern, harmonics) + ) return A def feedback_matrix(self): - # The angular frequency + # The angular fundamental frequency omega = 2 * np.pi / self.period - # The model - F_p = np.kron(np.diag(np.arange(self.order + 1)), np.array([[0., -omega], [omega, 0.]])) - lam = 3.0 ** 0.5 / self.lengthscale_matern - F_m = np.array([[0.0, 1.0], - [-lam ** 2, -2 * lam]]) - F = np.kron(F_m, np.eye(2 * (self.order + 1))) + np.kron(np.eye(2), F_p) + harmonics = np.arange(self.order + 1) * omega + F = block_diag( + *vmap(self.feedback_matrix_subband_matern32, [None, 0])(self.lengthscale_matern, harmonics) + ) + return F + + @staticmethod + def state_transition_subband_matern32(dt, ell, radial_freq): + # TODO: re-use code from SubbandMatern32 kernel + """ + Calculation of the closed form discrete-time state + transition matrix A = expm(FΔt) for the Subband Matern-3/2 prior + :param dt: step size(s), Δt = tₙ - tₙ₋₁ [1] + :param ell: lengthscale of the Matern component of the kernel + :param radial_freq: the radial (i.e. angular) frequency of the oscillator + :return: state transition matrix A [4, 4] + """ + lam = np.sqrt(3.0) / ell + R = rotation_matrix(dt, radial_freq) + A = np.exp(-dt * lam) * np.block([ + [(1. + dt * lam) * R, dt * R], + [-dt * lam ** 2 * R, (1. - dt * lam) * R] + ]) + return A + + @staticmethod + def feedback_matrix_subband_matern32(len, omega): + # TODO: re-use code from SubbandMatern32 kernel + lam = 3.0 ** 0.5 / len + F_mat = np.array([[0.0, 1.0], + [-lam ** 2, -2 * lam]]) + F_cos = np.array([[0.0, -omega], + [omega, 0.0]]) + # F = (0 -ω 1 0 + # ω 0 0 1 + # -λ² 0 -2λ -ω + # 0 -λ² ω -2λ) + F = np.kron(F_mat, np.eye(2)) + np.kron(np.eye(2), F_cos) return F @@ -1667,6 +1731,7 @@ def __init__(self, assert N.shape == R.shape self.N_ = objax.StateVar(np.array(N)) self.R_ = objax.StateVar(np.array(R)) + warn("Prediction is partially broken when using the LatentExponentiallyGenerated kernel. See issue #32.") @property def N(self): diff --git a/bayesnewton/models.py b/bayesnewton/models.py index c386e20..ff6e32c 100644 --- a/bayesnewton/models.py +++ b/bayesnewton/models.py @@ -58,6 +58,23 @@ def __init__(self, kernel, likelihood, X, Y): super().__init__(kernel, likelihood, X, Y) +class VariationalRiemannGP(VariationalInferenceRiemann, GaussianProcess): + """ + Variational Gaussian process [1], adapted to use conjugate computation VI [2] with PSD guarantees [3]. + :param kernel: a kernel object + :param likelihood: a likelihood object + :param X: inputs + :param Y: observations + + [1] Opper, Archambeau: The Variational Gaussian Approximation Revisited, Neural Computation, 2009 + [2] Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to + Inference in Conjugate Models, AISTATS 2017 + [3] Lin, Schmidt & Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020 + """ + def __init__(self, kernel, likelihood, X, Y): + super().__init__(kernel, likelihood, X, Y) + + class SparseVariationalGP(VariationalInference, SparseGaussianProcess): """ Sparse variational Gaussian process (SVGP) [1], adapted to use conjugate computation VI [2] @@ -76,6 +93,25 @@ def __init__(self, kernel, likelihood, X, Y, Z, opt_z=False): super().__init__(kernel, likelihood, X, Y, Z, opt_z) +class SparseVariationalRiemannGP(VariationalInferenceRiemann, SparseGaussianProcess): + """ + Sparse variational Gaussian process (SVGP) [1], adapted to use conjugate computation VI [2] with PSD guarantees [3]. + :param kernel: a kernel object + :param likelihood: a likelihood object + :param X: inputs + :param Y: observations + :param Z: inducing inputs + :param opt_z: boolean determining whether to optimise the inducing input locations + + [1] Hensman, Matthews, Ghahramani: Scalable Variational Gaussian Process Classification, AISTATS 2015 + [2] Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to + Inference in Conjugate Models, AISTATS 2017 + [3] Lin, Schmidt & Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020 + """ + def __init__(self, kernel, likelihood, X, Y, Z, opt_z=False): + super().__init__(kernel, likelihood, X, Y, Z, opt_z) + + SVGP = SparseVariationalGP diff --git a/demos/classification.py b/demos/classification.py index bc37716..4501b62 100644 --- a/demos/classification.py +++ b/demos/classification.py @@ -14,11 +14,11 @@ x = np.concatenate([x0, np.array([50]), x1], axis=0) # x = np.linspace(np.min(x), np.max(x), N) f = lambda x_: 6 * np.sin(np.pi * x_ / 10.0) / (np.pi * x_ / 10.0 + 1) -y_ = f(x) + np.math.sqrt(0.05)*np.random.randn(x.shape[0]) +y_ = f(x) + np.sqrt(0.05)*np.random.randn(x.shape[0]) y = np.sign(y_) y[y == -1] = 0 x_test = np.linspace(np.min(x)-5.0, np.max(x)+5.0, num=500) -y_test = np.sign(f(x_test) + np.math.sqrt(0.05)*np.random.randn(x_test.shape[0])) +y_test = np.sign(f(x_test) + np.sqrt(0.05)*np.random.randn(x_test.shape[0])) y_test[y_test == -1] = 0 x_plot = np.linspace(np.min(x)-10.0, np.max(x)+10.0, num=500) z = np.linspace(min(x), max(x), num=M) diff --git a/demos/heteroscedastic.py b/demos/heteroscedastic.py index 8da6dbf..a5dce48 100644 --- a/demos/heteroscedastic.py +++ b/demos/heteroscedastic.py @@ -15,7 +15,7 @@ y_scaler = StandardScaler().fit(Y) Xall = X_scaler.transform(X) Yall = y_scaler.transform(Y) -x_plot = np.linspace(np.min(Xall)-0.2, np.max(Xall)+0.2, 200) +x_plot = np.linspace(np.min(Xall)-0.2, np.max(Xall)+0.2, 200)[:, None] # Load cross-validation indices cvind = np.loadtxt('../experiments/motorcycle/cvind.csv').astype(int) @@ -123,9 +123,9 @@ def train_op(): link = model.likelihood.link_fn lb = posterior_mean[:, 0] - np.sqrt(posterior_var[:, 0, 0] + link(posterior_mean[:, 1]) ** 2) * 1.96 ub = posterior_mean[:, 0] + np.sqrt(posterior_var[:, 0, 0] + link(posterior_mean[:, 1]) ** 2) * 1.96 -post_mean = y_scaler.inverse_transform(posterior_mean[:, 0]) -lb = y_scaler.inverse_transform(lb) -ub = y_scaler.inverse_transform(ub) +post_mean = y_scaler.inverse_transform(posterior_mean[:, 0:1]) +lb = y_scaler.inverse_transform(lb[:, None])[:, 0] +ub = y_scaler.inverse_transform(ub[:, None])[:, 0] print('plotting ...') plt.figure(1, figsize=(12, 5)) @@ -133,7 +133,7 @@ def train_op(): plt.plot(X_scaler.inverse_transform(X), y_scaler.inverse_transform(Y), 'k.', label='train') plt.plot(X_scaler.inverse_transform(XT), y_scaler.inverse_transform(YT), 'r.', label='test') plt.plot(x_pred, post_mean, 'c', label='posterior mean') -plt.fill_between(x_pred, lb, ub, color='c', alpha=0.05, label='95% confidence') +plt.fill_between(x_pred[:, 0], lb, ub, color='c', alpha=0.05, label='95% confidence') plt.xlim(x_pred[0], x_pred[-1]) if hasattr(model, 'Z'): plt.plot(X_scaler.inverse_transform(model.Z.value[:, 0]), diff --git a/demos/positive.py b/demos/positive.py index 8b51093..479b02e 100644 --- a/demos/positive.py +++ b/demos/positive.py @@ -19,9 +19,9 @@ def wiggly_time_series(x_): # x = np.linspace(np.min(x), np.max(x), N) # f = lambda x_: 3 * np.sin(np.pi * x_ / 10.0) f = wiggly_time_series -y = nonlinearity(f(x)) + np.math.sqrt(0.1)*np.random.randn(x.shape[0]) +y = nonlinearity(f(x)) + np.sqrt(0.1)*np.random.randn(x.shape[0]) x_test = np.linspace(np.min(x), np.max(x), num=500) -y_test = nonlinearity(f(x_test)) + np.math.sqrt(0.05)*np.random.randn(x_test.shape[0]) +y_test = nonlinearity(f(x_test)) + np.sqrt(0.05)*np.random.randn(x_test.shape[0]) x_plot = np.linspace(np.min(x)-10.0, np.max(x)+10.0, num=500) M = 20 diff --git a/demos/regression.py b/demos/regression.py index 30b5cd0..e9fbf54 100644 --- a/demos/regression.py +++ b/demos/regression.py @@ -9,8 +9,8 @@ def wiggly_time_series(x_): noise_var = 0.2 # true observation noise # return 0.25 * (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + - np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) + - # np.math.sqrt(noise_var) * np.random.uniform(-4, 4, x_.shape) + + np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) + + # np.sqrt(noise_var) * np.random.uniform(-4, 4, x_.shape) + 0.0 * x_) # 0.02 * x_) # 0.0 * x_) + 2.5 # 0.02 * x_) diff --git a/demos/studentt.py b/demos/studentt.py index a4042ee..ef444cb 100644 --- a/demos/studentt.py +++ b/demos/studentt.py @@ -9,7 +9,7 @@ def wiggly_time_series(x_): noise_var = 0.2 # true observation noise scale return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + - np.math.sqrt(noise_var) * np.random.standard_t(3., x_.shape) + + np.sqrt(noise_var) * np.random.standard_t(3., x_.shape) + 0.0 * x_) diff --git a/experiments/binary/binary.py b/experiments/binary/binary.py index 2b5669c..58aede7 100644 --- a/experiments/binary/binary.py +++ b/experiments/binary/binary.py @@ -11,7 +11,7 @@ x = np.sort(70 * np.random.rand(N)) sn = 0.01 f = lambda x_: 12. * np.sin(4 * np.pi * x_) / (0.25 * np.pi * x_ + 1) -y_ = f(x) + np.math.sqrt(sn)*np.random.randn(x.shape[0]) +y_ = f(x) + np.sqrt(sn)*np.random.randn(x.shape[0]) y = np.sign(y_) y[y == -1] = 0 diff --git a/requirements.txt b/requirements.txt index 7b1a2a5..d406e9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ jax==0.4.14 jaxlib==0.4.14 objax==1.7.0 -tensorflow_probability==0.20.1 numpy matplotlib -scipy \ No newline at end of file +scipy +scikit-learn +pandas +tensorflow_probability==0.21 \ No newline at end of file diff --git a/tests/normaliser_test.py b/tests/normaliser_test.py index a21c3d5..9c9f1fd 100644 --- a/tests/normaliser_test.py +++ b/tests/normaliser_test.py @@ -13,7 +13,7 @@ def wiggly_time_series(x_): noise_var = 0.15 # true observation noise return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + - np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) + + np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape) + 0.0 * x_) # 0.02 * x_) diff --git a/tests/test_gp_vs_markovgp_class.py b/tests/test_gp_vs_markovgp_class.py index bbfc6c0..3d6365d 100644 --- a/tests/test_gp_vs_markovgp_class.py +++ b/tests/test_gp_vs_markovgp_class.py @@ -11,7 +11,7 @@ def build_data(N): x = 100 * np.random.rand(N) x = np.sort(x) # since MarkovGP sorts the inputs, they must also be sorted for GP f = lambda x_: 6 * np.sin(np.pi * x_ / 10.0) / (np.pi * x_ / 10.0 + 1) - y_ = f(x) + np.math.sqrt(0.05) * np.random.randn(x.shape[0]) + y_ = f(x) + np.sqrt(0.05) * np.random.randn(x.shape[0]) y = np.sign(y_) y[y == -1] = 0 x = x[:, None] diff --git a/tests/test_gp_vs_markovgp_reg.py b/tests/test_gp_vs_markovgp_reg.py index cf1999f..5fb65db 100644 --- a/tests/test_gp_vs_markovgp_reg.py +++ b/tests/test_gp_vs_markovgp_reg.py @@ -9,7 +9,7 @@ def wiggly_time_series(x_): noise_var = 0.15 # true observation noise return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + - np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) + np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) def build_data(N): diff --git a/tests/test_sparsemarkov.py b/tests/test_sparsemarkov.py index d55ba7b..91c199b 100644 --- a/tests/test_sparsemarkov.py +++ b/tests/test_sparsemarkov.py @@ -12,7 +12,7 @@ def wiggly_time_series(x_): noise_var = 0.15 # true observation noise return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + - np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) + np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) def build_data(N): diff --git a/tests/test_vs_exact_marg_lik.py b/tests/test_vs_exact_marg_lik.py index 10b1e5e..e0b5965 100644 --- a/tests/test_vs_exact_marg_lik.py +++ b/tests/test_vs_exact_marg_lik.py @@ -9,7 +9,7 @@ def wiggly_time_series(x_): noise_var = 0.15 # true observation noise return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + - np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) + np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) def build_data(N): diff --git a/tests/test_vs_gpflow_class.py b/tests/test_vs_gpflow_class.py index 391b475..44f5d82 100644 --- a/tests/test_vs_gpflow_class.py +++ b/tests/test_vs_gpflow_class.py @@ -15,7 +15,7 @@ def build_data(N): # np.random.seed(12345) x = 100 * np.random.rand(N) f = lambda x_: 6 * np.sin(np.pi * x_ / 10.0) / (np.pi * x_ / 10.0 + 1) - y_ = f(x) + np.math.sqrt(0.05) * np.random.randn(x.shape[0]) + y_ = f(x) + np.sqrt(0.05) * np.random.randn(x.shape[0]) y = np.sign(y_) y[y == -1] = 0 x = x[:, None] diff --git a/tests/test_vs_gpflow_reg.py b/tests/test_vs_gpflow_reg.py index 4be2d58..635a225 100644 --- a/tests/test_vs_gpflow_reg.py +++ b/tests/test_vs_gpflow_reg.py @@ -12,7 +12,7 @@ def wiggly_time_series(x_): noise_var = 0.15 # true observation noise return (np.cos(0.04*x_+0.33*np.pi) * np.sin(0.2*x_) + - np.math.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) + np.sqrt(noise_var) * np.random.normal(0, 1, x_.shape)) def build_data(N):