Skip to content

Commit

Permalink
fix vectorisation
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Jan 29, 2025
1 parent 57f5ed5 commit a9abbee
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 72 deletions.
12 changes: 8 additions & 4 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def simulate_condition(
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]),
mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]),
x_reinit: jt.Float[jt.Array, "*nx"] = jnp.array([]),
ts_mask: jt.Bool[jt.Array, "nt"] = jnp.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jt.Float[jt.Array, "nt *nx"] | jnp.float_, dict]:
r"""
Expand Down Expand Up @@ -522,6 +523,9 @@ def simulate_condition(
else:
x = self._x0(p)

if not ts_mask.shape[0]:
ts_mask = jnp.zeros_like(my, dtype=jnp.bool_)

Check warning on line 527 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L526-L527

Added lines #L526 - L527 were not covered by tests

# Re-initialization
if x_reinit.shape[0]:
x = jnp.where(mask_reinit, x_reinit, x)
Expand Down Expand Up @@ -566,9 +570,11 @@ def simulate_condition(
)

ts = jnp.concatenate((ts_dyn, ts_posteq), axis=0)

x = jnp.concatenate((x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
nllhs = jnp.where(ts_mask, nllhs, 0.0)

Check warning on line 577 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L577

Added line #L577 was not covered by tests
llh = -jnp.sum(nllhs)

stats = dict(
Expand Down Expand Up @@ -608,10 +614,8 @@ def simulate_condition(
ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos)
m_obj = obs_trafo(my, iy_trafos)
if ret == ReturnValue.chi2:
output = jnp.sum(
jnp.square(ys_obj - m_obj)
/ jnp.square(self._sigmays(ts, x, p, tcl, iys))
)
sigma_obj = self._sigmays(ts, x, p, tcl, iys)
output = jnp.sum(jnp.square((ys_obj - m_obj) / sigma_obj))

Check warning on line 618 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L617-L618

Added lines #L617 - L618 were not covered by tests
else:
output = ys_obj - m_obj
else:
Expand Down
256 changes: 198 additions & 58 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,13 @@ class JAXProblem(eqx.Module):
parameters: jnp.ndarray
model: JAXModel
_parameter_mappings: dict[str, ParameterMappingForCondition]
_measurements: dict[
tuple[str, ...],
tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
]
_petab_measurement_indices: dict[tuple[str, ...], tuple[int, ...]]
_ts_dyn: np.ndarray
_ts_posteq: np.ndarray
_my: np.ndarray
_iys: np.ndarray
_iy_trafos: np.ndarray
_ts_masks: np.ndarray
_petab_measurement_indices: np.ndarray

Check warning on line 90 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L84-L90

Added lines #L84 - L90 were not covered by tests
_petab_problem: petab.Problem

def __init__(self, model: JAXModel, petab_problem: petab.Problem):
Expand All @@ -107,9 +103,16 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem):
scs = petab_problem.get_simulation_conditions_from_measurement_df()
self._petab_problem = petab_problem
self._parameter_mappings = self._get_parameter_mappings(scs)
self._measurements, self._petab_measurement_indices = (
self._get_measurements(scs)
)
(

Check warning on line 106 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L106

Added line #L106 was not covered by tests
self._ts_dyn,
self._ts_posteq,
self._my,
self._iys,
self._iy_trafos,
self._ts_masks,
self._petab_measurement_indices,
) = self._get_measurements(scs)

self.parameters = self._get_nominal_parameter_values()

def save(self, directory: Path):
Expand Down Expand Up @@ -180,17 +183,13 @@ def _get_parameter_mappings(
def _get_measurements(
self, simulation_conditions: pd.DataFrame
) -> tuple[
dict[
tuple[str, ...],
tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
],
dict[tuple[str, ...], tuple[int, ...]],
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
]:
"""
Get measurements for the model based on the provided simulation conditions.
Expand All @@ -199,11 +198,17 @@ def _get_measurements(
Simulation conditions to create parameter mappings for. Same format as returned by
:meth:`petab.Problem.get_simulation_conditions_from_measurement_df`.
:return:
Dictionary mapping simulation conditions to measurements (tuple of pre-equilibrium, dynamic,
post-equilibrium time points; measurements and observable indices).
tuple of padded
- dynamic time points
- post-equilibrium time points
- measurements
- observable indices
- observable transformations indices
- measurement masks
- data indices (index in petab measurement dataframe).
"""
measurements = dict()
indices = dict()
petab_indices = dict()

Check warning on line 211 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L211

Added line #L211 was not covered by tests
for _, simulation_condition in simulation_conditions.iterrows():
query = " & ".join(
[f"{k} == '{v}'" for k, v in simulation_condition.items()]
Expand Down Expand Up @@ -249,8 +254,89 @@ def _get_measurements(
iys,
iy_trafos,
)
indices[tuple(simulation_condition)] = tuple(index.tolist())
return measurements, indices
petab_indices[tuple(simulation_condition)] = tuple(index.tolist())

Check warning on line 257 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L257

Added line #L257 was not covered by tests

# compute maximum lengths
n_ts_dyn = max(

Check warning on line 260 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L260

Added line #L260 was not covered by tests
len(ts_dyn) for ts_dyn, _, _, _, _ in measurements.values()
)
n_ts_posteq = max(

Check warning on line 263 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L263

Added line #L263 was not covered by tests
len(ts_posteq) for _, ts_posteq, _, _, _ in measurements.values()
)

# pad with last value and stack
ts_dyn = np.stack(

Check warning on line 268 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L268

Added line #L268 was not covered by tests
[
np.pad(x, (0, n_ts_dyn - len(x)), mode="edge")
for x, _, _, _, _ in measurements.values()
]
)
ts_posteq = np.stack(

Check warning on line 274 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L274

Added line #L274 was not covered by tests
[
np.pad(x, (0, n_ts_posteq - len(x)), mode="edge")
for _, x, _, _, _ in measurements.values()
]
)

def pad_measurement(x_dyn, x_peq, n_ts_dyn, n_ts_posteq):
return np.concatenate(

Check warning on line 282 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L281-L282

Added lines #L281 - L282 were not covered by tests
(
np.pad(x_dyn, (0, n_ts_dyn - len(x_dyn)), mode="edge"),
np.pad(x_peq, (0, n_ts_posteq - len(x_peq)), mode="edge"),
)
)

my = np.stack(

Check warning on line 289 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L289

Added line #L289 was not covered by tests
[
pad_measurement(
x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq
)
for tdyn, tpeq, x, _, _ in measurements.values()
]
)
iys = np.stack(

Check warning on line 297 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L297

Added line #L297 was not covered by tests
[
pad_measurement(
x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq
)
for tdyn, tpeq, _, x, _ in measurements.values()
]
)
iy_trafos = np.stack(

Check warning on line 305 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L305

Added line #L305 was not covered by tests
[
pad_measurement(
x[: len(tdyn)], x[len(tdyn) :], n_ts_dyn, n_ts_posteq
)
for tdyn, tpeq, _, _, x in measurements.values()
]
)
ts_masks = np.stack(

Check warning on line 313 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L313

Added line #L313 was not covered by tests
[
np.concatenate(
(
np.pad(np.ones_like(tdyn), (0, n_ts_dyn - len(tdyn))),
np.pad(
np.ones_like(tpeq), (0, n_ts_posteq - len(tpeq))
),
)
)
for tdyn, tpeq, _, _, _ in measurements.values()
]
).astype(bool)
petab_indices = np.stack(

Check warning on line 326 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L326

Added line #L326 was not covered by tests
[
pad_measurement(
idx[: len(tdyn)], idx[len(tdyn) :], n_ts_dyn, n_ts_posteq
)
for (tdyn, tpeq, _, _, _), idx in zip(
measurements.values(), petab_indices.values()
)
]
)

# create mask for my

return ts_dyn, ts_posteq, my, iys, iy_trafos, ts_masks, petab_indices

Check warning on line 339 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L339

Added line #L339 was not covered by tests

def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]:
simulation_conditions = (
Expand Down Expand Up @@ -464,21 +550,27 @@ def update_parameters(self, p: jt.Float[jt.Array, "np"]) -> "JAXProblem":

def run_simulation(
self,
simulation_condition: tuple[str, ...],
p: jt.Float[jt.Array, "np"], # noqa: F821, F722
ts_dyn: np.ndarray,
ts_posteq: np.ndarray,
my: np.ndarray,
iys: np.ndarray,
iy_trafos: np.ndarray,
mask_reinit: jt.Bool[jt.Array, "nx"], # noqa: F821, F722
x_reinit: jt.Float[jt.Array, "nx"], # noqa: F821, F722
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
steady_state_event: Callable[
..., diffrax._custom_types.BoolScalarLike
],
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
ts_mask: np.ndarray = np.array([]),
ret: ReturnValue = ReturnValue.llh,
) -> tuple[jnp.float_, dict]:
"""
Run a simulation for a given simulation condition.
:param simulation_condition:
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Expand All @@ -492,13 +584,6 @@ def run_simulation(
:return:
Tuple of output value and simulation statistics
"""
ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
mask_reinit, x_reinit = self.load_reinitialisation(
simulation_condition[0], p
)
return self.model.simulate_condition(
p=p,
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
Expand All @@ -509,6 +594,7 @@ def run_simulation(
x_preeq=x_preeq,
mask_reinit=jax.lax.stop_gradient(mask_reinit),
x_reinit=x_reinit,
ts_mask=jax.lax.stop_gradient(jnp.array(ts_mask)),
solver=solver,
controller=controller,
max_steps=max_steps,
Expand Down Expand Up @@ -598,6 +684,22 @@ def run_simulations(
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

p_array = jnp.stack(

Check warning on line 687 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L687

Added line #L687 was not covered by tests
[problem.load_parameters(sc[0]) for sc in simulation_conditions]
)
mask_reinit_array = jnp.stack(

Check warning on line 690 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L690

Added line #L690 was not covered by tests
[
problem.load_reinitialisation(sc[0], p)[0]
for sc, p in zip(simulation_conditions, p_array)
]
)
x_reinit_array = jnp.stack(

Check warning on line 696 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L696

Added line #L696 was not covered by tests
[
problem.load_reinitialisation(sc[0], p)[1]
for sc, p in zip(simulation_conditions, p_array)
]
)

preeqs = {
sc: problem.run_preequilibration(
sc, solver, controller, steady_state_event, max_steps
Expand All @@ -606,26 +708,64 @@ def run_simulations(
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}

results = {
sc: problem.run_simulation(
sc,
solver,
controller,
steady_state_event,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
ret=ret,
)
for sc in simulation_conditions
}
stats = {
sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1]
for sc, res in results.items()
}
preeq_array = jnp.stack(

Check warning on line 711 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L711

Added line #L711 was not covered by tests
[
preeqs[sc[1]][0] if len(sc) > 1 else jnp.array([])
for sc in simulation_conditions
]
)

parallel_run_simulation = eqx.filter_vmap(

Check warning on line 718 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L718

Added line #L718 was not covered by tests
JAXProblem.run_simulation,
in_axes=(
None, # problem
0, # p
0, # ts_dyn
0, # ts_posteq
0, # my
0, # iys
0, # iy_trafos
0, # mask_reinit
0, # x_reinit
None, # solver
None, # controller
None, # steady_state_event
None, # max_steps
0, # preeq_array
0, # ts_masks
None, # ret
),
)

results = parallel_run_simulation(

Check warning on line 740 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L740

Added line #L740 was not covered by tests
problem,
p_array,
problem._ts_dyn,
problem._ts_posteq,
problem._my,
problem._iys,
problem._iy_trafos,
mask_reinit_array,
x_reinit_array,
solver,
controller,
steady_state_event,
max_steps,
preeq_array,
problem._ts_masks,
ret,
)

if ret in (ReturnValue.llh, ReturnValue.chi2):
output = sum(r for r, _ in results.values())
output = jnp.sum(results[0])

Check warning on line 760 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L760

Added line #L760 was not covered by tests
else:
output = {sc: res[0] for sc, res in results.items()}
output = results[0]

Check warning on line 762 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L762

Added line #L762 was not covered by tests

stats = (

Check warning on line 764 in python/sdist/amici/jax/petab.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/petab.py#L764

Added line #L764 was not covered by tests
results[1]["stats_dyn"] | results[1]["stats_posteq"]
if results[1]["stats_posteq"]
else results[1]["stats_dyn"]
)

return output, stats

Expand Down
Loading

0 comments on commit a9abbee

Please sign in to comment.