Skip to content

Commit

Permalink
Merge pull request #34 from ThoreWietzke/AaltoML-main
Browse files Browse the repository at this point in the history
Removed numba requirement und bumped jax to 0.4.14 for native windows support
  • Loading branch information
wil-j-wil authored Dec 22, 2023
2 parents 0281955 + d27991a commit f72ae9a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 44 deletions.
23 changes: 12 additions & 11 deletions bayesnewton/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import objax
import jax.numpy as np
from jax import vmap
from jax import vmap, Array
from .utils import (
diag,
transpose,
Expand Down Expand Up @@ -48,9 +48,10 @@ class InferenceMixin(abc.ABC):
TODO: re-derive and re-implement QuasiNewton methods
TODO: move as much of the generic functionality as possible from the base model class to this class.
"""

num_data: float
Y: np.DeviceArray
ind: np.DeviceArray
Y: Array
ind: Array
pseudo_likelihood: GaussianDistribution
posterior_mean: objax.StateVar
posterior_var: objax.StateVar
Expand Down Expand Up @@ -231,8 +232,8 @@ class ExpectationPropagation(InferenceMixin):
compute_full_pseudo_lik: classmethod
compute_log_lik: classmethod
compute_ep_energy_terms: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, ensure_psd=True, **kwargs):
"""
Expand Down Expand Up @@ -332,8 +333,8 @@ class PosteriorLinearisation(InferenceMixin):
# TODO: remove these when possible
cavity_distribution: classmethod
compute_full_pseudo_lik: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, **kwargs):
"""
Expand Down Expand Up @@ -432,8 +433,8 @@ class PosteriorLinearisation2ndOrder(PosteriorLinearisation):
"""
# TODO: remove these when possible
compute_full_pseudo_lik: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, ensure_psd=True, **kwargs):
"""
Expand Down Expand Up @@ -574,8 +575,8 @@ class PosteriorLinearisation2ndOrderGaussNewton(PosteriorLinearisation):
"""
# TODO: remove these when possible
compute_full_pseudo_lik: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, **kwargs):
"""
Expand Down
29 changes: 0 additions & 29 deletions bayesnewton/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from jax.lax import scan
# from matplotlib._png import read_png
import math
from functools import partial
import numba as nb
import numpy as onp

LOG2PI = math.log(2 * math.pi)
INV2PI = (2 * math.pi) ** -1
Expand Down Expand Up @@ -520,7 +517,6 @@ def gaussian_expected_log_lik(Y, q_mu, q_covar, noise, mask=None):
:return:
E[log 𝓝(yₙ|fₙ,σ²)] = ∫ log 𝓝(yₙ|fₙ,σ²) 𝓝(fₙ|mₙ,vₙ) dfₙ
"""

if mask is not None:
# build a mask for computing the log likelihood of a partially observed multivariate Gaussian
maskv = mask.reshape(-1, 1)
Expand Down Expand Up @@ -633,31 +629,6 @@ def rotation_matrix(dt, omega):
return R


@partial(nb.jit, nopython=True)
def nb_balance_ss(F: onp.ndarray,
iters: int) -> onp.ndarray:
"""
taken from https://github.com/EEA-sensors/parallel-gps/blob/main/pssgp/kernels/math_utils.py
"""
dim = F.shape[0]
dtype = F.dtype
d = onp.ones((dim,), dtype=dtype)
for k in range(iters):
for i in range(dim):
tmp = onp.copy(F[:, i])
tmp[i] = 0.
c = onp.linalg.norm(tmp, 2)
tmp2 = onp.copy(F[i, :])
tmp2[i] = 0.

r = onp.linalg.norm(tmp2, 2)
f = onp.sqrt(r / c)
d[i] *= f
F[:, i] *= f
F[i, :] /= f
return d


def balance(F: np.ndarray,
iters: int) -> np.ndarray:
dim = F.shape[0]
Expand Down
7 changes: 3 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
jax==0.4.2
jaxlib==0.4.2
objax==1.6.0
numba
jax==0.4.14
jaxlib==0.4.14
objax==1.7.0
numpy
matplotlib
scipy
Expand Down
7 changes: 7 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
version=__version__,
packages=find_packages(),
python_requires='>=3.6',
install_requires=[
"jax==0.4.14",
"jaxlib==0.4.14",
"objax==1.7.0",
"tensorflow_probability==0.21",
"numpy>=1.22"
],
url='https://github.com/AaltoML/BayesNewton',
license='Apache-2.0',
)

0 comments on commit f72ae9a

Please sign in to comment.