diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index cfc3a3d8ca..24f97c000a 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -548,6 +548,12 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem": """ return eqx.tree_at(lambda p: p.parameters, self, p) + @eqx.filter_vmap( + in_axes={ + "max_steps": None, + "self": None, + }, # only list arguments here where eqx.is_array(0) is not the right thing + ) def run_simulation( self, p: jt.Float[jt.Array, "np"], # noqa: F821, F722 @@ -605,9 +611,17 @@ def run_simulation( ret=ret, ) + @eqx.filter_vmap( + in_axes={ + "max_steps": None, + "self": None, + }, # only list arguments here where eqx.is_array(0) is not the right thing + ) def run_preequilibration( self, - simulation_condition: str, + p: jt.Float[jt.Array, "np"], # noqa: F821, F722 + mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722 + x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722 solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, steady_state_event: Callable[ @@ -629,10 +643,6 @@ def run_preequilibration( :return: Pre-equilibration state """ - p = self.load_parameters(simulation_condition) - mask_reinit, x_reinit = self.load_reinitialisation( - simulation_condition, p - ) return self.model.preequilibrate_condition( p=p, mask_reinit=mask_reinit, @@ -684,61 +694,43 @@ def run_simulations( if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() - p_array = jnp.stack( - [problem.load_parameters(sc[0]) for sc in simulation_conditions] - ) - mask_reinit_array = jnp.stack( - [ - problem.load_reinitialisation(sc[0], p)[0] - for sc, p in zip(simulation_conditions, p_array) - ] - ) - x_reinit_array = jnp.stack( - [ - problem.load_reinitialisation(sc[0], p)[1] - for sc, p in zip(simulation_conditions, p_array) - ] + dynamic_conditions = [sc[0] for sc in simulation_conditions] + preequilibration_conditions = list( + {sc[1] for sc in simulation_conditions if len(sc) > 1} ) - preeqs = { - sc: problem.run_preequilibration( - sc, solver, controller, steady_state_event, max_steps + if preequilibration_conditions: + p_array, mask_reinit_array, x_reinit_array = _prepare_conditions( + problem, preequilibration_conditions ) - # only run preequilibration once per condition - for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} - } + preeqs, preresults = problem.run_preequilibration( + p_array, + mask_reinit_array, + x_reinit_array, + solver, + controller, + steady_state_event, + max_steps, + ) + else: + preresults = { + "stats_preeq": None, + } preeq_array = jnp.stack( [ - preeqs[sc[1]][0] if len(sc) > 1 else jnp.array([]) + preeqs[preequilibration_conditions.index(sc[1]), :] + if len(sc) > 1 + else jnp.array([]) for sc in simulation_conditions ] ) - parallel_run_simulation = eqx.filter_vmap( - JAXProblem.run_simulation, - in_axes=( - None, # problem - 0, # p - 0, # ts_dyn - 0, # ts_posteq - 0, # my - 0, # iys - 0, # iy_trafos - 0, # mask_reinit - 0, # x_reinit - None, # solver - None, # controller - None, # steady_state_event - None, # max_steps - 0, # preeq_array - 0, # ts_masks - None, # ret - ), + ### simulation + p_array, mask_reinit_array, x_reinit_array = _prepare_conditions( + problem, dynamic_conditions ) - - results = parallel_run_simulation( - problem, + output, results = problem.run_simulation( p_array, problem._ts_dyn, problem._ts_posteq, @@ -757,17 +749,33 @@ def run_simulations( ) if ret in (ReturnValue.llh, ReturnValue.chi2): - output = jnp.sum(results[0]) - else: - output = results[0] + output = jnp.sum(output) + + return output, { + "preequilibration_conditions": preequilibration_conditions, + "dynamic_conditions": dynamic_conditions, + "dynamic": results["stats_dyn"], + "posteq": results["stats_posteq"], + "preeq": preresults["stats_preeq"], + } - stats = ( - results[1]["stats_dyn"] | results[1]["stats_posteq"] - if results[1]["stats_posteq"] - else results[1]["stats_dyn"] - ) - return output, stats +def _prepare_conditions(problem: JAXProblem, conditions: Iterable[str]): + p_array = jnp.stack([problem.load_parameters(sc) for sc in conditions]) + + mask_reinit_array = jnp.stack( + [ + problem.load_reinitialisation(sc, p)[0] + for sc, p in zip(conditions, p_array) + ] + ) + x_reinit_array = jnp.stack( + [ + problem.load_reinitialisation(sc, p)[1] + for sc, p in zip(conditions, p_array) + ] + ) + return p_array, mask_reinit_array, x_reinit_array def petab_simulate(