Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the choice for an integrator of the solution #405

Merged
merged 13 commits into from
Nov 12, 2021
2 changes: 1 addition & 1 deletion bioptim/misc/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class SolutionIntegrator(Enum):
Selection of integrator to use integrate function
"""

DEFAULT = None
DEFAULT = ""
SCIPY_RK23 = "RK23"
SCIPY_RK45 = "RK45"
SCIPY_DOP853 = "DOP853"
Expand Down
4 changes: 1 addition & 3 deletions bioptim/optimization/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def __perform_integration(
if shooting_type != Shooting.MULTIPLE:
raise RuntimeError(
"Integration with direct collocation must using shooting_type=Shooting.MULTIPLE "
"if not a scipy integrator is not used"
"if a scipy integrator is not used"
)

# Copy the data
Expand Down Expand Up @@ -623,8 +623,6 @@ def __perform_integration(
x0,
t_eval=t_eval,
method=integrator.value,
atol=1e-10,
rtol=1e-10,
).y

next_state_col = (
Expand Down
33 changes: 17 additions & 16 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_integrate_non_continuous(shooting, merge, integrator, ode_solver):
if (
ode_solver == OdeSolver.COLLOCATION
and shooting != Shooting.MULTIPLE
and integrator != SolutionIntegrator.DEFAULT
and integrator == SolutionIntegrator.DEFAULT
):
with pytest.raises(
RuntimeError,
Expand Down Expand Up @@ -547,7 +547,7 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, integrator, od
return

if ode_solver == OdeSolver.COLLOCATION:
if integrator != SolutionIntegrator.DEFAULT:
if integrator == SolutionIntegrator.DEFAULT:
if shooting != Shooting.MULTIPLE:
with pytest.raises(
RuntimeError, match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE"
Expand All @@ -566,7 +566,7 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, integrator, od
return

opts["continuous"] = True
if ode_solver == OdeSolver.COLLOCATION and integrator != SolutionIntegrator.DEFAULT:
if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT:
with pytest.raises(
RuntimeError,
match="Integration with direct collocation must be not continuous",
Expand All @@ -580,15 +580,15 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, integrator, od
decimal = 1 if integrator != SolutionIntegrator.DEFAULT else 8
for i in range(len(sol_integrated.states)):
for k, key in enumerate(sol.states[i]):
if integrator != SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
np.testing.assert_almost_equal(
sol_integrated.states[i][key][:, [0, -1]], sol.states[i][key][:, [0, -1]], decimal=decimal
)

if keep_intermediate_points:
assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] * 5 + 1)
else:
if integrator != SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
np.testing.assert_almost_equal(sol_integrated.states[i][key], sol.states[i][key])
assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] + 1)
if ode_solver == OdeSolver.COLLOCATION:
Expand Down Expand Up @@ -636,7 +636,7 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, integra
_ = sol.integrate(**opts)
return

if ode_solver == OdeSolver.COLLOCATION and integrator != SolutionIntegrator.DEFAULT:
if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT:
if shooting != Shooting.MULTIPLE:
with pytest.raises(
RuntimeError, match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE"
Expand All @@ -656,7 +656,7 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, integra

opts["merge_phases"] = True
opts["continuous"] = True
if ode_solver == OdeSolver.COLLOCATION and integrator != SolutionIntegrator.DEFAULT:
if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT:
with pytest.raises(
RuntimeError,
match="Integration with direct collocation must be not continuous",
Expand All @@ -671,14 +671,14 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, integra
decimal = 0 if integrator != SolutionIntegrator.DEFAULT else 8
for k, key in enumerate(sol.states[0]):
expected = np.array([sol.states[0][key][:, 0], sol.states[-1][key][:, -1]]).T
if integrator != SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
np.testing.assert_almost_equal(sol_integrated.states[key][:, [0, -1]], expected, decimal=decimal)

if keep_intermediate_points:
assert sol_integrated.states[key].shape == (shapes[k], sum(n_shooting) * 5 + 1)
else:
# The interpolation prevents from comparing all points
if integrator != SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
expected = np.concatenate(
(sol.states[0][key][:, 0:1], sol.states[-1][key][:, -1][:, np.newaxis]), axis=1
)
Expand Down Expand Up @@ -730,7 +730,7 @@ def test_integrate_multiphase_non_continuous(shooting, integrator, ode_solver):
_ = sol.integrate(**opts)
return

if ode_solver == OdeSolver.COLLOCATION and integrator != SolutionIntegrator.DEFAULT:
if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT:
if shooting != Shooting.MULTIPLE:
with pytest.raises(
RuntimeError, match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE"
Expand All @@ -744,15 +744,15 @@ def test_integrate_multiphase_non_continuous(shooting, integrator, ode_solver):
decimal = 1 if integrator != SolutionIntegrator.DEFAULT or ode_solver == OdeSolver.COLLOCATION else 8
for i in range(len(sol_integrated.states)):
for k, key in enumerate(sol.states[i]):
if integrator != SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
np.testing.assert_almost_equal(
sol_integrated.states[i][key][:, [0, -1]], sol.states[i][key][:, [0, -1]], decimal=decimal
)
np.testing.assert_almost_equal(
sol_integrated.states[i][key][:, [0, -2]], sol.states[i][key][:, [0, -1]], decimal=decimal
)

if ode_solver == OdeSolver.COLLOCATION and integrator != SolutionIntegrator.DEFAULT:
if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT:
assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] * (4 + 1) + 1)
else:
assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] * (5 + 1) + 1)
Expand Down Expand Up @@ -805,10 +805,11 @@ def test_integrate_multiphase_merged_non_continuous(shooting, integrator, ode_so
_ = sol.integrate(**opts)

elif ode_solver == OdeSolver.COLLOCATION:
if integrator != SolutionIntegrator.DEFAULT:
if integrator == SolutionIntegrator.DEFAULT:
with pytest.raises(
RuntimeError,
match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE",
match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE "
"if a scipy integrator is not used",
):
_ = sol.integrate(**opts)
return
Expand All @@ -820,10 +821,10 @@ def test_integrate_multiphase_merged_non_continuous(shooting, integrator, ode_so
shapes = (6, 3, 3)

decimal = 0 if integrator != SolutionIntegrator.DEFAULT or ode_solver == OdeSolver.COLLOCATION else 8
steps = 4 if integrator != SolutionIntegrator.DEFAULT and ode_solver == OdeSolver.COLLOCATION else 5
steps = 4 if integrator == SolutionIntegrator.DEFAULT and ode_solver == OdeSolver.COLLOCATION else 5
for k, key in enumerate(sol.states[0]):
expected = np.array([sol.states[0][key][:, 0], sol.states[-1][key][:, -1]]).T
if integrator != SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE:
np.testing.assert_almost_equal(sol_integrated.states[key][:, [0, -1]], expected, decimal=decimal)
np.testing.assert_almost_equal(sol_integrated.states[key][:, [0, -2]], expected, decimal=decimal)

Expand Down