Skip to content

Commit

Permalink
changes for jax 0.4.14
Browse files Browse the repository at this point in the history
  • Loading branch information
ThoreWietzke committed Aug 29, 2023
1 parent dafe496 commit 339e986
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions bayesnewton/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import objax
import jax.numpy as np
from jax import vmap
from jax import vmap, Array
from .utils import (
diag,
transpose,
Expand Down Expand Up @@ -48,9 +48,10 @@ class InferenceMixin(abc.ABC):
TODO: re-derive and re-implement QuasiNewton methods
TODO: move as much of the generic functionality as possible from the base model class to this class.
"""

num_data: float
Y: np.DeviceArray
ind: np.DeviceArray
Y: Array
ind: Array
pseudo_likelihood: GaussianDistribution
posterior_mean: objax.StateVar
posterior_var: objax.StateVar
Expand Down Expand Up @@ -231,8 +232,8 @@ class ExpectationPropagation(InferenceMixin):
compute_full_pseudo_lik: classmethod
compute_log_lik: classmethod
compute_ep_energy_terms: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, ensure_psd=True, **kwargs):
"""
Expand Down Expand Up @@ -332,8 +333,8 @@ class PosteriorLinearisation(InferenceMixin):
# TODO: remove these when possible
cavity_distribution: classmethod
compute_full_pseudo_lik: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, **kwargs):
"""
Expand Down Expand Up @@ -432,8 +433,8 @@ class PosteriorLinearisation2ndOrder(PosteriorLinearisation):
"""
# TODO: remove these when possible
compute_full_pseudo_lik: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, ensure_psd=True, **kwargs):
"""
Expand Down Expand Up @@ -574,8 +575,8 @@ class PosteriorLinearisation2ndOrderGaussNewton(PosteriorLinearisation):
"""
# TODO: remove these when possible
compute_full_pseudo_lik: classmethod
mask_y: np.DeviceArray
mask_pseudo_y: np.DeviceArray
mask_y: Array
mask_pseudo_y: Array

def update_variational_params(self, batch_ind=None, lr=1., cubature=None, **kwargs):
"""
Expand Down

0 comments on commit 339e986

Please sign in to comment.