diff --git a/README.md b/README.md index 8a1385792..9bd008b2b 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,7 @@ As a tour guide that uses this binder, you can watch the `bioptim` workshop that - [InterpolationType](#enum-interpolationtype) - [Shooting](#enum-shooting) - [CostType](#enum-costtype) + - [SolutionIntegrator](#enum-solutionintegrator) [Examples](#examples) - [Run examples](#run-examples) @@ -1383,7 +1384,6 @@ The accepted values are: - SPLINE: Requires five columns. It performs a cubic spline to interpolate between the nodes. - CUSTOM: User defined interpolation function - ### Enum: Shooting The type of integration to perform - MULTIPLE: resets the state at each node @@ -1396,6 +1396,16 @@ The type of cost - CONSTRAINTS: The constraints - ALL: All the previously described cost type +### Enum: SolutionIntegrator +The type of integrator used to integrate the solution of the optimal control problem +- DEFAULT: The default integrator initially chosen with [OdeSolver](#class-odesolver) +- SCIPY_RK23: The scipy integrator RK23 +- SCIPY_RK45: The scipy integrator RK45 +- SCIPY_DOP853: The scipy integrator DOP853 +- SCIPY_BDF: The scipy integrator BDF +- SCIPY_LSODA: The scipy integrator LSODA + + # Examples In this section, you will find the description of all the examples implemented with bioptim. They are ordered in separate files. Each subsection corresponds to the different files, dealing with different examples and topics. diff --git a/bioptim/__init__.py b/bioptim/__init__.py index f534d65bc..d95de9b61 100644 --- a/bioptim/__init__.py +++ b/bioptim/__init__.py @@ -165,7 +165,17 @@ from .limits.path_conditions import BoundsList, Bounds, InitialGuessList, InitialGuess, QAndQDotBounds from .limits.fatigue_path_conditions import FatigueBounds, FatigueInitialGuess from .limits.penalty_node import PenaltyNode, PenaltyNodeList -from .misc.enums import Axis, Node, InterpolationType, PlotType, ControlType, CostType, Shooting, VariableType +from .misc.enums import ( + Axis, + Node, + InterpolationType, + PlotType, + ControlType, + CostType, + Shooting, + VariableType, + SolutionIntegrator, +) from .misc.mapping import BiMappingList, BiMapping, Mapping from .optimization.non_linear_program import NonLinearProgram from .optimization.optimal_control_program import OptimalControlProgram diff --git a/bioptim/gui/plot.py b/bioptim/gui/plot.py index 9f5eb2eca..967e4f489 100644 --- a/bioptim/gui/plot.py +++ b/bioptim/gui/plot.py @@ -10,7 +10,7 @@ from casadi import Callback, nlpsol_out, nlpsol_n_out, Sparsity, DM from ..limits.path_conditions import Bounds -from ..misc.enums import PlotType, ControlType, InterpolationType, Shooting +from ..misc.enums import PlotType, ControlType, InterpolationType, Shooting, SolutionIntegrator from ..misc.mapping import Mapping from ..optimization.solution import Solution @@ -197,7 +197,7 @@ def __init__( automatically_organize: bool = True, show_bounds: bool = False, shooting_type: Shooting = Shooting.MULTIPLE, - use_scipy_integrator: bool = False, + integrator: SolutionIntegrator = SolutionIntegrator.DEFAULT, ): """ Prepares the figures during the simulation @@ -212,8 +212,9 @@ def __init__( If the axes should fit the bounds (True) or the data (False) shooting_type: Shooting The type of integration method - use_scipy_integrator: bool - Use the scipy solve_ivp integrator for RungeKutta 45 instead of currently defined integrator + integrator: SolutionIntegrator + Use the ode defined by OCP or use a separate integrator provided by scipy + """ for i in range(1, ocp.n_phases): if len(ocp.nlp[0].states["q"]) != len(ocp.nlp[i].states["q"]): @@ -235,7 +236,8 @@ def __init__( self.t = [] self.t_integrated = [] - self.use_scipy_integrator = use_scipy_integrator + self.integrator = integrator + if isinstance(self.ocp.original_phase_time, (int, float)): self.tf = [self.ocp.original_phase_time] else: @@ -294,12 +296,14 @@ def __update_time_vector(self): self.t_integrated = [] last_t = 0 for phase_idx, nlp in enumerate(self.ocp.nlp): - n_int_steps = nlp.ode_solver.steps_scipy if self.use_scipy_integrator else nlp.ode_solver.steps + n_int_steps = ( + nlp.ode_solver.steps_scipy if self.integrator != SolutionIntegrator.DEFAULT else nlp.ode_solver.steps + ) dt_ns = self.tf[phase_idx] / nlp.ns time_phase_integrated = [] last_t_int = copy(last_t) for _ in range(nlp.ns): - if nlp.ode_solver.is_direct_collocation and not self.use_scipy_integrator: + if nlp.ode_solver.is_direct_collocation and self.integrator == SolutionIntegrator.DEFAULT: time_phase_integrated.append(np.array(nlp.dynamics[0].step_time) * dt_ns + last_t_int) else: time_phase_integrated.append(np.linspace(last_t_int, last_t_int + dt_ns, n_int_steps + 1)) @@ -426,7 +430,11 @@ def legend_without_duplicate_labels(ax): ) elif plot_type == PlotType.INTEGRATED: plots_integrated = [] - n_int_steps = nlp.ode_solver.steps_scipy if self.use_scipy_integrator else nlp.ode_solver.steps + n_int_steps = ( + nlp.ode_solver.steps_scipy + if self.integrator != SolutionIntegrator.DEFAULT + else nlp.ode_solver.steps + ) zero = np.zeros(n_int_steps + 1) color = self.plot_func[variable][i].color if self.plot_func[variable][i].color else "tab:brown" for cmp in range(nlp.ns): @@ -585,7 +593,7 @@ def update_data(self, v: dict): continuous=False, shooting_type=self.shooting_type, keep_intermediate_points=True, - use_scipy_integrator=self.use_scipy_integrator, + integrator=self.integrator, ).states data_controls = sol.controls data_params = sol.parameters @@ -598,7 +606,11 @@ def update_data(self, v: dict): self.__update_xdata() for i, nlp in enumerate(self.ocp.nlp): - step_size = nlp.ode_solver.steps_scipy + 1 if self.use_scipy_integrator else nlp.ode_solver.steps + 1 + step_size = ( + nlp.ode_solver.steps_scipy + 1 + if self.integrator != SolutionIntegrator.DEFAULT + else nlp.ode_solver.steps + 1 + ) n_elements = nlp.ns * step_size + 1 state = np.ndarray((0, n_elements)) diff --git a/bioptim/misc/enums.py b/bioptim/misc/enums.py index 54c548587..41143d875 100644 --- a/bioptim/misc/enums.py +++ b/bioptim/misc/enums.py @@ -103,3 +103,16 @@ class VariableType(Enum): STATES = "states" CONTROLS = "controls" + + +class SolutionIntegrator(Enum): + """ + Selection of integrator to use integrate function + """ + + DEFAULT = "DEFAULT" + SCIPY_RK23 = "RK23" + SCIPY_RK45 = "RK45" + SCIPY_DOP853 = "DOP853" + SCIPY_BDF = "BDF" + SCIPY_LSODA = "LSODA" diff --git a/bioptim/optimization/optimal_control_program.py b/bioptim/optimization/optimal_control_program.py index 5bc256552..71cc7c496 100644 --- a/bioptim/optimization/optimal_control_program.py +++ b/bioptim/optimization/optimal_control_program.py @@ -28,7 +28,7 @@ from ..limits.penalty import PenaltyOption from ..limits.objective_functions import ObjectiveFunction from ..misc.__version__ import __version__ -from ..misc.enums import ControlType, SolverType, Shooting, PlotType, CostType +from ..misc.enums import ControlType, SolverType, Shooting, PlotType, CostType, SolutionIntegrator from ..misc.mapping import BiMappingList, Mapping from ..misc.utils import check_version from ..optimization.parameters import ParameterList, Parameter @@ -749,7 +749,7 @@ def prepare_plots( automatically_organize: bool = True, show_bounds: bool = False, shooting_type: Shooting = Shooting.MULTIPLE, - use_scipy_integrator: bool = False, + integrator: SolutionIntegrator = SolutionIntegrator.DEFAULT, ) -> PlotOcp: """ Create all the plots associated with the OCP @@ -762,8 +762,8 @@ def prepare_plots( If the ylim should fit the bounds shooting_type: Shooting What type of integration - use_scipy_integrator: bool - Use the scipy solve_ivp integrator for RungeKutta 45 instead of currently defined integrator + integrator: SolutionIntegrator + Use the ode defined by OCP or use a separate integrator provided by scipy Returns ------- @@ -775,7 +775,7 @@ def prepare_plots( automatically_organize=automatically_organize, show_bounds=show_bounds, shooting_type=shooting_type, - use_scipy_integrator=use_scipy_integrator, + integrator=integrator, ) def solve( diff --git a/bioptim/optimization/solution.py b/bioptim/optimization/solution.py index 4cadc1526..2ba27127b 100644 --- a/bioptim/optimization/solution.py +++ b/bioptim/optimization/solution.py @@ -9,7 +9,7 @@ from matplotlib import pyplot as plt from ..limits.path_conditions import InitialGuess, InitialGuessList -from ..misc.enums import ControlType, CostType, Shooting, InterpolationType, SolverType +from ..misc.enums import ControlType, CostType, Shooting, InterpolationType, SolverType, SolutionIntegrator from ..misc.utils import check_version from ..optimization.non_linear_program import NonLinearProgram from ..optimization.optimization_variable import OptimizationVariableList, OptimizationVariable @@ -492,7 +492,7 @@ def integrate( keep_intermediate_points: bool = False, merge_phases: bool = False, continuous: bool = True, - use_scipy_integrator: bool = False, + integrator: SolutionIntegrator = SolutionIntegrator.DEFAULT, ) -> Any: """ Integrate the states @@ -509,8 +509,8 @@ def integrate( continuous: bool If the arrival value of a node should be discarded [True] or kept [False]. The value of an integrated arrival node and the beginning of the next one are expected to be almost equal when the problem converged - use_scipy_integrator: bool - Ignore the dynamics defined by OCP and use a separate integrator provided by scipy + integrator: SolutionIntegrator + Use the ode defined by OCP or use a separate integrator provided by scipy Returns ------- @@ -535,7 +535,7 @@ def integrate( "Shooting.SINGLE_CONTINUOUS and continuous=False cannot be used simultaneously it is a contradiction" ) - out = self.__perform_integration(shooting_type, keep_intermediate_points, continuous, use_scipy_integrator) + out = self.__perform_integration(shooting_type, keep_intermediate_points, continuous, integrator) if merge_phases: if continuous: @@ -548,24 +548,25 @@ def integrate( return out def __perform_integration( - self, shooting_type: Shooting, keep_intermediate_points: bool, continuous: bool, use_scipy_integrator: bool + self, shooting_type: Shooting, keep_intermediate_points: bool, continuous: bool, integrator: SolutionIntegrator ): n_direct_collocation = sum([nlp.ode_solver.is_direct_collocation for nlp in self.ocp.nlp]) - if n_direct_collocation > 0 and not use_scipy_integrator: + + if n_direct_collocation > 0 and integrator == SolutionIntegrator.DEFAULT: if continuous: raise RuntimeError( - "Integration with direct collocation must be not continuous if not use_scipy_integrator" + "Integration with direct collocation must be not continuous if a scipy integrator is used" ) if shooting_type != Shooting.MULTIPLE: raise RuntimeError( "Integration with direct collocation must using shooting_type=Shooting.MULTIPLE " - "if not use_scipy_integrator" + "if a scipy integrator is not used" ) # Copy the data out = self.copy(skip_data=True) - out.recomputed_time_steps = use_scipy_integrator + out.recomputed_time_steps = integrator != SolutionIntegrator.DEFAULT out._states = [] for _ in range(len(self._states)): out._states.append({}) @@ -575,7 +576,7 @@ def __perform_integration( for p, nlp in enumerate(self.ocp.nlp): param_scaling = nlp.parameters.scaling n_states = self._states[p]["all"].shape[0] - n_steps = nlp.ode_solver.steps_scipy if use_scipy_integrator else nlp.ode_solver.steps + n_steps = nlp.ode_solver.steps_scipy if integrator != SolutionIntegrator.DEFAULT else nlp.ode_solver.steps if not continuous: n_steps += 1 if keep_intermediate_points: @@ -596,7 +597,11 @@ def __perform_integration( ) x0 += np.array(val)[:, 0] else: - col = slice(0, n_steps) if nlp.ode_solver.is_direct_collocation and not use_scipy_integrator else 0 + col = ( + slice(0, n_steps) + if nlp.ode_solver.is_direct_collocation and integrator == SolutionIntegrator.DEFAULT + else 0 + ) x0 = self._states[p]["all"][:, col] for s in range(self.ns[p]): @@ -607,13 +612,17 @@ def __perform_integration( else: raise NotImplementedError(f"ControlType {nlp.control_type} " f"not yet implemented in integrating") - if use_scipy_integrator: + if integrator != SolutionIntegrator.DEFAULT: t_init = sum(out.phase_time[:p]) / nlp.ns t_end = sum(out.phase_time[: (p + 2)]) / nlp.ns n_points = n_steps + 1 if continuous else n_steps t_eval = np.linspace(t_init, t_end, n_points) if keep_intermediate_points else [t_init, t_end] integrated = solve_ivp( - lambda t, x: np.array(nlp.dynamics_func(x, u, params))[:, 0], [t_init, t_end], x0, t_eval=t_eval + lambda t, x: np.array(nlp.dynamics_func(x, u, params))[:, 0], + [t_init, t_end], + x0, + t_eval=t_eval, + method=integrator.value, ).y next_state_col = ( @@ -853,7 +862,7 @@ def graphs( show_bounds: bool = False, show_now: bool = True, shooting_type: Shooting = Shooting.MULTIPLE, - use_scipy_integrator: bool = False, + integrator: SolutionIntegrator = SolutionIntegrator.DEFAULT, ): """ Show the graphs of the simulation @@ -868,14 +877,14 @@ def graphs( If the show method should be called. This is blocking shooting_type: Shooting The type of interpolation - use_scipy_integrator: bool + integrator: SolutionIntegrator Use the scipy solve_ivp integrator for RungeKutta 45 instead of currently defined integrator """ if self.is_merged or self.is_interpolated or self.is_integrated: raise NotImplementedError("It is not possible to graph a modified Solution yet") - plot_ocp = self.ocp.prepare_plots(automatically_organize, show_bounds, shooting_type, use_scipy_integrator) + plot_ocp = self.ocp.prepare_plots(automatically_organize, show_bounds, shooting_type, integrator) plot_ocp.update_data(self.vector) if show_now: plt.show() diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 382475542..6619c20b6 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -4,7 +4,7 @@ import pytest import numpy as np -from bioptim import Shooting, OdeSolver +from bioptim import Shooting, OdeSolver, SolutionIntegrator def test_merge_phases_one_phase(): @@ -167,8 +167,8 @@ def test_interpolate_multiphases_merge_phase(): @pytest.mark.parametrize("ode_solver", [OdeSolver.RK4, OdeSolver.COLLOCATION]) -@pytest.mark.parametrize("use_scipy", [True, False]) -def test_integrate(use_scipy, ode_solver): +@pytest.mark.parametrize("integrator", [SolutionIntegrator.SCIPY_RK45, SolutionIntegrator.DEFAULT]) +def test_integrate(integrator, ode_solver): # Load pendulum from bioptim.examples.getting_started import pendulum as ocp_module @@ -185,7 +185,7 @@ def test_integrate(use_scipy, ode_solver): sol = ocp.solve() - opts = {"shooting_type": Shooting.MULTIPLE, "keep_intermediate_points": False, "use_scipy_integrator": use_scipy} + opts = {"shooting_type": Shooting.MULTIPLE, "keep_intermediate_points": False, "integrator": integrator} with pytest.raises( ValueError, match="Shooting.MULTIPLE and keep_intermediate_points=False " @@ -194,7 +194,7 @@ def test_integrate(use_scipy, ode_solver): _ = sol.integrate(**opts) opts["keep_intermediate_points"] = True - if ode_solver == OdeSolver.COLLOCATION and not use_scipy: + if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT: with pytest.raises(RuntimeError, match="Integration with direct collocation must be not continuous"): sol.integrate(**opts) return @@ -202,7 +202,7 @@ def test_integrate(use_scipy, ode_solver): sol_integrated = sol.integrate(**opts) shapes = (4, 2, 2) - decimal = 5 if use_scipy else 8 + decimal = 5 if integrator != SolutionIntegrator.DEFAULT else 8 for i, key in enumerate(sol.states): np.testing.assert_almost_equal( sol_integrated.states[key][:, [0, -1]], sol.states[key][:, [0, -1]], decimal=decimal @@ -241,7 +241,7 @@ def test_integrate_single_shoot(keep_intermediate_points, ode_solver): sol = ocp.solve() - opts = {"keep_intermediate_points": keep_intermediate_points, "use_scipy_integrator": False} + opts = {"keep_intermediate_points": keep_intermediate_points, "integrator": SolutionIntegrator.DEFAULT} if ode_solver == OdeSolver.COLLOCATION: with pytest.raises(RuntimeError, match="Integration with direct collocation must be not continuous"): sol.integrate(**opts) @@ -296,7 +296,7 @@ def test_integrate_single_shoot_use_scipy(keep_intermediate_points, ode_solver): sol = ocp.solve() - opts = {"keep_intermediate_points": keep_intermediate_points, "use_scipy_integrator": True} + opts = {"keep_intermediate_points": keep_intermediate_points, "integrator": SolutionIntegrator.SCIPY_RK45} sol_integrated = sol.integrate(**opts) shapes = (4, 2, 2) @@ -436,8 +436,8 @@ def test_integrate_single_shoot_use_scipy(keep_intermediate_points, ode_solver): @pytest.mark.parametrize("ode_solver", [OdeSolver.RK4, OdeSolver.COLLOCATION]) @pytest.mark.parametrize("shooting", [Shooting.SINGLE_CONTINUOUS, Shooting.MULTIPLE, Shooting.SINGLE]) @pytest.mark.parametrize("merge", [False, True]) -@pytest.mark.parametrize("use_scipy", [False, True]) -def test_integrate_non_continuous(shooting, merge, use_scipy, ode_solver): +@pytest.mark.parametrize("integrator", [SolutionIntegrator.DEFAULT, SolutionIntegrator.SCIPY_RK45]) +def test_integrate_non_continuous(shooting, merge, integrator, ode_solver): # Load pendulum from bioptim.examples.getting_started import pendulum as ocp_module @@ -458,7 +458,7 @@ def test_integrate_non_continuous(shooting, merge, use_scipy, ode_solver): "shooting_type": shooting, "continuous": False, "keep_intermediate_points": False, - "use_scipy_integrator": use_scipy, + "integrator": integrator, } if shooting == Shooting.SINGLE_CONTINUOUS: with pytest.raises( @@ -477,7 +477,11 @@ def test_integrate_non_continuous(shooting, merge, use_scipy, ode_solver): opts["keep_intermediate_points"] = True opts["merge_phases"] = merge - if ode_solver == OdeSolver.COLLOCATION and shooting != Shooting.MULTIPLE and not use_scipy: + if ( + ode_solver == OdeSolver.COLLOCATION + and shooting != Shooting.MULTIPLE + and integrator == SolutionIntegrator.DEFAULT + ): with pytest.raises( RuntimeError, match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE", @@ -488,14 +492,14 @@ def test_integrate_non_continuous(shooting, merge, use_scipy, ode_solver): sol_integrated = sol.integrate(**opts) shapes = (4, 2, 2) - decimal = 1 if use_scipy or ode_solver == OdeSolver.COLLOCATION else 8 + decimal = 1 if integrator != SolutionIntegrator.DEFAULT or ode_solver == OdeSolver.COLLOCATION else 8 for i, key in enumerate(sol.states): np.testing.assert_almost_equal( sol_integrated.states[key][:, [0, -1]], sol.states[key][:, [0, -1]], decimal=decimal ) if ode_solver == OdeSolver.COLLOCATION: - if use_scipy: + if integrator != SolutionIntegrator.DEFAULT: assert sol_integrated.states[key].shape == (shapes[i], n_shooting * (5 + 1) + 1) else: assert sol_integrated.states[key].shape == (shapes[i], n_shooting * (4 + 1) + 1) @@ -515,8 +519,8 @@ def test_integrate_non_continuous(shooting, merge, use_scipy, ode_solver): @pytest.mark.parametrize("ode_solver", [OdeSolver.RK4, OdeSolver.COLLOCATION]) @pytest.mark.parametrize("shooting", [Shooting.SINGLE_CONTINUOUS, Shooting.MULTIPLE, Shooting.SINGLE]) @pytest.mark.parametrize("keep_intermediate_points", [True, False]) -@pytest.mark.parametrize("use_scipy", [False, True]) -def test_integrate_multiphase(shooting, keep_intermediate_points, use_scipy, ode_solver): +@pytest.mark.parametrize("integrator", [SolutionIntegrator.DEFAULT, SolutionIntegrator.SCIPY_RK45]) +def test_integrate_multiphase(shooting, keep_intermediate_points, integrator, ode_solver): # Load pendulum from bioptim.examples.getting_started import example_multiphase as ocp_module @@ -531,7 +535,7 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, use_scipy, ode "shooting_type": shooting, "continuous": False, "keep_intermediate_points": keep_intermediate_points, - "use_scipy_integrator": use_scipy, + "integrator": integrator, } if shooting == Shooting.SINGLE_CONTINUOUS: @@ -543,7 +547,7 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, use_scipy, ode return if ode_solver == OdeSolver.COLLOCATION: - if not use_scipy: + if integrator == SolutionIntegrator.DEFAULT: if shooting != Shooting.MULTIPLE: with pytest.raises( RuntimeError, match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE" @@ -562,7 +566,7 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, use_scipy, ode return opts["continuous"] = True - if ode_solver == OdeSolver.COLLOCATION and not use_scipy: + if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT: with pytest.raises( RuntimeError, match="Integration with direct collocation must be not continuous", @@ -573,10 +577,10 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, use_scipy, ode sol_integrated = sol.integrate(**opts) shapes = (6, 3, 3) - decimal = 1 if use_scipy else 8 + decimal = 1 if integrator != SolutionIntegrator.DEFAULT else 8 for i in range(len(sol_integrated.states)): for k, key in enumerate(sol.states[i]): - if not use_scipy or shooting == Shooting.MULTIPLE: + if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE: np.testing.assert_almost_equal( sol_integrated.states[i][key][:, [0, -1]], sol.states[i][key][:, [0, -1]], decimal=decimal ) @@ -584,7 +588,7 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, use_scipy, ode if keep_intermediate_points: assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] * 5 + 1) else: - if not use_scipy or shooting == Shooting.MULTIPLE: + if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE: np.testing.assert_almost_equal(sol_integrated.states[i][key], sol.states[i][key]) assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] + 1) if ode_solver == OdeSolver.COLLOCATION: @@ -603,8 +607,8 @@ def test_integrate_multiphase(shooting, keep_intermediate_points, use_scipy, ode @pytest.mark.parametrize("ode_solver", [OdeSolver.RK4, OdeSolver.COLLOCATION]) @pytest.mark.parametrize("shooting", [Shooting.SINGLE_CONTINUOUS, Shooting.MULTIPLE, Shooting.SINGLE]) @pytest.mark.parametrize("keep_intermediate_points", [True, False]) -@pytest.mark.parametrize("use_scipy", [False, True]) -def test_integrate_multiphase_merged(shooting, keep_intermediate_points, use_scipy, ode_solver): +@pytest.mark.parametrize("integrator", [SolutionIntegrator.DEFAULT, SolutionIntegrator.SCIPY_RK45]) +def test_integrate_multiphase_merged(shooting, keep_intermediate_points, integrator, ode_solver): # Load pendulum from bioptim.examples.getting_started import example_multiphase as ocp_module @@ -621,7 +625,7 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, use_sci "shooting_type": shooting, "continuous": False, "keep_intermediate_points": keep_intermediate_points, - "use_scipy_integrator": use_scipy, + "integrator": integrator, } if shooting == Shooting.SINGLE_CONTINUOUS: @@ -632,7 +636,7 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, use_sci _ = sol.integrate(**opts) return - if ode_solver == OdeSolver.COLLOCATION and not use_scipy: + if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT: if shooting != Shooting.MULTIPLE: with pytest.raises( RuntimeError, match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE" @@ -652,7 +656,7 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, use_sci opts["merge_phases"] = True opts["continuous"] = True - if ode_solver == OdeSolver.COLLOCATION and not use_scipy: + if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT: with pytest.raises( RuntimeError, match="Integration with direct collocation must be not continuous", @@ -664,17 +668,17 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, use_sci sol_integrated = sol.integrate(**opts) shapes = (6, 3, 3) - decimal = 0 if use_scipy else 8 + decimal = 0 if integrator != SolutionIntegrator.DEFAULT else 8 for k, key in enumerate(sol.states[0]): expected = np.array([sol.states[0][key][:, 0], sol.states[-1][key][:, -1]]).T - if not use_scipy or shooting == Shooting.MULTIPLE: + if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE: np.testing.assert_almost_equal(sol_integrated.states[key][:, [0, -1]], expected, decimal=decimal) if keep_intermediate_points: assert sol_integrated.states[key].shape == (shapes[k], sum(n_shooting) * 5 + 1) else: # The interpolation prevents from comparing all points - if not use_scipy or shooting == Shooting.MULTIPLE: + if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE: expected = np.concatenate( (sol.states[0][key][:, 0:1], sol.states[-1][key][:, -1][:, np.newaxis]), axis=1 ) @@ -698,8 +702,8 @@ def test_integrate_multiphase_merged(shooting, keep_intermediate_points, use_sci @pytest.mark.parametrize("ode_solver", [OdeSolver.RK4, OdeSolver.COLLOCATION]) @pytest.mark.parametrize("shooting", [Shooting.SINGLE_CONTINUOUS, Shooting.MULTIPLE, Shooting.SINGLE]) -@pytest.mark.parametrize("use_scipy", [False, True]) -def test_integrate_multiphase_non_continuous(shooting, use_scipy, ode_solver): +@pytest.mark.parametrize("integrator", [SolutionIntegrator.DEFAULT, SolutionIntegrator.SCIPY_RK45]) +def test_integrate_multiphase_non_continuous(shooting, integrator, ode_solver): # Load pendulum from bioptim.examples.getting_started import example_multiphase as ocp_module @@ -714,7 +718,7 @@ def test_integrate_multiphase_non_continuous(shooting, use_scipy, ode_solver): "shooting_type": shooting, "continuous": False, "keep_intermediate_points": True, - "use_scipy_integrator": use_scipy, + "integrator": integrator, } if shooting == Shooting.SINGLE_CONTINUOUS: @@ -726,7 +730,7 @@ def test_integrate_multiphase_non_continuous(shooting, use_scipy, ode_solver): _ = sol.integrate(**opts) return - if ode_solver == OdeSolver.COLLOCATION and not use_scipy: + if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT: if shooting != Shooting.MULTIPLE: with pytest.raises( RuntimeError, match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE" @@ -737,10 +741,10 @@ def test_integrate_multiphase_non_continuous(shooting, use_scipy, ode_solver): sol_integrated = sol.integrate(**opts) shapes = (6, 3, 3) - decimal = 1 if use_scipy or ode_solver == OdeSolver.COLLOCATION else 8 + decimal = 1 if integrator != SolutionIntegrator.DEFAULT or ode_solver == OdeSolver.COLLOCATION else 8 for i in range(len(sol_integrated.states)): for k, key in enumerate(sol.states[i]): - if not use_scipy or shooting == Shooting.MULTIPLE: + if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE: np.testing.assert_almost_equal( sol_integrated.states[i][key][:, [0, -1]], sol.states[i][key][:, [0, -1]], decimal=decimal ) @@ -748,7 +752,7 @@ def test_integrate_multiphase_non_continuous(shooting, use_scipy, ode_solver): sol_integrated.states[i][key][:, [0, -2]], sol.states[i][key][:, [0, -1]], decimal=decimal ) - if ode_solver == OdeSolver.COLLOCATION and not use_scipy: + if ode_solver == OdeSolver.COLLOCATION and integrator == SolutionIntegrator.DEFAULT: assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] * (4 + 1) + 1) else: assert sol_integrated.states[i][key].shape == (shapes[k], n_shooting[i] * (5 + 1) + 1) @@ -768,8 +772,8 @@ def test_integrate_multiphase_non_continuous(shooting, use_scipy, ode_solver): @pytest.mark.parametrize("ode_solver", [OdeSolver.RK4, OdeSolver.COLLOCATION]) @pytest.mark.parametrize("shooting", [Shooting.SINGLE_CONTINUOUS, Shooting.MULTIPLE, Shooting.SINGLE]) -@pytest.mark.parametrize("use_scipy", [False, True]) -def test_integrate_multiphase_merged_non_continuous(shooting, use_scipy, ode_solver): +@pytest.mark.parametrize("integrator", [SolutionIntegrator.DEFAULT, SolutionIntegrator.SCIPY_RK45]) +def test_integrate_multiphase_merged_non_continuous(shooting, integrator, ode_solver): # Load pendulum from bioptim.examples.getting_started import example_multiphase as ocp_module @@ -783,7 +787,7 @@ def test_integrate_multiphase_merged_non_continuous(shooting, use_scipy, ode_sol "shooting_type": shooting, "continuous": False, "keep_intermediate_points": False, - "use_scipy_integrator": use_scipy, + "integrator": integrator, } if shooting == Shooting.SINGLE_CONTINUOUS: with pytest.raises( @@ -801,10 +805,11 @@ def test_integrate_multiphase_merged_non_continuous(shooting, use_scipy, ode_sol _ = sol.integrate(**opts) elif ode_solver == OdeSolver.COLLOCATION: - if not use_scipy: + if integrator == SolutionIntegrator.DEFAULT: with pytest.raises( RuntimeError, - match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE", + match="Integration with direct collocation must using shooting_type=Shooting.MULTIPLE " + "if a scipy integrator is not used", ): _ = sol.integrate(**opts) return @@ -815,11 +820,11 @@ def test_integrate_multiphase_merged_non_continuous(shooting, use_scipy, ode_sol sol_integrated = sol.integrate(**opts) shapes = (6, 3, 3) - decimal = 0 if use_scipy or ode_solver == OdeSolver.COLLOCATION else 8 - steps = 4 if not use_scipy and ode_solver == OdeSolver.COLLOCATION else 5 + decimal = 0 if integrator != SolutionIntegrator.DEFAULT or ode_solver == OdeSolver.COLLOCATION else 8 + steps = 4 if integrator == SolutionIntegrator.DEFAULT and ode_solver == OdeSolver.COLLOCATION else 5 for k, key in enumerate(sol.states[0]): expected = np.array([sol.states[0][key][:, 0], sol.states[-1][key][:, -1]]).T - if not use_scipy or shooting == Shooting.MULTIPLE: + if integrator == SolutionIntegrator.DEFAULT or shooting == Shooting.MULTIPLE: np.testing.assert_almost_equal(sol_integrated.states[key][:, [0, -1]], expected, decimal=decimal) np.testing.assert_almost_equal(sol_integrated.states[key][:, [0, -2]], expected, decimal=decimal) diff --git a/tests/utils.py b/tests/utils.py index dfa1a47a0..2526a58e7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,6 +18,7 @@ InitialGuess, Shooting, Solver, + SolutionIntegrator, ) @@ -117,7 +118,7 @@ def simulate(sol, decimal_value=7): merge_phases=True, shooting_type=Shooting.SINGLE_CONTINUOUS, keep_intermediate_points=True, - use_scipy_integrator=False, + integrator=SolutionIntegrator.DEFAULT, ) # Evaluate the final error of the single shooting integration versus the finale node