Skip to content

Commit

Permalink
decorators, preequilibration
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Jan 29, 2025
1 parent a9abbee commit 95207d7
Showing 1 changed file with 66 additions and 58 deletions.
124 changes: 66 additions & 58 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,12 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":
"""
return eqx.tree_at(lambda p: p.parameters, self, p)

@eqx.filter_vmap(

Check warning on line 551 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L551

Added line #L551 was not covered by tests
in_axes={
"max_steps": None,
"self": None,
}, # only list arguments here where eqx.is_array(0) is not the right thing
)
def run_simulation(
self,
p: jt.Float[jt.Array, "np"], # noqa: F821, F722
Expand Down Expand Up @@ -605,9 +611,17 @@ def run_simulation(
ret=ret,
)

@eqx.filter_vmap(

Check warning on line 614 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L614

Added line #L614 was not covered by tests
in_axes={
"max_steps": None,
"self": None,
}, # only list arguments here where eqx.is_array(0) is not the right thing
)
def run_preequilibration(
self,
simulation_condition: str,
p: jt.Float[jt.Array, "np"], # noqa: F821, F722
mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722
x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
Expand All @@ -629,10 +643,6 @@ def run_preequilibration(
:return:
Pre-equilibration state
"""
p = self.load_parameters(simulation_condition)
mask_reinit, x_reinit = self.load_reinitialisation(
simulation_condition, p
)
return self.model.preequilibrate_condition(
p=p,
mask_reinit=mask_reinit,
Expand Down Expand Up @@ -684,61 +694,43 @@ def run_simulations(
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

p_array = jnp.stack(
[problem.load_parameters(sc[0]) for sc in simulation_conditions]
)
mask_reinit_array = jnp.stack(
[
problem.load_reinitialisation(sc[0], p)[0]
for sc, p in zip(simulation_conditions, p_array)
]
)
x_reinit_array = jnp.stack(
[
problem.load_reinitialisation(sc[0], p)[1]
for sc, p in zip(simulation_conditions, p_array)
]
dynamic_conditions = [sc[0] for sc in simulation_conditions]
preequilibration_conditions = list(

Check warning on line 698 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L697-L698

Added lines #L697 - L698 were not covered by tests
{sc[1] for sc in simulation_conditions if len(sc) > 1}
)

preeqs = {
sc: problem.run_preequilibration(
sc, solver, controller, steady_state_event, max_steps
if preequilibration_conditions:
p_array, mask_reinit_array, x_reinit_array = _prepare_conditions(

Check warning on line 703 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L702-L703

Added lines #L702 - L703 were not covered by tests
problem, preequilibration_conditions
)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}
preeqs, preresults = problem.run_preequilibration(

Check warning on line 706 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L706

Added line #L706 was not covered by tests
p_array,
mask_reinit_array,
x_reinit_array,
solver,
controller,
steady_state_event,
max_steps,
)
else:
preresults = {

Check warning on line 716 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L716

Added line #L716 was not covered by tests
"stats_preeq": None,
}

preeq_array = jnp.stack(

Check warning on line 720 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L720

Added line #L720 was not covered by tests
[
preeqs[sc[1]][0] if len(sc) > 1 else jnp.array([])
preeqs[preequilibration_conditions.index(sc[1]), :]
if len(sc) > 1
else jnp.array([])
for sc in simulation_conditions
]
)

parallel_run_simulation = eqx.filter_vmap(
JAXProblem.run_simulation,
in_axes=(
None, # problem
0, # p
0, # ts_dyn
0, # ts_posteq
0, # my
0, # iys
0, # iy_trafos
0, # mask_reinit
0, # x_reinit
None, # solver
None, # controller
None, # steady_state_event
None, # max_steps
0, # preeq_array
0, # ts_masks
None, # ret
),
### simulation
p_array, mask_reinit_array, x_reinit_array = _prepare_conditions(

Check warning on line 730 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L730

Added line #L730 was not covered by tests
problem, dynamic_conditions
)

results = parallel_run_simulation(
problem,
output, results = problem.run_simulation(

Check warning on line 733 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L733

Added line #L733 was not covered by tests
p_array,
problem._ts_dyn,
problem._ts_posteq,
Expand All @@ -757,17 +749,33 @@ def run_simulations(
)

if ret in (ReturnValue.llh, ReturnValue.chi2):
output = jnp.sum(results[0])
else:
output = results[0]
output = jnp.sum(output)

Check warning on line 752 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L751-L752

Added lines #L751 - L752 were not covered by tests

return output, {

Check warning on line 754 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L754

Added line #L754 was not covered by tests
"preequilibration_conditions": preequilibration_conditions,
"dynamic_conditions": dynamic_conditions,
"dynamic": results["stats_dyn"],
"posteq": results["stats_posteq"],
"preeq": preresults["stats_preeq"],
}

stats = (
results[1]["stats_dyn"] | results[1]["stats_posteq"]
if results[1]["stats_posteq"]
else results[1]["stats_dyn"]
)

return output, stats
def _prepare_conditions(problem: JAXProblem, conditions: Iterable[str]):
p_array = jnp.stack([problem.load_parameters(sc) for sc in conditions])

Check warning on line 764 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L763-L764

Added lines #L763 - L764 were not covered by tests

mask_reinit_array = jnp.stack(

Check warning on line 766 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L766

Added line #L766 was not covered by tests
[
problem.load_reinitialisation(sc, p)[0]
for sc, p in zip(conditions, p_array)
]
)
x_reinit_array = jnp.stack(

Check warning on line 772 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L772

Added line #L772 was not covered by tests
[
problem.load_reinitialisation(sc, p)[1]
for sc, p in zip(conditions, p_array)
]
)
return p_array, mask_reinit_array, x_reinit_array

Check warning on line 778 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L778

Added line #L778 was not covered by tests


def petab_simulate(
Expand Down

0 comments on commit 95207d7

Please sign in to comment.