diff --git a/src/fiesta/inference/likelihood.py b/src/fiesta/inference/likelihood.py index 75023ec..8e7caf8 100644 --- a/src/fiesta/inference/likelihood.py +++ b/src/fiesta/inference/likelihood.py @@ -136,24 +136,24 @@ def evaluate(self, theta = {**theta, **self.fixed_params} mag_abs: dict[str, Array] = self.model.predict(theta) - mag_app = jax.tree.map(lambda x: mag_app_from_mag_abs(x, theta["luminosity_distance"]), + mag_app = jax.tree_util.tree_map(lambda x: mag_app_from_mag_abs(x, theta["luminosity_distance"]), mag_abs) # Interpolate the mags to the times of interest - mag_est_det = jax.tree.map(lambda t, m: jnp.interp(t, self.model.times, m), + mag_est_det = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, self.model.times, m), self.times_det, mag_app) - mag_est_nondet = jax.tree.map(lambda t, m: jnp.interp(t, self.model.times, m), + mag_est_nondet = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, self.model.times, m), self.times_nondet, mag_app) # Get chisq - chisq = jax.tree.map(self.get_chisq_filt, + chisq = jax.tree_util.tree_map(self.get_chisq_filt, mag_est_det, self.mag_det, self.sigma, self.detection_limit) chisq_flatten, _ = jax.flatten_util.ravel_pytree(chisq) chisq_total = jnp.sum(chisq_flatten).astype(jnp.float64) # Get gaussprob: - gaussprob = jax.tree.map(self.get_gaussprob_filt, + gaussprob = jax.tree_util.tree_map(self.get_gaussprob_filt, mag_est_nondet, self.mag_nondet, self.error_budget) gaussprob_flatten, _ = jax.flatten_util.ravel_pytree(gaussprob) gaussprob_total = jnp.sum(gaussprob_flatten).astype(jnp.float64)