Skip to content

Commit

Permalink
Change jax.tree.map() to jax.tree_util.tree_map for jax version compa…
Browse files Browse the repository at this point in the history
…tibility
  • Loading branch information
ThibeauWouters committed Dec 19, 2024
1 parent 8fe9cb6 commit 6ed815f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/fiesta/inference/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,24 +136,24 @@ def evaluate(self,

theta = {**theta, **self.fixed_params}
mag_abs: dict[str, Array] = self.model.predict(theta)
mag_app = jax.tree.map(lambda x: mag_app_from_mag_abs(x, theta["luminosity_distance"]),
mag_app = jax.tree_util.tree_map(lambda x: mag_app_from_mag_abs(x, theta["luminosity_distance"]),
mag_abs)

# Interpolate the mags to the times of interest
mag_est_det = jax.tree.map(lambda t, m: jnp.interp(t, self.model.times, m),
mag_est_det = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, self.model.times, m),
self.times_det, mag_app)

mag_est_nondet = jax.tree.map(lambda t, m: jnp.interp(t, self.model.times, m),
mag_est_nondet = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, self.model.times, m),
self.times_nondet, mag_app)

# Get chisq
chisq = jax.tree.map(self.get_chisq_filt,
chisq = jax.tree_util.tree_map(self.get_chisq_filt,
mag_est_det, self.mag_det, self.sigma, self.detection_limit)
chisq_flatten, _ = jax.flatten_util.ravel_pytree(chisq)
chisq_total = jnp.sum(chisq_flatten).astype(jnp.float64)

# Get gaussprob:
gaussprob = jax.tree.map(self.get_gaussprob_filt,
gaussprob = jax.tree_util.tree_map(self.get_gaussprob_filt,
mag_est_nondet, self.mag_nondet, self.error_budget)
gaussprob_flatten, _ = jax.flatten_util.ravel_pytree(gaussprob)
gaussprob_total = jnp.sum(gaussprob_flatten).astype(jnp.float64)
Expand Down

0 comments on commit 6ed815f

Please sign in to comment.