diff --git a/.all-contributorsrc b/.all-contributorsrc index 00b86f3474..caadc74b78 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -747,7 +747,8 @@ "profile": "https://github.com/AbhishekChaudharii", "contributions": [ "doc", - "code" + "code", + "test" ] }, { diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0f503f09a7..6984abf32c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,7 +10,7 @@ src/pybamm/meshes/ @martinjrobins @rtimms @valentinsulzer @rtimms src/pybamm/models/ @brosaplanella @DrSOKane @rtimms @valentinsulzer @TomTranter @rtimms src/pybamm/parameters/ @brosaplanella @DrSOKane @rtimms @valentinsulzer @TomTranter @rtimms @kratman src/pybamm/plotting/ @martinjrobins @rtimms @Saransh-cpp @valentinsulzer @rtimms @kratman @agriyakhetarpal -src/pybamm/solvers/ @martinjrobins @rtimms @valentinsulzer @TomTranter @rtimms +src/pybamm/solvers/ @martinjrobins @rtimms @valentinsulzer @TomTranter @rtimms @MarcBerliner src/pybamm/spatial_methods/ @martinjrobins @rtimms @valentinsulzer @rtimms src/pybamm/* @pybamm-team/maintainers # the files directly under /pybamm/, will not recurse diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 784ccebc4f..3186f0e49b 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -68,6 +68,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@4dd16135b69a43b6c8efb853346f8437d92d3c93 # v3.26.6 + uses: github/codeql-action/upload-sarif@294a9d92911152fe08befb9ec03e240add280cb3 # v3.26.8 with: sarif_file: results.sarif diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 985cd0291a..306118e254 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.6.4" + rev: "v0.6.7" hooks: - id: ruff args: [--fix, --show-fixes] diff --git a/CHANGELOG.md b/CHANGELOG.md index faa0a83ea8..38e03775d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,21 +4,26 @@ - Added basic telemetry to record which functions are being run. See [Telemetry section in the User Guide](https://docs.pybamm.org/en/latest/source/user_guide/index.html#telemetry) for more information. ([#4441](https://github.com/pybamm-team/PyBaMM/pull/4441)) - Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415)) +- Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449)) +- Added phase-dependent particle options to LAM ([#4369](https://github.com/pybamm-team/PyBaMM/pull/4369)) +- Added a lithium ion equivalent circuit model with split open circuit voltages for each electrode (`SplitOCVR`). ([#4330](https://github.com/pybamm-team/PyBaMM/pull/4330)) ## Optimizations +- Performance refactor of JAX BDF Solver with default Jax method set to `"BDF"`. ([#4456](https://github.com/pybamm-team/PyBaMM/pull/4456)) +- Improved performance of initialization and reinitialization of ODEs in the (`IDAKLUSolver`). ([#4453](https://github.com/pybamm-team/PyBaMM/pull/4453)) - Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416)) -## Features +## Bug Fixes -- Added phase-dependent particle options to LAM #4369 +- Fixed bug where IDAKLU solver failed when `output variables` were specified and an extrapolation event is present. ([#4440](https://github.com/pybamm-team/PyBaMM/pull/4440)) ## Breaking changes +- Removed the deprecation warning for the chemistry argument in ParameterValues ([#4466](https://github.com/pybamm-team/PyBaMM/pull/4466)) - The parameters "... electrode OCP entropic change [V.K-1]" and "... electrode volume change" are now expected to be functions of stoichiometry only instead of functions of both stoichiometry and maximum concentration ([#4427](https://github.com/pybamm-team/PyBaMM/pull/4427)) - Renamed `set_events` function to `add_events_from` to better reflect its purpose. ([#4421](https://github.com/pybamm-team/PyBaMM/pull/4421)) - # [v24.9.0](https://github.com/pybamm-team/PyBaMM/tree/v24.9.0) - 2024-09-03 ## Features diff --git a/CMakeLists.txt b/CMakeLists.txt index ad56ac34ca..a7f68ce7a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ endif() project(idaklu) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_EXPORT_COMPILE_COMMANDS 1) @@ -82,6 +82,8 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp + src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp + src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp @@ -94,6 +96,8 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/common.cpp src/pybamm/solvers/c_solvers/idaklu/Solution.cpp src/pybamm/solvers/c_solvers/idaklu/Solution.hpp + src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp + src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp src/pybamm/solvers/c_solvers/idaklu/Options.hpp src/pybamm/solvers/c_solvers/idaklu/Options.cpp # IDAKLU expressions / function evaluation [abstract] @@ -138,6 +142,23 @@ set_target_properties( INSTALL_RPATH_USE_LINK_PATH TRUE ) +# openmp +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + execute_process( + COMMAND "brew" "--prefix" + OUTPUT_VARIABLE HOMEBREW_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE) + if (OpenMP_ROOT) + set(OpenMP_ROOT "${OpenMP_ROOT}:${HOMEBREW_PREFIX}/opt/libomp") + else() + set(OpenMP_ROOT "${HOMEBREW_PREFIX}/opt/libomp") + endif() +endif() +find_package(OpenMP) +if(OpenMP_CXX_FOUND) + target_link_libraries(idaklu PRIVATE OpenMP::OpenMP_CXX) +endif() + set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}) # Sundials find_package(SUNDIALS REQUIRED) diff --git a/all_contributors.md b/all_contributors.md index 49a5f1d1ef..b701ef2a13 100644 --- a/all_contributors.md +++ b/all_contributors.md @@ -91,7 +91,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d Agnik Bakshi
Agnik Bakshi

📖 RuiheLi
RuiheLi

💻 ⚠️ chmabaur
chmabaur

🐛 💻 - Abhishek Chaudhari
Abhishek Chaudhari

📖 💻 + Abhishek Chaudhari
Abhishek Chaudhari

📖 💻 ⚠️ Shubham Bhardwaj
Shubham Bhardwaj

🚇 Jonathan Lauber
Jonathan Lauber

🚇 diff --git a/docs/source/api/models/lithium_ion/ecm_split_ocv.rst b/docs/source/api/models/lithium_ion/ecm_split_ocv.rst new file mode 100644 index 0000000000..a7d833cf55 --- /dev/null +++ b/docs/source/api/models/lithium_ion/ecm_split_ocv.rst @@ -0,0 +1,7 @@ +Equivalent Circuit Model with Split OCV (SplitOCVR) +===================================================== + +.. autoclass:: pybamm.lithium_ion.SplitOCVR + :members: + +.. footbibliography:: diff --git a/docs/source/api/models/lithium_ion/index.rst b/docs/source/api/models/lithium_ion/index.rst index 1a72c3c662..52efe44d6b 100644 --- a/docs/source/api/models/lithium_ion/index.rst +++ b/docs/source/api/models/lithium_ion/index.rst @@ -12,3 +12,4 @@ Lithium-ion Models msmr yang2017 electrode_soh + ecm_split_ocv diff --git a/examples/scripts/multiprocess_jax_solver.py b/examples/scripts/multiprocess_jax_solver.py new file mode 100644 index 0000000000..8192256ed1 --- /dev/null +++ b/examples/scripts/multiprocess_jax_solver.py @@ -0,0 +1,57 @@ +import pybamm +import time +import numpy as np + + +# This script provides an example for massively vectorised +# model solves using the JAX BDF solver. First, +# we set up the model and process parameters +model = pybamm.lithium_ion.SPM() +model.convert_to_format = "jax" +model.events = [] # remove events (not supported in jax) +geometry = model.default_geometry +param = pybamm.ParameterValues("Chen2020") +param.update({"Current function [A]": "[input]"}) +param.process_geometry(geometry) +param.process_model(model) + +# Discretise and setup solver +mesh = pybamm.Mesh(geometry, model.default_submesh_types, model.default_var_pts) +disc = pybamm.Discretisation(mesh, model.default_spatial_methods) +disc.process_model(model) +t_eval = np.linspace(0, 3600, 100) +solver = pybamm.JaxSolver(atol=1e-6, rtol=1e-6, method="BDF") + +# Set number of vectorised solves +values = np.linspace(0.01, 1.0, 1000) +inputs = [{"Current function [A]": value} for value in values] + +# Run solve for all inputs, with a just-in-time compilation +# occurring on the first solve. All sequential solves will +# use the compiled code, with a large performance improvement. +start_time = time.time() +sol = solver.solve(model, t_eval, inputs=inputs) +print(f"Time taken: {time.time() - start_time}") # 1.3s + +# Rerun the vectorised solve, showing performance improvement +start_time = time.time() +compiled_sol = solver.solve(model, t_eval, inputs=inputs) +print(f"Compiled time taken: {time.time() - start_time}") # 0.42s + +# Plot one of the solves +plot = pybamm.QuickPlot( + sol[5], + [ + "Negative particle concentration [mol.m-3]", + "Electrolyte concentration [mol.m-3]", + "Positive particle concentration [mol.m-3]", + "Current [A]", + "Negative electrode potential [V]", + "Electrolyte potential [V]", + "Positive electrode potential [V]", + "Voltage [V]", + ], + time_unit="seconds", + spatial_unit="um", +) +plot.dynamic_plot() diff --git a/noxfile.py b/noxfile.py index 6567ed167c..14bafcca47 100644 --- a/noxfile.py +++ b/noxfile.py @@ -222,7 +222,7 @@ def run_scripts(session): # https://bitbucket.org/pybtex-devs/pybtex/issues/169/replace-pkg_resources-with # is fixed session.install("setuptools", silent=False) - session.install("-e", ".[all,dev]", silent=False) + session.install("-e", ".[all,dev,jax]", silent=False) session.run("python", "-m", "pytest", "-m", "scripts") diff --git a/pyproject.toml b/pyproject.toml index 484c2bb334..48317c7e2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -281,6 +281,3 @@ module = [ "pybamm.models.full_battery_models.base_battery_model.*" ] disable_error_code = "attr-defined" - -[tool.poetry.scripts] -post-install = "scripts.post_install:run" diff --git a/src/pybamm/models/full_battery_models/lithium_ion/__init__.py b/src/pybamm/models/full_battery_models/lithium_ion/__init__.py index b02868dbe9..556e8de31c 100644 --- a/src/pybamm/models/full_battery_models/lithium_ion/__init__.py +++ b/src/pybamm/models/full_battery_models/lithium_ion/__init__.py @@ -24,8 +24,9 @@ from .Yang2017 import Yang2017 from .mpm import MPM from .msmr import MSMR +from .basic_splitOCVR import SplitOCVR __all__ = ['Yang2017', 'base_lithium_ion_model', 'basic_dfn', 'basic_dfn_composite', 'basic_dfn_half_cell', 'basic_spm', 'dfn', 'electrode_soh', 'electrode_soh_half_cell', 'mpm', 'msmr', - 'newman_tobias', 'spm', 'spme'] + 'newman_tobias', 'spm', 'spme', 'basic_splitOCVR'] diff --git a/src/pybamm/models/full_battery_models/lithium_ion/basic_splitOCVR.py b/src/pybamm/models/full_battery_models/lithium_ion/basic_splitOCVR.py new file mode 100644 index 0000000000..386a5c08fc --- /dev/null +++ b/src/pybamm/models/full_battery_models/lithium_ion/basic_splitOCVR.py @@ -0,0 +1,100 @@ +# +# Equivalent Circuit Model with split OCV +# +import pybamm + + +class SplitOCVR(pybamm.BaseModel): + """Basic Equivalent Circuit Model that uses two OCV functions + for each electrode. This model is easily parameterizable with minimal parameters. + This class differs from the :class: pybamm.equivalent_circuit.Thevenin() due + to dual OCV functions to make up the voltage from each electrode. + + Parameters + ---------- + name: str, optional + The name of the model. + """ + + def __init__(self, name="ECM with split OCV"): + super().__init__(name) + + ###################### + # Variables + ###################### + # All variables are only time-dependent + # No domain definition needed + + theta_n = pybamm.Variable("Negative particle stoichiometry") + theta_p = pybamm.Variable("Positive particle stoichiometry") + Q = pybamm.Variable("Discharge capacity [A.h]") + V = pybamm.Variable("Voltage [V]") + + # model is isothermal + I = pybamm.FunctionParameter("Current function [A]", {"Time [s]": pybamm.t}) + + # Capacity equation + self.rhs[Q] = I / 3600 + self.initial_conditions[Q] = pybamm.Scalar(0) + + # Capacity in each electrode + Q_n = pybamm.Parameter("Negative electrode capacity [A.h]") + Q_p = pybamm.Parameter("Positive electrode capacity [A.h]") + + # State of charge electrode equations + theta_n_0 = pybamm.Parameter("Negative electrode initial stoichiometry") + theta_p_0 = pybamm.Parameter("Positive electrode initial stoichiometry") + self.rhs[theta_n] = -I / Q_n / 3600 + self.rhs[theta_p] = I / Q_p / 3600 + self.initial_conditions[theta_n] = theta_n_0 + self.initial_conditions[theta_p] = theta_p_0 + + # Resistance for IR expression + R = pybamm.Parameter("Ohmic resistance [Ohm]") + + # Open-circuit potential for each electrode + Un = pybamm.FunctionParameter( + "Negative electrode OCP [V]", {"Negative particle stoichiometry": theta_n} + ) + Up = pybamm.FunctionParameter( + "Positive electrode OCP [V]", {"Positive particle stoichiometry": theta_p} + ) + + # Voltage expression + V = Up - Un - I * R + + # Parameters for Voltage cutoff + voltage_high_cut = pybamm.Parameter("Upper voltage cut-off [V]") + voltage_low_cut = pybamm.Parameter("Lower voltage cut-off [V]") + + self.variables = { + "Negative particle stoichiometry": theta_n, + "Positive particle stoichiometry": theta_p, + "Current [A]": I, + "Discharge capacity [A.h]": Q, + "Voltage [V]": V, + "Times [s]": pybamm.t, + "Positive electrode OCP [V]": Up, + "Negative electrode OCP [V]": Un, + "Current function [A]": I, + } + + # Events specify points at which a solution should terminate + self.events += [ + pybamm.Event("Minimum voltage [V]", V - voltage_low_cut), + pybamm.Event("Maximum voltage [V]", voltage_high_cut - V), + pybamm.Event("Maximum Negative Electrode stoichiometry", 0.999 - theta_n), + pybamm.Event("Maximum Positive Electrode stoichiometry", 0.999 - theta_p), + pybamm.Event("Minimum Negative Electrode stoichiometry", theta_n - 0.0001), + pybamm.Event("Minimum Positive Electrode stoichiometry", theta_p - 0.0001), + ] + + @property + def default_quick_plot_variables(self): + return [ + "Voltage [V]", + ["Negative particle stoichiometry", "Positive particle stoichiometry"], + "Negative electrode OCP [V]", + "Positive electrode OCP [V]", + "Current [A]", + ] diff --git a/src/pybamm/parameters/parameter_values.py b/src/pybamm/parameters/parameter_values.py index 43c1ea17ce..0a0e49cd8f 100644 --- a/src/pybamm/parameters/parameter_values.py +++ b/src/pybamm/parameters/parameter_values.py @@ -35,15 +35,7 @@ class ParameterValues: """ - def __init__(self, values, chemistry=None): - if chemistry is not None: - raise ValueError( - "The 'chemistry' keyword argument has been deprecated. " - "Call `ParameterValues` with a dictionary dictionary of " - "parameter values, or the name of a parameter set (string), " - "as the single argument, e.g. `ParameterValues('Chen2020')`.", - ) - + def __init__(self, values): # add physical constants as default values self._dict_items = pybamm.FuzzyDict( { @@ -930,3 +922,9 @@ def print_evaluated_parameters(self, evaluated_parameters, output_file): file.write((s + " : {:10.4g}\n").format(name, value)) else: file.write((s + " : {:10.3E}\n").format(name, value)) + + def __contains__(self, key): + return key in self._dict_items + + def __iter__(self): + return iter(self._dict_items) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 1df9aef35f..a7fc79fe3a 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -86,6 +86,14 @@ def supports_interp(self): def root_method(self): return self._root_method + @property + def supports_parallel_solve(self): + return False + + @property + def requires_explicit_sensitivities(self): + return True + @root_method.setter def root_method(self, method): if method == "casadi": @@ -137,7 +145,7 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): # see if we need to form the explicit sensitivity equations calculate_sensitivities_explicit = ( - model.calculate_sensitivities and not isinstance(self, pybamm.IDAKLUSolver) + model.calculate_sensitivities and self.requires_explicit_sensitivities ) self._set_up_model_sensitivities_inplace( @@ -490,11 +498,7 @@ def _set_up_model_sensitivities_inplace( # if we have a mass matrix, we need to extend it def extend_mass_matrix(M): M_extend = [M.entries] * (num_parameters + 1) - M_extend_pybamm = pybamm.Matrix(block_diag(M_extend, format="csr")) - return M_extend_pybamm - - model.mass_matrix = extend_mass_matrix(model.mass_matrix) - model.mass_matrix = extend_mass_matrix(model.mass_matrix) + return pybamm.Matrix(block_diag(M_extend, format="csr")) model.mass_matrix = extend_mass_matrix(model.mass_matrix) @@ -896,17 +900,8 @@ def solve( pybamm.logger.verbose( f"Calling solver for {t_eval[start_index]} < t < {t_eval[end_index - 1]}" ) - ninputs = len(model_inputs_list) - if ninputs == 1: - new_solution = self._integrate( - model, - t_eval[start_index:end_index], - model_inputs_list[0], - t_interp=t_interp, - ) - new_solutions = [new_solution] - elif model.convert_to_format == "jax": - # Jax can parallelize over the inputs efficiently + if self.supports_parallel_solve: + # Jax and IDAKLU solver can accept a list of inputs new_solutions = self._integrate( model, t_eval[start_index:end_index], @@ -914,18 +909,28 @@ def solve( t_interp, ) else: - with mp.get_context(self._mp_context).Pool(processes=nproc) as p: - new_solutions = p.starmap( - self._integrate, - zip( - [model] * ninputs, - [t_eval[start_index:end_index]] * ninputs, - model_inputs_list, - [t_interp] * ninputs, - ), + ninputs = len(model_inputs_list) + if ninputs == 1: + new_solution = self._integrate( + model, + t_eval[start_index:end_index], + model_inputs_list[0], + t_interp=t_interp, ) - p.close() - p.join() + new_solutions = [new_solution] + else: + with mp.get_context(self._mp_context).Pool(processes=nproc) as p: + new_solutions = p.starmap( + self._integrate, + zip( + [model] * ninputs, + [t_eval[start_index:end_index]] * ninputs, + model_inputs_list, + [t_interp] * ninputs, + ), + ) + p.close() + p.join() # Setting the solve time for each segment. # pybamm.Solution.__add__ assumes attribute solve_time. solve_time = timer.time() @@ -995,7 +1000,7 @@ def solve( ) # Return solution(s) - if ninputs == 1: + if len(solutions) == 1: return solutions[0] else: return solutions @@ -1350,7 +1355,13 @@ def step( # Step pybamm.logger.verbose(f"Stepping for {t_start_shifted:.0f} < t < {t_end:.0f}") timer.reset() - solution = self._integrate(model, t_eval, model_inputs, t_interp) + + # API for _integrate is different for JaxSolver and IDAKLUSolver + if self.supports_parallel_solve: + solutions = self._integrate(model, t_eval, [model_inputs], t_interp) + solution = solutions[0] + else: + solution = self._integrate(model, t_eval, model_inputs, t_interp) solution.solve_time = timer.time() # Check if extrapolation occurred @@ -1478,8 +1489,12 @@ def check_extrapolation(self, solution, events): # second pass: check if the extrapolation events are within the tolerance last_state = solution.last_state - t = last_state.all_ts[0][0] - y = last_state.all_ys[0][:, 0] + if solution.t_event: + t = solution.t_event[0] + y = solution.y_event[:, 0] + else: + t = last_state.all_ts[0][0] + y = last_state.all_ys[0][:, 0] inputs = last_state.all_inputs[0] if isinstance(y, casadi.DM): diff --git a/src/pybamm/solvers/c_solvers/idaklu.cpp b/src/pybamm/solvers/c_solvers/idaklu.cpp index 3ef0194403..db7147feb2 100644 --- a/src/pybamm/solvers/c_solvers/idaklu.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu.cpp @@ -9,6 +9,7 @@ #include #include "idaklu/idaklu_solver.hpp" +#include "idaklu/IDAKLUSolverGroup.hpp" #include "idaklu/IdakluJax.hpp" #include "idaklu/common.hpp" #include "idaklu/Expressions/Casadi/CasadiFunctions.hpp" @@ -26,15 +27,17 @@ casadi::Function generate_casadi_function(const std::string &data) namespace py = pybind11; PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MODULE(idaklu, m) { m.doc() = "sundials solvers"; // optional module docstring py::bind_vector>(m, "VectorNdArray"); + py::bind_vector>(m, "VectorSolution"); - py::class_(m, "IDAKLUSolver") - .def("solve", &IDAKLUSolver::solve, + py::class_(m, "IDAKLUSolverGroup") + .def("solve", &IDAKLUSolverGroup::solve, "perform a solve", py::arg("t_eval"), py::arg("t_interp"), @@ -43,8 +46,8 @@ PYBIND11_MODULE(idaklu, m) py::arg("inputs"), py::return_value_policy::take_ownership); - m.def("create_casadi_solver", &create_idaklu_solver, - "Create a casadi idaklu solver object", + m.def("create_casadi_solver_group", &create_idaklu_solver_group, + "Create a group of casadi idaklu solver objects", py::arg("number_of_states"), py::arg("number_of_parameters"), py::arg("rhs_alg"), @@ -70,8 +73,8 @@ PYBIND11_MODULE(idaklu, m) py::return_value_policy::take_ownership); #ifdef IREE_ENABLE - m.def("create_iree_solver", &create_idaklu_solver, - "Create a iree idaklu solver object", + m.def("create_iree_solver_group", &create_idaklu_solver_group, + "Create a group of iree idaklu solver objects", py::arg("number_of_states"), py::arg("number_of_parameters"), py::arg("rhs_alg"), diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp index 29b451e6d3..379d64783a 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp @@ -2,7 +2,8 @@ #define PYBAMM_IDAKLU_CASADI_SOLVER_HPP #include "common.hpp" -#include "Solution.hpp" +#include "SolutionData.hpp" + /** * Abstract base class for solutions that can use different solvers and vector @@ -24,14 +25,17 @@ class IDAKLUSolver ~IDAKLUSolver() = default; /** - * @brief Abstract solver method that returns a Solution class + * @brief Abstract solver method that executes the solver */ - virtual Solution solve( - np_array t_eval_np, - np_array t_interp_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs) = 0; + virtual SolutionData solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps + ) = 0; /** * Abstract method to initialize the solver, once vectors and solver classes diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp new file mode 100644 index 0000000000..8a76d73cfe --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp @@ -0,0 +1,145 @@ +#include "IDAKLUSolverGroup.hpp" +#include +#include + +std::vector IDAKLUSolverGroup::solve( + np_array t_eval_np, + np_array t_interp_np, + np_array y0_np, + np_array yp0_np, + np_array inputs) { + DEBUG("IDAKLUSolverGroup::solve"); + + // If t_interp is empty, save all adaptive steps + bool save_adaptive_steps = t_interp_np.size() == 0; + + const realtype* t_eval_begin = t_eval_np.data(); + const realtype* t_eval_end = t_eval_begin + t_eval_np.size(); + const realtype* t_interp_begin = t_interp_np.data(); + const realtype* t_interp_end = t_interp_begin + t_interp_np.size(); + + // Process the time inputs + // 1. Get the sorted and unique t_eval vector + auto const t_eval = makeSortedUnique(t_eval_begin, t_eval_end); + + // 2.1. Get the sorted and unique t_interp vector + auto const t_interp_unique_sorted = makeSortedUnique(t_interp_begin, t_interp_end); + + // 2.2 Remove the t_eval values from t_interp + auto const t_interp_setdiff = setDiff(t_interp_unique_sorted.begin(), t_interp_unique_sorted.end(), t_eval_begin, t_eval_end); + + // 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed + auto const t_interp = makeSortedUnique(t_interp_setdiff.begin(), t_interp_setdiff.end()); + + int const number_of_evals = t_eval.size(); + int const number_of_interps = t_interp.size(); + + // setDiff removes entries of t_interp that overlap with + // t_eval, so we need to check if we need to interpolate any unique points. + // This is not the same as save_adaptive_steps since some entries of t_interp + // may be removed by setDiff + bool save_interp_steps = number_of_interps > 0; + + // 3. Check if the timestepping entries are valid + if (number_of_evals < 2) { + throw std::invalid_argument( + "t_eval must have at least 2 entries" + ); + } else if (save_interp_steps) { + if (t_interp.front() < t_eval.front()) { + throw std::invalid_argument( + "t_interp values must be greater than the smallest t_eval value: " + + std::to_string(t_eval.front()) + ); + } else if (t_interp.back() > t_eval.back()) { + throw std::invalid_argument( + "t_interp values must be less than the greatest t_eval value: " + + std::to_string(t_eval.back()) + ); + } + } + + auto n_coeffs = number_of_states + number_of_parameters * number_of_states; + + // check y0 and yp0 and inputs have the correct dimensions + if (y0_np.ndim() != 2) + throw std::domain_error("y0 has wrong number of dimensions. Expected 2 but got " + std::to_string(y0_np.ndim())); + if (yp0_np.ndim() != 2) + throw std::domain_error("yp0 has wrong number of dimensions. Expected 2 but got " + std::to_string(yp0_np.ndim())); + if (inputs.ndim() != 2) + throw std::domain_error("inputs has wrong number of dimensions. Expected 2 but got " + std::to_string(inputs.ndim())); + + auto number_of_groups = y0_np.shape()[0]; + + // check y0 and yp0 and inputs have the correct shape + if (y0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "y0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(y0_np.shape()[1])); + + if (yp0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "yp0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(yp0_np.shape()[1])); + + if (yp0_np.shape()[0] != number_of_groups) + throw std::domain_error( + "yp0 has wrong number of rows. Expected " + std::to_string(number_of_groups) + + " but got " + std::to_string(yp0_np.shape()[0])); + + if (inputs.shape()[0] != number_of_groups) + throw std::domain_error( + "inputs has wrong number of rows. Expected " + std::to_string(number_of_groups) + + " but got " + std::to_string(inputs.shape()[0])); + + const std::size_t solves_per_thread = number_of_groups / m_solvers.size(); + const std::size_t remainder_solves = number_of_groups % m_solvers.size(); + + const realtype *y0 = y0_np.data(); + const realtype *yp0 = yp0_np.data(); + const realtype *inputs_data = inputs.data(); + + std::vector results(number_of_groups); + + std::optional exception; + + omp_set_num_threads(m_solvers.size()); + #pragma omp parallel for + for (int i = 0; i < m_solvers.size(); i++) { + try { + for (int j = 0; j < solves_per_thread; j++) { + const std::size_t index = i * solves_per_thread + j; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps); + } + } catch (std::exception &e) { + // If an exception is thrown, we need to catch it and rethrow it outside the parallel region + #pragma omp critical + { + exception = e; + } + } + } + + if (exception.has_value()) { + py::set_error(PyExc_ValueError, exception->what()); + throw py::error_already_set(); + } + + for (int i = 0; i < remainder_solves; i++) { + const std::size_t index = number_of_groups - remainder_solves + i; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps); + } + + // create solutions (needs to be serial as we're using the Python GIL) + std::vector solutions(number_of_groups); + for (int i = 0; i < number_of_groups; i++) { + solutions[i] = results[i].generate_solution(); + } + return solutions; +} diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp new file mode 100644 index 0000000000..609b3b6fca --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp @@ -0,0 +1,48 @@ +#ifndef PYBAMM_IDAKLU_SOLVER_GROUP_HPP +#define PYBAMM_IDAKLU_SOLVER_GROUP_HPP + +#include "IDAKLUSolver.hpp" +#include "common.hpp" + +/** + * @brief class for a group of solvers. + */ +class IDAKLUSolverGroup +{ +public: + + /** + * @brief Default constructor + */ + IDAKLUSolverGroup(std::vector> solvers, int number_of_states, int number_of_parameters): + m_solvers(std::move(solvers)), + number_of_states(number_of_states), + number_of_parameters(number_of_parameters) + {} + + // no copy constructor (unique_ptr cannot be copied) + IDAKLUSolverGroup(IDAKLUSolverGroup &) = delete; + + /** + * @brief Default destructor + */ + ~IDAKLUSolverGroup() = default; + + /** + * @brief solver method that returns a vector of Solutions + */ + std::vector solve( + np_array t_eval_np, + np_array t_interp_np, + np_array y0_np, + np_array yp0_np, + np_array inputs); + + + private: + std::vector> m_solvers; + int number_of_states; + int number_of_parameters; +}; + +#endif // PYBAMM_IDAKLU_SOLVER_GROUP_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp index ca710fbff6..92eede3643 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp @@ -52,8 +52,9 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver int const number_of_states; // cppcheck-suppress unusedStructMember int const number_of_parameters; // cppcheck-suppress unusedStructMember int const number_of_events; // cppcheck-suppress unusedStructMember + int number_of_timesteps; int precon_type; // cppcheck-suppress unusedStructMember - N_Vector yy, yp, avtol; // y, y', and absolute tolerance + N_Vector yy, yp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance N_Vector *yyS; // cppcheck-suppress unusedStructMember N_Vector *ypS; // cppcheck-suppress unusedStructMember N_Vector id; // rhs_alg_id @@ -69,6 +70,7 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver vector res_dvar_dp; bool const sensitivity; // cppcheck-suppress unusedStructMember bool const save_outputs_only; // cppcheck-suppress unusedStructMember + bool is_ODE; // cppcheck-suppress unusedStructMember int length_of_return_vector; // cppcheck-suppress unusedStructMember vector t; // cppcheck-suppress unusedStructMember vector> y; // cppcheck-suppress unusedStructMember @@ -106,12 +108,16 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver /** * @brief The main solve method that solves for each variable and time step */ - Solution solve( - np_array t_eval_np, - np_array t_interp_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs) override; + SolutionData solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps + ) override; + /** * @brief Concrete implementation of initialization method @@ -153,6 +159,32 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver */ void PrintStats(); + /** + * @brief Set a consistent initialization for ODEs + */ + void ReinitializeIntegrator(const realtype& t_val); + + /** + * @brief Set a consistent initialization for the system of equations + */ + void ConsistentInitialization( + const realtype& t_val, + const realtype& t_next, + const int& icopt); + + /** + * @brief Set a consistent initialization for DAEs + */ + void ConsistentInitializationDAE( + const realtype& t_val, + const realtype& t_next, + const int& icopt); + + /** + * @brief Set a consistent initialization for ODEs + */ + void ConsistentInitializationODE(const realtype& t_val); + /** * @brief Extend the adaptive arrays by 1 */ diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl index fd8eb38257..56f546facf 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl @@ -3,6 +3,7 @@ #include #include "common.hpp" +#include "SolutionData.hpp" template IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( @@ -82,15 +83,30 @@ IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( if (this->setup_opts.preconditioner != "none") { precon_type = SUN_PREC_LEFT; } + + // The default is to solve a DAE for generality. This may be changed + // to an ODE during the Initialize() call + is_ODE = false; } template void IDAKLUSolverOpenMP::AllocateVectors() { + DEBUG("IDAKLUSolverOpenMP::AllocateVectors (num_threads = " << setup_opts.num_threads << ")"); // Create vectors - yy = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); - yp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); - avtol = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); - id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + if (setup_opts.num_threads == 1) { + yy = N_VNew_Serial(number_of_states, sunctx); + yp = N_VNew_Serial(number_of_states, sunctx); + y_cache = N_VNew_Serial(number_of_states, sunctx); + avtol = N_VNew_Serial(number_of_states, sunctx); + id = N_VNew_Serial(number_of_states, sunctx); + } else { + DEBUG("IDAKLUSolverOpenMP::AllocateVectors OpenMP"); + yy = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + yp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + y_cache = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + avtol = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + } } template @@ -279,9 +295,13 @@ void IDAKLUSolverOpenMP::Initialize() { realtype *id_val; id_val = N_VGetArrayPointer(id); - int ii; - for (ii = 0; ii < number_of_states; ii++) { + // Determine if the system is an ODE + is_ODE = number_of_states > 0; + for (int ii = 0; ii < number_of_states; ii++) { id_val[ii] = id_np_val[ii]; + // check if id_val[ii] approximately equals 1 (>0.999) handles + // cases where id_val[ii] is not exactly 1 due to numerical errors + is_ODE &= id_val[ii] > 0.999; } // Variable types: differential (1) and algebraic (0) @@ -290,6 +310,7 @@ void IDAKLUSolverOpenMP::Initialize() { template IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { + DEBUG("IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP"); // Free memory if (sensitivity) { IDASensFree(ida_mem); @@ -301,6 +322,7 @@ IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { N_VDestroy(avtol); N_VDestroy(yy); N_VDestroy(yp); + N_VDestroy(y_cache); N_VDestroy(id); if (sensitivity) { @@ -313,63 +335,24 @@ IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { } template -Solution IDAKLUSolverOpenMP::solve( - np_array t_eval_np, - np_array t_interp_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs +SolutionData IDAKLUSolverOpenMP::solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps ) { DEBUG("IDAKLUSolver::solve"); + const int number_of_evals = t_eval.size(); + const int number_of_interps = t_interp.size(); - // If t_interp is empty, save all adaptive steps - bool save_adaptive_steps = t_interp_np.unchecked<1>().size() == 0; - - // Process the time inputs - // 1. Get the sorted and unique t_eval vector - auto const t_eval = makeSortedUnique(t_eval_np); - - // 2.1. Get the sorted and unique t_interp vector - auto const t_interp_unique_sorted = makeSortedUnique(t_interp_np); - - // 2.2 Remove the t_eval values from t_interp - auto const t_interp_setdiff = setDiff(t_interp_unique_sorted, t_eval); - - // 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed - auto const t_interp = makeSortedUnique(t_interp_setdiff); - - int const number_of_evals = t_eval.size(); - int const number_of_interps = t_interp.size(); - - // setDiff removes entries of t_interp that overlap with - // t_eval, so we need to check if we need to interpolate any unique points. - // This is not the same as save_adaptive_steps since some entries of t_interp - // may be removed by setDiff - bool save_interp_steps = number_of_interps > 0; - - // 3. Check if the timestepping entries are valid - if (number_of_evals < 2) { - throw std::invalid_argument( - "t_eval must have at least 2 entries" - ); - } else if (save_interp_steps) { - if (t_interp.front() < t_eval.front()) { - throw std::invalid_argument( - "t_interp values must be greater than the smallest t_eval value: " - + std::to_string(t_eval.front()) - ); - } else if (t_interp.back() > t_eval.back()) { - throw std::invalid_argument( - "t_interp values must be less than the greatest t_eval value: " - + std::to_string(t_eval.back()) - ); - } + if (t.size() < number_of_evals + number_of_interps) { + InitializeStorage(number_of_evals + number_of_interps); } - // Initialize length_of_return_vector, t, y, and yS - InitializeStorage(number_of_evals + number_of_interps); - int i_save = 0; realtype t0 = t_eval.front(); @@ -386,24 +369,11 @@ Solution IDAKLUSolverOpenMP::solve( t_interp_next = t_interp[0]; } - auto y0 = y0_np.unchecked<1>(); - auto yp0 = yp0_np.unchecked<1>(); auto n_coeffs = number_of_states + number_of_parameters * number_of_states; - if (y0.size() != n_coeffs) { - throw std::domain_error( - "y0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(y0.size())); - } else if (yp0.size() != n_coeffs) { - throw std::domain_error( - "yp0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(yp0.size())); - } - // set inputs - auto p_inputs = inputs.unchecked<2>(); for (int i = 0; i < functions->inputs.size(); i++) { - functions->inputs[i] = p_inputs(i, 0); + functions->inputs[i] = inputs[i]; } // Setup consistent initialization @@ -427,21 +397,16 @@ Solution IDAKLUSolverOpenMP::solve( SetSolverOptions(); - CheckErrors(IDAReInit(ida_mem, t0, yy, yp)); - if (sensitivity) { - CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS)); - } - // Prepare first time step i_eval = 1; realtype t_eval_next = t_eval[i_eval]; + // Consistent initialization + ReinitializeIntegrator(t0); int const init_type = solver_opts.init_all_y_ic ? IDA_Y_INIT : IDA_YA_YDP_INIT; if (solver_opts.calc_ic) { - DEBUG("IDACalcIC"); - // IDACalcIC will throw a warning if it fails to find initial conditions - IDACalcIC(ida_mem, init_type, t_eval_next); + ConsistentInitialization(t0, t_eval_next, init_type); } if (sensitivity) { @@ -521,12 +486,8 @@ Solution IDAKLUSolverOpenMP::solve( CheckErrors(IDASetStopTime(ida_mem, t_eval_next)); // Reinitialize the solver to deal with the discontinuity at t = t_val. - // We must reinitialize the algebraic terms, so do not use init_type. - IDACalcIC(ida_mem, IDA_YA_YDP_INIT, t_eval_next); - CheckErrors(IDAReInit(ida_mem, t_val, yy, yp)); - if (sensitivity) { - CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS)); - } + ReinitializeIntegrator(t_val); + ConsistentInitialization(t_val, t_eval_next, IDA_YA_YDP_INIT); } t_prev = t_val; @@ -543,8 +504,8 @@ Solution IDAKLUSolverOpenMP::solve( PrintStats(); } - int const number_of_timesteps = i_save; - int count; + // store number of timesteps so we can generate the solution later + number_of_timesteps = i_save; // Copy the data to return as numpy arrays @@ -554,23 +515,9 @@ Solution IDAKLUSolverOpenMP::solve( t_return[i] = t[i]; } - py::capsule free_t_when_done( - t_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array t_ret = np_array( - number_of_timesteps, - &t_return[0], - free_t_when_done - ); - // States, y realtype *y_return = new realtype[number_of_timesteps * length_of_return_vector]; - count = 0; + int count = 0; for (size_t i = 0; i < number_of_timesteps; i++) { for (size_t j = 0; j < length_of_return_vector; j++) { y_return[count] = y[i][j]; @@ -578,20 +525,6 @@ Solution IDAKLUSolverOpenMP::solve( } } - py::capsule free_y_when_done( - y_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array y_ret = np_array( - number_of_timesteps * length_of_return_vector, - &y_return[0], - free_y_when_done - ); - // Sensitivity states, yS // Note: Ordering of vector is different if computing outputs vs returning // the complete state vector @@ -614,43 +547,7 @@ Solution IDAKLUSolverOpenMP::solve( } } - py::capsule free_yS_when_done( - yS_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array yS_ret = np_array( - vector { - arg_sens0, - arg_sens1, - arg_sens2 - }, - &yS_return[0], - free_yS_when_done - ); - - // Final state slice, yterm - py::capsule free_yterm_when_done( - yterm_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array y_term = np_array( - length_of_final_sv_slice, - &yterm_return[0], - free_yterm_when_done - ); - - // Store the solution - Solution sol(retval, t_ret, y_ret, yS_ret, y_term); - - return sol; + return SolutionData(retval, number_of_timesteps, length_of_return_vector, arg_sens0, arg_sens1, arg_sens2, length_of_final_sv_slice, t_return, y_return, yS_return, yterm_return); } template @@ -668,6 +565,53 @@ void IDAKLUSolverOpenMP::ExtendAdaptiveArrays() { } } +template +void IDAKLUSolverOpenMP::ReinitializeIntegrator(const realtype& t_val) { + DEBUG("IDAKLUSolver::ReinitializeIntegrator"); + CheckErrors(IDAReInit(ida_mem, t_val, yy, yp)); + if (sensitivity) { + CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS)); + } +} + +template +void IDAKLUSolverOpenMP::ConsistentInitialization( + const realtype& t_val, + const realtype& t_next, + const int& icopt) { + DEBUG("IDAKLUSolver::ConsistentInitialization"); + + if (is_ODE && icopt == IDA_YA_YDP_INIT) { + ConsistentInitializationODE(t_val); + } else { + ConsistentInitializationDAE(t_val, t_next, icopt); + } +} + +template +void IDAKLUSolverOpenMP::ConsistentInitializationDAE( + const realtype& t_val, + const realtype& t_next, + const int& icopt) { + DEBUG("IDAKLUSolver::ConsistentInitializationDAE"); + IDACalcIC(ida_mem, icopt, t_next); +} + +template +void IDAKLUSolverOpenMP::ConsistentInitializationODE( + const realtype& t_val) { + DEBUG("IDAKLUSolver::ConsistentInitializationODE"); + + // For ODEs where the mass matrix M = I, we can simplify the problem + // by analytically computing the yp values. If we take our implicit + // DAE system res(t,y,yp) = f(t,y) - I*yp, then yp = res(t,y,0). This + // avoids an expensive call to IDACalcIC. + realtype *y_cache_val = N_VGetArrayPointer(y_cache); + std::memset(y_cache_val, 0, number_of_states * sizeof(realtype)); + // Overwrite yp + residual_eval(t_val, yy, y_cache, yp, functions.get()); +} + template void IDAKLUSolverOpenMP::SetStep( realtype &tval, @@ -828,9 +772,8 @@ void IDAKLUSolverOpenMP::SetStepOutputSensitivities( template void IDAKLUSolverOpenMP::CheckErrors(int const & flag) { if (flag < 0) { - auto message = (std::string("IDA failed with flag ") + std::to_string(flag)).c_str(); - py::set_error(PyExc_ValueError, message); - throw py::error_already_set(); + auto message = std::string("IDA failed with flag ") + std::to_string(flag); + throw std::runtime_error(message.c_str()); } } diff --git a/src/pybamm/solvers/c_solvers/idaklu/Options.cpp b/src/pybamm/solvers/c_solvers/idaklu/Options.cpp index b6a33e016e..51544040ee 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Options.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Options.cpp @@ -1,6 +1,7 @@ #include "Options.hpp" #include #include +#include using namespace std::string_literals; @@ -11,9 +12,25 @@ SetupOptions::SetupOptions(py::dict &py_opts) precon_half_bandwidth(py_opts["precon_half_bandwidth"].cast()), precon_half_bandwidth_keep(py_opts["precon_half_bandwidth_keep"].cast()), num_threads(py_opts["num_threads"].cast()), + num_solvers(py_opts["num_solvers"].cast()), linear_solver(py_opts["linear_solver"].cast()), linsol_max_iterations(py_opts["linsol_max_iterations"].cast()) { + if (num_solvers > num_threads) + { + throw std::domain_error( + "Number of solvers must be less than or equal to the number of threads" + ); + } + + // input num_threads is the overall number of threads to use. num_solvers of these + // will be used to run solvers in parallel, leaving num_threads / num_solvers threads + // to be used by each solver. From here on num_threads is the number of threads to be used by each solver + num_threads = static_cast( + std::floor( + static_cast(num_threads) / static_cast(num_solvers) + ) + ); using_sparse_matrix = true; using_banded_matrix = false; diff --git a/src/pybamm/solvers/c_solvers/idaklu/Options.hpp b/src/pybamm/solvers/c_solvers/idaklu/Options.hpp index 66a175cfff..d0c0c1d766 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Options.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Options.hpp @@ -15,6 +15,7 @@ struct SetupOptions { int precon_half_bandwidth; int precon_half_bandwidth_keep; int num_threads; + int num_solvers; // IDALS linear solver interface std::string linear_solver; // klu, lapack, spbcg int linsol_max_iterations; diff --git a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp index 72d48fa644..a43e6a7174 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp @@ -9,6 +9,11 @@ class Solution { public: + /** + * @brief Default Constructor + */ + Solution() = default; + /** * @brief Constructor */ @@ -17,6 +22,11 @@ class Solution { } + /** + * @brief Default copy from another Solution + */ + Solution(const Solution &solution) = default; + int flag; np_array t; np_array y; diff --git a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp new file mode 100644 index 0000000000..00c2ddbccc --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp @@ -0,0 +1,67 @@ +#include "SolutionData.hpp" + +Solution SolutionData::generate_solution() { + py::capsule free_t_when_done( + t_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array t_ret = np_array( + number_of_timesteps, + &t_return[0], + free_t_when_done + ); + + py::capsule free_y_when_done( + y_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array y_ret = np_array( + number_of_timesteps * length_of_return_vector, + &y_return[0], + free_y_when_done + ); + + py::capsule free_yS_when_done( + yS_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array yS_ret = np_array( + std::vector { + arg_sens0, + arg_sens1, + arg_sens2 + }, + &yS_return[0], + free_yS_when_done + ); + + // Final state slice, yterm + py::capsule free_yterm_when_done( + yterm_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array y_term = np_array( + length_of_final_sv_slice, + &yterm_return[0], + free_yterm_when_done + ); + + // Store the solution + return Solution(flag, t_ret, y_ret, yS_ret, y_term); +} diff --git a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp new file mode 100644 index 0000000000..815e41daca --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp @@ -0,0 +1,73 @@ +#ifndef PYBAMM_IDAKLU_SOLUTION_DATA_HPP +#define PYBAMM_IDAKLU_SOLUTION_DATA_HPP + + +#include "common.hpp" +#include "Solution.hpp" + +/** + * @brief SolutionData class. Contains all the data needed to create a Solution + */ +class SolutionData +{ + public: + /** + * @brief Default constructor + */ + SolutionData() = default; + + /** + * @brief constructor using fields + */ + SolutionData( + int flag, + int number_of_timesteps, + int length_of_return_vector, + int arg_sens0, + int arg_sens1, + int arg_sens2, + int length_of_final_sv_slice, + realtype *t_return, + realtype *y_return, + realtype *yS_return, + realtype *yterm_return): + flag(flag), + number_of_timesteps(number_of_timesteps), + length_of_return_vector(length_of_return_vector), + arg_sens0(arg_sens0), + arg_sens1(arg_sens1), + arg_sens2(arg_sens2), + length_of_final_sv_slice(length_of_final_sv_slice), + t_return(t_return), + y_return(y_return), + yS_return(yS_return), + yterm_return(yterm_return) + {} + + + /** + * @brief Default copy from another SolutionData + */ + SolutionData(const SolutionData &solution_data) = default; + + /** + * @brief Create a solution object from this data + */ + Solution generate_solution(); + +private: + + int flag; + int number_of_timesteps; + int length_of_return_vector; + int arg_sens0; + int arg_sens1; + int arg_sens2; + int length_of_final_sv_slice; + realtype *t_return; + realtype *y_return; + realtype *yS_return; + realtype *yterm_return; +}; + +#endif // PYBAMM_IDAKLU_SOLUTION_DATA_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/common.cpp b/src/pybamm/solvers/c_solvers/idaklu/common.cpp index bf38acc56a..161c14f340 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/common.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu/common.cpp @@ -11,21 +11,9 @@ std::vector numpy2realtype(const np_array& input_np) { return output; } -std::vector setDiff(const std::vector& A, const std::vector& B) { - std::vector result; - if (!(A.empty())) { - std::set_difference(A.begin(), A.end(), B.begin(), B.end(), std::back_inserter(result)); - } - return result; -} -std::vector makeSortedUnique(const std::vector& input) { - std::unordered_set uniqueSet(input.begin(), input.end()); // Remove duplicates - std::vector uniqueVector(uniqueSet.begin(), uniqueSet.end()); // Convert to vector - std::sort(uniqueVector.begin(), uniqueVector.end()); // Sort the vector - return uniqueVector; -} std::vector makeSortedUnique(const np_array& input_np) { - return makeSortedUnique(numpy2realtype(input_np)); + const auto input_vec = numpy2realtype(input_np); + return makeSortedUnique(input_vec.begin(), input_vec.end()); } diff --git a/src/pybamm/solvers/c_solvers/idaklu/common.hpp b/src/pybamm/solvers/c_solvers/idaklu/common.hpp index 3289326541..58be90932e 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/common.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/common.hpp @@ -31,8 +31,8 @@ #include namespace py = pybind11; -using np_array = py::array_t; -using np_array_dense = py::array_t; +// note: we rely on c_style ordering for numpy arrays so don't change this! +using np_array = py::array_t; using np_array_int = py::array_t; /** @@ -83,12 +83,25 @@ std::vector numpy2realtype(const np_array& input_np); /** * @brief Utility function to compute the set difference of two vectors */ -std::vector setDiff(const std::vector& A, const std::vector& B); +template +std::vector setDiff(const T1 a_begin, const T1 a_end, const T2 b_begin, const T2 b_end) { + std::vector result; + if (std::distance(a_begin, a_end) > 0) { + std::set_difference(a_begin, a_end, b_begin, b_end, std::back_inserter(result)); + } + return result; +} /** * @brief Utility function to make a sorted and unique vector */ -std::vector makeSortedUnique(const std::vector& input); +template +std::vector makeSortedUnique(const T input_begin, const T input_end) { + std::unordered_set uniqueSet(input_begin, input_end); // Remove duplicates + std::vector uniqueVector(uniqueSet.begin(), uniqueSet.end()); // Convert to vector + std::sort(uniqueVector.begin(), uniqueVector.end()); // Sort the vector + return uniqueVector; +} std::vector makeSortedUnique(const np_array& input_np); @@ -126,8 +139,7 @@ std::vector makeSortedUnique(const np_array& input_np); } \ std::cout << "]" << std::endl; } -#define DEBUG_v(v, M) {\ - int N = 2; \ +#define DEBUG_v(v, N) {\ std::cout << #v << "[n=" << N << "] = ["; \ for (int i = 0; i < N; i++) { \ std::cout << v[i]; \ diff --git a/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp b/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp index ce1765aa82..dcc1e4f8cc 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp @@ -2,6 +2,7 @@ #define PYBAMM_CREATE_IDAKLU_SOLVER_HPP #include "IDAKLUSolverOpenMP_solvers.hpp" +#include "IDAKLUSolverGroup.hpp" #include #include @@ -12,52 +13,21 @@ */ template IDAKLUSolver *create_idaklu_solver( - int number_of_states, + std::unique_ptr functions, int number_of_parameters, - const typename ExprSet::BaseFunctionType &rhs_alg, - const typename ExprSet::BaseFunctionType &jac_times_cjmass, const np_array_int &jac_times_cjmass_colptrs, const np_array_int &jac_times_cjmass_rowvals, const int jac_times_cjmass_nnz, const int jac_bandwidth_lower, const int jac_bandwidth_upper, - const typename ExprSet::BaseFunctionType &jac_action, - const typename ExprSet::BaseFunctionType &mass_action, - const typename ExprSet::BaseFunctionType &sens, - const typename ExprSet::BaseFunctionType &events, const int number_of_events, np_array rhs_alg_id, np_array atol_np, double rel_tol, int inputs_length, - const std::vector& var_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - py::dict py_opts + SolverOptions solver_opts, + SetupOptions setup_opts ) { - auto setup_opts = SetupOptions(py_opts); - auto solver_opts = SolverOptions(py_opts); - auto functions = std::make_unique( - rhs_alg, - jac_times_cjmass, - jac_times_cjmass_nnz, - jac_bandwidth_lower, - jac_bandwidth_upper, - jac_times_cjmass_rowvals, - jac_times_cjmass_colptrs, - inputs_length, - jac_action, - mass_action, - sens, - events, - number_of_states, - number_of_events, - number_of_parameters, - var_fcns, - dvar_dy_fcns, - dvar_dp_fcns, - setup_opts - ); IDAKLUSolver *idakluSolver = nullptr; @@ -189,4 +159,88 @@ IDAKLUSolver *create_idaklu_solver( return idakluSolver; } +/** + * @brief Create a group of solvers using create_idaklu_solver + */ +template +IDAKLUSolverGroup *create_idaklu_solver_group( + int number_of_states, + int number_of_parameters, + const typename ExprSet::BaseFunctionType &rhs_alg, + const typename ExprSet::BaseFunctionType &jac_times_cjmass, + const np_array_int &jac_times_cjmass_colptrs, + const np_array_int &jac_times_cjmass_rowvals, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const typename ExprSet::BaseFunctionType &jac_action, + const typename ExprSet::BaseFunctionType &mass_action, + const typename ExprSet::BaseFunctionType &sens, + const typename ExprSet::BaseFunctionType &events, + const int number_of_events, + np_array rhs_alg_id, + np_array atol_np, + double rel_tol, + int inputs_length, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + py::dict py_opts +) { + auto setup_opts = SetupOptions(py_opts); + auto solver_opts = SolverOptions(py_opts); + + + std::vector> solvers; + for (int i = 0; i < setup_opts.num_solvers; i++) { + // Note: we can't copy an ExprSet as it contains raw pointers to the functions + // So we create it in the loop + auto functions = std::make_unique( + rhs_alg, + jac_times_cjmass, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals, + jac_times_cjmass_colptrs, + inputs_length, + jac_action, + mass_action, + sens, + events, + number_of_states, + number_of_events, + number_of_parameters, + var_fcns, + dvar_dy_fcns, + dvar_dp_fcns, + setup_opts + ); + solvers.emplace_back( + std::unique_ptr( + create_idaklu_solver( + std::move(functions), + number_of_parameters, + jac_times_cjmass_colptrs, + jac_times_cjmass_rowvals, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + number_of_events, + rhs_alg_id, + atol_np, + rel_tol, + inputs_length, + solver_opts, + setup_opts + ) + ) + ); + } + + return new IDAKLUSolverGroup(std::move(solvers), number_of_states, number_of_parameters); +} + + + #endif // PYBAMM_CREATE_IDAKLU_SOLVER_HPP diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 08f86b3264..2b2d852697 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -29,7 +29,8 @@ idaklu = importlib.util.module_from_spec(idaklu_spec) if idaklu_spec.loader: idaklu_spec.loader.exec_module(idaklu) - except ImportError: # pragma: no cover + except ImportError as e: # pragma: no cover + print(f"Error loading idaklu: {e}") idaklu_spec = None @@ -78,8 +79,10 @@ class IDAKLUSolver(pybamm.BaseSolver): options = { # Print statistics of the solver after every solve "print_stats": False, - # Number of threads available for OpenMP + # Number of threads available for OpenMP (must be greater than or equal to `num_solvers`) "num_threads": 1, + # Number of solvers to use in parallel (for solving multiple sets of input parameters in parallel) + "num_solvers": num_threads, # Evaluation engine to use for jax, can be 'jax'(native) or 'iree' "jax_evaluator": "jax", ## Linear solver interface @@ -182,6 +185,7 @@ def __init__( "precon_half_bandwidth": 5, "precon_half_bandwidth_keep": 5, "num_threads": 1, + "num_solvers": 1, "jax_evaluator": "jax", "linear_solver": "SUNLinSol_KLU", "linsol_max_iterations": 5, @@ -209,6 +213,8 @@ def __init__( if options is None: options = default_options else: + if "num_threads" in options and "num_solvers" not in options: + options["num_solvers"] = options["num_threads"] for key, value in default_options.items(): if key not in options: options[key] = value @@ -443,7 +449,7 @@ def inputs_to_dict(inputs): if model.convert_to_format == "casadi": # Serialize casadi functions - idaklu_solver_fcn = idaklu.create_casadi_solver + idaklu_solver_fcn = idaklu.create_casadi_solver_group rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) jac_times_cjmass = idaklu.generate_function(jac_times_cjmass.serialize()) jac_rhs_algebraic_action = idaklu.generate_function( @@ -457,7 +463,7 @@ def inputs_to_dict(inputs): and self._options["jax_evaluator"] == "iree" ): # Convert Jax functions to MLIR (also, demote to single precision) - idaklu_solver_fcn = idaklu.create_iree_solver + idaklu_solver_fcn = idaklu.create_iree_solver_group pybamm.demote_expressions_to_32bit = True if pybamm.demote_expressions_to_32bit: warnings.warn( @@ -726,7 +732,15 @@ def _check_mlir_conversion(self, name, mlir: str): def _demote_64_to_32(self, x: pybamm.EvaluatorJax): return pybamm.EvaluatorJax._demote_64_to_32(x) - def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): + @property + def supports_parallel_solve(self): + return True + + @property + def requires_explicit_sensitivities(self): + return False + + def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): """ Solve a DAE model defined by residuals with initial conditions y0. @@ -736,22 +750,30 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): The model whose solution to calculate. t_eval : numeric type The times at which to stop the integration due to a discontinuity in time. - inputs_dict : dict, optional + inputs_list: list of dict, optional Any input parameters to pass to the model when solving. t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to `None`, which returns the adaptive time-stepping times. """ - inputs_dict = inputs_dict or {} - # stack inputs - if inputs_dict: - arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] - inputs = np.vstack(arrays_to_stack) + inputs_list = inputs_list or [{}] + + # stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters) + if inputs_list and inputs_list[0]: + inputs = np.vstack( + [ + np.hstack([np.array(x).reshape(-1) for x in inputs_dict.values()]) + for inputs_dict in inputs_list + ] + ) else: inputs = np.array([[]]) - y0full = model.y0full - ydot0full = model.ydot0full + # stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states) + # note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support + # different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input + y0full = np.vstack([model.y0full] * len(inputs_list)) + ydot0full = np.vstack([model.ydot0full] * len(inputs_list)) atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, y0full.size) @@ -761,7 +783,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): model.convert_to_format == "jax" and self._options["jax_evaluator"] == "iree" ): - sol = self._setup["solver"].solve( + solns = self._setup["solver"].solve( t_eval, t_interp, y0full, @@ -773,6 +795,12 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") integration_time = timer.time() + return [ + self._post_process_solution(soln, model, integration_time, inputs_dict) + for soln, inputs_dict in zip(solns, inputs_list) + ] + + def _post_process_solution(self, sol, model, integration_time, inputs_dict): number_of_sensitivity_parameters = self._setup[ "number_of_sensitivity_parameters" ] @@ -964,7 +992,7 @@ def _rhs_dot_consistent_initialization(self, y0, model, time, inputs_dict): rhs0 = rhs_alg0[: model.len_rhs] - # for the differential terms, ydot = -M^-1 * (rhs) + # for the differential terms, ydot = M^-1 * (rhs) ydot0[: model.len_rhs] = model.mass_matrix_inv.entries @ rhs0 return ydot0 diff --git a/src/pybamm/solvers/jax_bdf_solver.py b/src/pybamm/solvers/jax_bdf_solver.py index 6f0c62b9a8..3db82ca0da 100644 --- a/src/pybamm/solvers/jax_bdf_solver.py +++ b/src/pybamm/solvers/jax_bdf_solver.py @@ -111,68 +111,63 @@ def fun_bind_inputs(y, t): jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0) - t0 = t_eval[0] - h0 = t_eval[1] - t0 - + t0, h0 = t_eval[0], t_eval[1] - t_eval[0] stepper = _bdf_init( fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol ) - i = 0 y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype) - init_state = [stepper, t_eval, i, y_out] - def cond_fun(state): - _, t_eval, i, _ = state + _, _, i, _ = state return i < len(t_eval) def body_fun(state): stepper, t_eval, i, y_out = state stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs) - index = jnp.searchsorted(t_eval, stepper.t) - index = index.astype( - "int" + t_eval.dtype.name[-2:] - ) # Coerce index to correct type + index = jnp.searchsorted(t_eval, stepper.t).astype(jnp.int32) - def for_body(j, y_out): - t = t_eval[j] - y_out = y_out.at[jnp.index_exp[j, :]].set(_bdf_interpolate(stepper, t)) - return y_out + def interpolate_and_update(j, y_out): + y = _bdf_interpolate(stepper, t_eval[j]) + return y_out.at[j].set(y) - y_out = jax.lax.fori_loop(i, index, for_body, y_out) - return [stepper, t_eval, index, y_out] + y_out = jax.lax.fori_loop(i, index, interpolate_and_update, y_out) + return stepper, t_eval, index, y_out + + init_state = (stepper, t_eval, 0, y_out) + _, _, _, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state) - stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, init_state) return y_out - BDFInternalStates = [ - "t", - "atol", - "rtol", - "M", - "newton_tol", - "order", - "h", - "n_equal_steps", - "D", - "y0", - "scale_y0", - "kappa", - "gamma", - "alpha", - "c", - "error_const", - "J", - "LU", - "U", - "psi", - "n_function_evals", - "n_jacobian_evals", - "n_lu_decompositions", - "n_steps", - "consistent_y0_failed", - ] - BDFState = collections.namedtuple("BDFState", BDFInternalStates) + BDFState = collections.namedtuple( + "BDFState", + [ + "t", + "atol", + "rtol", + "M", + "newton_tol", + "order", + "h", + "n_equal_steps", + "D", + "y0", + "scale_y0", + "kappa", + "gamma", + "alpha", + "c", + "error_const", + "J", + "LU", + "U", + "psi", + "n_function_evals", + "n_jacobian_evals", + "n_lu_decompositions", + "n_steps", + "consistent_y0_failed", + ], + ) jax.tree_util.register_pytree_node( BDFState, lambda xs: (tuple(xs), None), lambda _, xs: BDFState(*xs) @@ -211,62 +206,70 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): absolute tolerance for the solver """ - state = {} - state["t"] = t0 - state["atol"] = atol - state["rtol"] = rtol - state["M"] = mass EPS = jnp.finfo(y0.dtype).eps - state["newton_tol"] = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol**0.5)) + # Scaling and tolerance initialisation scale_y0 = atol + rtol * jnp.abs(y0) + newton_tol = jnp.maximum(10 * EPS / rtol, jnp.minimum(0.03, rtol**0.5)) + y0, not_converged = _select_initial_conditions( - fun, mass, t0, y0, state["newton_tol"], scale_y0 + fun, mass, t0, y0, newton_tol, scale_y0 ) - state["consistent_y0_failed"] = not_converged + # Compute initial function and step size f0 = fun(y0, t0) - order = 1 - state["order"] = order - state["h"] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0) - state["n_equal_steps"] = 0 + h = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0) + + # Initialise the difference matrix, D D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) D = D.at[jnp.index_exp[0, :]].set(y0) - D = D.at[jnp.index_exp[1, :]].set(f0 * state["h"]) - state["D"] = D - state["y0"] = y0 - state["scale_y0"] = scale_y0 + D = D.at[jnp.index_exp[1, :]].set(f0 * h) # kappa values for difference orders, taken from Table 1 of [1] kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1)))) alpha = 1.0 / ((1 - kappa) * gamma) - c = state["h"] * alpha[order] + c = h * alpha[1] error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2) - state["kappa"] = kappa - state["gamma"] = gamma - state["alpha"] = alpha - state["c"] = c - state["error_const"] = error_const - + # Jacobian and LU decomp J = jac(y0, t0) - state["J"] = J - - state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J) - - state["U"] = _compute_R(order, 1) - state["psi"] = None - - state["n_function_evals"] = 2 - state["n_jacobian_evals"] = 1 - state["n_lu_decompositions"] = 1 - state["n_steps"] = 0 + LU = jax.scipy.linalg.lu_factor(mass - c * J) + U = _compute_R(1, 1) # Order 1 + + # Create initial BDFState + state = BDFState( + t=t0, + atol=atol, + rtol=rtol, + M=mass, + newton_tol=newton_tol, + consistent_y0_failed=not_converged, + order=1, + h=h, + n_equal_steps=0, + D=D, + y0=y0, + scale_y0=scale_y0, + kappa=kappa, + gamma=gamma, + alpha=alpha, + c=c, + error_const=error_const, + J=J, + LU=LU, + U=U, + psi=None, + n_function_evals=2, + n_jacobian_evals=1, + n_lu_decompositions=1, + n_steps=0, + ) - tuple_state = BDFState(*[state[k] for k in BDFInternalStates]) - y0, scale_y0 = _predict(tuple_state, D) - psi = _update_psi(tuple_state, D) - return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi) + # Predict initial y0, scale_yo, update state + y0, scale_y0 = _predict(state, D) + psi = _update_psi(state, D) + return state._replace(y0=y0, scale_y0=scale_y0, psi=psi) def _compute_R(order, factor): """ @@ -374,10 +377,8 @@ def _predict(state, D): """ predict forward to new step (eq 2 in [1]) """ - n = len(state.y0) - order = state.order - orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1) - subD = jnp.where(orders <= order, D, 0) + orders = jnp.arange(MAX_ORDER + 1)[:, None] + subD = jnp.where(orders <= state.order, D, 0) y0 = jnp.sum(subD, axis=0) scale_y0 = state.atol + state.rtol * jnp.abs(state.y0) return y0, scale_y0 @@ -397,7 +398,7 @@ def _update_psi(state, D): def _update_difference_for_next_step(state, d): """ - update of difference equations can be done efficiently + Update of difference equations can be done efficiently by reusing d and D. From first equation on page 4 of [1]: @@ -409,34 +410,21 @@ def _update_difference_for_next_step(state, d): Combining these gives the following algorithm """ order = state.order - D = state.D - D = D.at[jnp.index_exp[order + 2]].set(d - D[order + 1]) - D = D.at[jnp.index_exp[order + 1]].set(d) - i = order - while_state = [i, D] - - def while_cond(while_state): - i, _ = while_state - return i >= 0 + D = state.D.at[order + 2].set(d - state.D[order + 1]) + D = D.at[order + 1].set(d) - def while_body(while_state): - i, D = while_state - D = D.at[jnp.index_exp[i]].add(D[i + 1]) - i -= 1 - return [i, D] + def update_D(i, D): + return D.at[order - i].add(D[order - i + 1]) - i, D = jax.lax.while_loop(while_cond, while_body, while_state) - - return D + return jax.lax.fori_loop(0, order + 1, update_D, D) def _update_step_size_and_lu(state, factor): + """ + Update step size and recompute LU decomposition. + """ state = _update_step_size(state, factor) - - # redo lu (c has changed) LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J) - n_lu_decompositions = state.n_lu_decompositions + 1 - - return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions) + return state._replace(LU=LU, n_lu_decompositions=state.n_lu_decompositions + 1) def _update_step_size(state, factor): """ @@ -449,7 +437,6 @@ def _update_step_size(state, factor): """ order = state.order h = state.h * factor - n_equal_steps = 0 c = h * state.alpha[order] # update D using equations in section 3.2 of [1] @@ -461,19 +448,14 @@ def _update_step_size(state, factor): RU = jnp.where( jnp.logical_and(I <= order, J <= order), RU, jnp.identity(MAX_ORDER + 1) ) - D = state.D - D = jnp.dot(RU.T, D) - # D = jax.ops.index_update(D, jax.ops.index[:order + 1], - # jnp.dot(RU.T, D[:order + 1])) + D = jnp.dot(RU.T, state.D) - # update psi (D has changed) + # update psi, y0 (D has changed) psi = _update_psi(state, D) - - # update y0 (D has changed) y0, scale_y0 = _predict(state, D) return state._replace( - n_equal_steps=n_equal_steps, + n_equal_steps=0, h=h, c=c, D=D, @@ -484,27 +466,23 @@ def _update_step_size(state, factor): def _update_jacobian(state, jac): """ - we update the jacobian using J(t_{n+1}, y^0_{n+1}) + Update the jacobian using J(t_{n+1}, y^0_{n+1}) following the scipy bdf implementation rather than J(t_n, y_n) as per [1] """ J = jac(state.y0, state.t + state.h) - n_jacobian_evals = state.n_jacobian_evals + 1 LU = jax.scipy.linalg.lu_factor(state.M - state.c * J) - n_lu_decompositions = state.n_lu_decompositions + 1 return state._replace( J=J, - n_jacobian_evals=n_jacobian_evals, + n_jacobian_evals=state.n_jacobian_evals + 1, LU=LU, - n_lu_decompositions=n_lu_decompositions, + n_lu_decompositions=state.n_lu_decompositions + 1, ) def _newton_iteration(state, fun): - tol = state.newton_tol - c = state.c - psi = state.psi + """ + Perform Newton iteration to solve the system. + """ y0 = state.y0 - LU = state.LU - M = state.M scale_y0 = state.scale_y0 t = state.t + state.h d = jnp.zeros(y0.shape, dtype=y0.dtype) @@ -522,17 +500,20 @@ def while_cond(while_state): def while_body(while_state): k, converged, dy_norm_old, d, y, n_function_evals = while_state + f_eval = fun(y, t) n_function_evals += 1 - b = c * f_eval - M @ (psi + d) - dy = jax.scipy.linalg.lu_solve(LU, b) + b = state.c * f_eval - state.M @ (state.psi + d) + dy = jax.scipy.linalg.lu_solve(state.LU, b) dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0) ** 2)) rate = dy_norm / dy_norm_old # if iteration is not going to converge in NEWTON_MAXITER # (assuming the current rate), then abort pred = rate >= 1 - pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol + pred += ( + rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > state.newton_tol + ) pred *= dy_norm_old >= 0 k += pred * (NEWTON_MAXITER - k - 1) @@ -541,7 +522,7 @@ def while_body(while_state): # if converged then break out of iteration early pred = dy_norm_old >= 0.0 - pred *= rate / (1 - rate) * dy_norm < tol + pred *= rate / (1 - rate) * dy_norm < state.newton_tol converged = (dy_norm == 0.0) + pred dy_norm_old = dy_norm @@ -564,7 +545,6 @@ def _prepare_next_step(state, d): def _prepare_next_step_order_change(state, d, y, n_iter): order = state.order - D = _update_difference_for_next_step(state, d) # Note: we are recalculating these from the while loop above, could re-use? @@ -586,7 +566,6 @@ def _prepare_next_step_order_change(state, d, y, n_iter): rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y), jnp.inf, ) - error_norms = jnp.array([error_m_norm, error_norm, error_p_norm]) factors = error_norms ** (-1 / (jnp.arange(3) + order)) @@ -595,111 +574,89 @@ def _prepare_next_step_order_change(state, d, y, n_iter): max_index = jnp.argmax(factors) order += max_index - 1 + # New step size factor factor = jnp.minimum(MAX_FACTOR, safety * factors[max_index]) - - new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor) - return new_state + new_state = state._replace(D=D, order=order) + return _update_step_size_and_lu(new_state, factor) def _bdf_step(state, fun, jac): - # print('bdf_step', state.t, state.h) - # we will try and use the old jacobian unless convergence of newton iteration - # fails - updated_jacobian = False - # initialise step size and try to make the step, - # iterate, reducing step size until error is in bounds - step_accepted = False - y = jnp.empty_like(state.y0) - d = jnp.empty_like(state.y0) - n_iter = -1 - - # loop until step is accepted - while_state = [state, step_accepted, updated_jacobian, y, d, n_iter] + """ + Perform a BDF step. - def while_cond(while_state): - _, step_accepted, _, _, _, _ = while_state - return step_accepted == False # noqa: E712 + We will try and use the old jacobian unless + convergence of newton iteration fails. + """ - def while_body(while_state): - state, step_accepted, updated_jacobian, y, d, n_iter = while_state + def step_iteration(while_state): + state, updated_jacobian = while_state - # solve BDF equation using y0 as starting point + # Solve BDF equation using Newton iteration converged, n_iter, y, d, state = _newton_iteration(state, fun) - not_converged = converged == False # noqa: E712 - - # newton iteration did not converge, but jacobian has already been - # evaluated so reduce step size by 0.3 (as per [1]) and try again - state = tree_map( - partial(jnp.where, not_converged * updated_jacobian), - _update_step_size_and_lu(state, 0.3), - state, - ) - # if not_converged * updated_jacobian: - # print('not converged, update step size by 0.3') - # if not_converged * (updated_jacobian == False): - # print('not converged, update jacobian') - - # if not converged and jacobian not updated, then update the jacobian and - # try again - (state, updated_jacobian) = tree_map( - partial( - jnp.where, - not_converged * (updated_jacobian == False), # noqa: E712 + # Update Jacobian or reduce step size if not converged + # Evaluated so reduce step size by 0.3 (as per [1]) and try again + state, updated_jacobian = jax.lax.cond( + ~converged, + lambda s, uj: jax.lax.cond( + uj, + lambda s: (_update_step_size_and_lu(s, 0.3), True), + lambda s: (_update_jacobian(s, jac), True), + s, ), - (_update_jacobian(state, jac), True), - (state, False + updated_jacobian), + lambda s, uj: (s, uj), + state, + updated_jacobian, ) safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) scale_y = state.atol + state.rtol * jnp.abs(y) + # Calculate error and updated step size factor # combine eq 3, 4 and 6 from [1] to obtain error # Note that error = C_k * h^{k+1} y^{k+1} # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} error = state.error_const[state.order] * d - error_norm = rms_norm(error / scale_y) - # calculate optimal step size factor as per eq 2.46 of [2] - factor = jnp.maximum( - MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1)) + # Calculate optimal step size factor as per eq 2.46 of [2] + factor = jnp.clip( + safety * error_norm ** (-1 / (state.order + 1)), MIN_FACTOR, None ) - # if converged * (error_norm > 1): - # print( - # "converged, but error is too large", - # error_norm, - # factor, - # d, - # scale_y, - # ) - - (state, step_accepted) = tree_map( - partial(jnp.where, converged * (error_norm > 1)), - (_update_step_size_and_lu(state, factor), False), - (state, converged), + # Update step size if error is too large + state = jax.lax.cond( + converged & (error_norm > 1), + lambda s: _update_step_size_and_lu(s, factor), + lambda s: s, + state, ) - return [state, step_accepted, updated_jacobian, y, d, n_iter] - - state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop( - while_cond, while_body, while_state + step_accepted = converged & (error_norm <= 1) + return (state, updated_jacobian), (step_accepted, y, d, n_iter) + + # Iterate until step is accepted + (state, _), (_, y, d, n_iter) = jax.lax.while_loop( + lambda carry_and_aux: ~carry_and_aux[1][0], + lambda carry_and_aux: step_iteration(carry_and_aux[0]), + ( + (state, False), + (False, jnp.empty_like(state.y0), jnp.empty_like(state.y0), -1), + ), ) - # take the accepted step + # Update state for the accepted step n_steps = state.n_steps + 1 t = state.t + state.h - - # a change in order is only done after running at order k for k + 1 steps - # (see page 83 of [2]) n_equal_steps = state.n_equal_steps + 1 - state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps) - state = tree_map( - partial(jnp.where, n_equal_steps < state.order + 1), - _prepare_next_step(state, d), - _prepare_next_step_order_change(state, d, y, n_iter), + # Prepare for the next step, potentially changing order + # (see page 83 of [2]) + state = jax.lax.cond( + n_equal_steps < state.order + 1, + lambda s: _prepare_next_step(s, d), + lambda s: _prepare_next_step_order_change(s, d, y, n_iter), + state, ) return state @@ -710,8 +667,6 @@ def _bdf_interpolate(state, t_eval): definition of the interpolating polynomial can be found on page 7 of [1] """ - order = state.order - t = state.t h = state.h D = state.D j = 0 @@ -721,11 +676,11 @@ def _bdf_interpolate(state, t_eval): def while_cond(while_state): j, _, _ = while_state - return j < order + return j < state.order def while_body(while_state): j, time_factor, order_summation = while_state - time_factor *= (t_eval - (t - h * j)) / (h * (1 + j)) + time_factor *= (t_eval - (state.t - h * j)) / (h * (1 + j)) order_summation += D[j + 1] * time_factor j += 1 return [j, time_factor, order_summation] @@ -972,13 +927,13 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): """ Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in :footcite:t:`byrne1975polyalgorithm`. This particular implementation - follows that implemented in the Matlab routine ode15s described in - :footcite:t:`shampine1997matlab` and the SciPy implementation - :footcite:t:`Virtanen2020` which features the NDF formulas for improved stability, - with associated differences in the error constants, and calculates the jacobian at - J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the - SciPy library :footcite:t:`Virtanen2020`, which also mainly follows - :footcite:t:`shampine1997matlab` but uses the more standard jacobian update. + follows the Matlab routine ode15s described in :footcite:t:`shampine1997matlab` + and the SciPy implementation :footcite:t:`Virtanen2020` which features + the NDF formulas for improved stability, with associated differences in the + error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This + implementation was based on that implemented in the SciPy library + :footcite:t:`Virtanen2020`, which also mainly follows :footcite:t:`shampine1997matlab` + but uses the more standard jacobian update. Parameters ---------- diff --git a/src/pybamm/solvers/jax_solver.py b/src/pybamm/solvers/jax_solver.py index da5fd4983a..a1f1733ed6 100644 --- a/src/pybamm/solvers/jax_solver.py +++ b/src/pybamm/solvers/jax_solver.py @@ -31,10 +31,10 @@ class JaxSolver(pybamm.BaseSolver): Parameters ---------- method: str, optional (see `jax.experimental.ode.odeint` for details) - * 'RK45' (default) uses jax.experimental.ode.odeint - * 'BDF' uses custom jax_bdf_integrate (see `jax_bdf_integrate.py` for details) + * 'BDF' (default) uses custom jax_bdf_integrate (see `jax_bdf_integrate.py` for details) + * 'RK45' uses jax.experimental.ode.odeint root_method: str, optional - Method to use to calculate consistent initial conditions. By default this uses + Method to use to calculate consistent initial conditions. By default, this uses the newton chord method internal to the jax bdf solver, otherwise choose from the set of default options defined in docs for pybamm.BaseSolver rtol : float, optional @@ -52,7 +52,7 @@ class JaxSolver(pybamm.BaseSolver): def __init__( self, - method="RK45", + method="BDF", root_method=None, rtol=1e-6, atol=1e-6, @@ -185,6 +185,14 @@ def solve_model_bdf(inputs): else: return jax.jit(solve_model_bdf) + @property + def supports_parallel_solve(self): + return True + + @property + def requires_explicit_sensitivities(self): + return False + def _integrate(self, model, t_eval, inputs=None, t_interp=None): """ Solve a model defined by dydt with initial conditions y0. @@ -200,7 +208,7 @@ def _integrate(self, model, t_eval, inputs=None, t_interp=None): Returns ------- - object + list of `pybamm.Solution` An object containing the times and values of the solution, as well as various diagnostic messages. @@ -301,6 +309,4 @@ async def solve_model_async(inputs_v): sol.integration_time = integration_time solutions.append(sol) - if len(solutions) == 1: - return solutions[0] return solutions diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_splitOCVR.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_splitOCVR.py new file mode 100644 index 0000000000..ef4391acdd --- /dev/null +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_splitOCVR.py @@ -0,0 +1,56 @@ +# +# Test that the model works with an example parameter set +# +import pybamm +import numpy as np +import tests + + +class TestSplitOCVR: + def test_basic_processing(self): + # example parameters + qp0 = 8.73231852 + qn0 = 5.82761507 + theta0_n = 0.9013973983641687 * 0.9 + theta0_p = 0.5142305254580026 * 0.83 + + # OCV functions + def Un(theta_n): + Un = ( + 0.1493 + + 0.8493 * np.exp(-61.79 * theta_n) + + 0.3824 * np.exp(-665.8 * theta_n) + - np.exp(39.42 * theta_n - 41.92) + - 0.03131 * np.arctan(25.59 * theta_n - 4.099) + - 0.009434 * np.arctan(32.49 * theta_n - 15.74) + ) + return Un + + def Up(theta_p): + Up = ( + -10.72 * theta_p**4 + + 23.88 * theta_p**3 + - 16.77 * theta_p**2 + + 2.595 * theta_p + + 4.563 + ) + return Up + + pars = pybamm.ParameterValues( + { + "Positive electrode capacity [A.h]": qp0, + "Ohmic resistance [Ohm]": 0.001, + "Negative electrode initial stoichiometry": theta0_n, + "Lower voltage cut-off [V]": 2.8, + "Positive electrode initial stoichiometry": theta0_p, + "Upper voltage cut-off [V]": 4.2, + "Negative electrode capacity [A.h]": qn0, + "Current function [A]": 5, + "Positive electrode OCP [V]": Up, + "Negative electrode OCP [V]": Un, + "Nominal cell capacity [A.h]": 5, + } + ) + model = pybamm.lithium_ion.SplitOCVR() + modeltest = tests.StandardModelTest(model) + modeltest.test_all(param=pars) diff --git a/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_splitOCVR.py b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_splitOCVR.py new file mode 100644 index 0000000000..e807ec1607 --- /dev/null +++ b/tests/unit/test_models/test_full_battery_models/test_lithium_ion/test_splitOCVR.py @@ -0,0 +1,15 @@ +# +# Test for the ecm split-OCV model +# +import pybamm + + +class TestSplitOCVR: + def test_ecmsplitocv_well_posed(self): + model = pybamm.lithium_ion.SplitOCVR() + model.check_well_posedness() + + def test_get_default_quick_plot_variables(self): + model = pybamm.lithium_ion.SplitOCVR() + variables = model.default_quick_plot_variables + assert "Current [A]" in variables diff --git a/tests/unit/test_parameters/test_parameter_values.py b/tests/unit/test_parameters/test_parameter_values.py index 4086abea6d..cdcfb30ede 100644 --- a/tests/unit/test_parameters/test_parameter_values.py +++ b/tests/unit/test_parameters/test_parameter_values.py @@ -17,6 +17,7 @@ ) from pybamm.expression_tree.exceptions import OptionError import casadi +from pybamm.parameters.parameter_values import ParameterValues class TestParameterValues: @@ -36,12 +37,6 @@ def test_init(self): param = pybamm.ParameterValues({"a": 1, "chemistry": "lithium-ion"}) assert "chemistry" not in param.keys() - # chemistry kwarg removed - with pytest.raises( - ValueError, match="'chemistry' keyword argument has been deprecated" - ): - pybamm.ParameterValues(None, chemistry="lithium-ion") - # junk param values rejected with pytest.raises(ValueError, match="'Junk' is not a valid parameter set."): pybamm.ParameterValues("Junk") @@ -1029,3 +1024,23 @@ def test_exchange_current_density_plating(self): match="referring to the reaction at the surface of a lithium metal electrode", ): parameter_values.evaluate(param) + + def test_contains_method(self): + """Test for __contains__ method to check the functionality of 'in' keyword""" + parameter_values = ParameterValues( + {"Negative particle radius [m]": 1e-6, "Positive particle radius [m]": 2e-6} + ) + assert ( + "Negative particle radius [m]" in parameter_values + ), "Key should be found in parameter_values" + assert ( + "Invalid key" not in parameter_values + ), "Non-existent key should not be found" + + def test_iter_method(self): + """Test for __iter__ method to check if we can iterate over keys""" + parameter_values = ParameterValues( + values={"Negative particle radius [m]": 1e-6} + ) + pv = [i for i in parameter_values] + assert len(pv) == 5, "Should have 5 keys" diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index 6753513e72..1733cafc4c 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -19,6 +19,7 @@ def test_base_solver_init(self): assert solver.rtol == 1e-5 solver.rtol = 1e-7 assert solver.rtol == 1e-7 + assert solver.requires_explicit_sensitivities def test_root_method_init(self): solver = pybamm.BaseSolver(root_method="casadi") diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 32e289b3e0..b049729ae3 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -63,6 +63,47 @@ def test_ida_roberts_klu(self): true_solution = 0.1 * solution.t np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) + def test_multiple_inputs(self): + model = pybamm.BaseModel() + var = pybamm.Variable("var") + rate = pybamm.InputParameter("rate") + model.rhs = {var: -rate * var} + model.initial_conditions = {var: 2} + disc = pybamm.Discretisation() + disc.process_model(model) + + for num_threads, num_solvers in [ + [1, None], + [2, None], + [8, None], + [8, 1], + [8, 2], + [8, 7], + ]: + options = {"num_threads": num_threads} + if num_solvers is not None: + options["num_solvers"] = num_solvers + solver = pybamm.IDAKLUSolver(rtol=1e-5, atol=1e-5, options=options) + t_interp = np.linspace(0, 1, 10) + t_eval = [t_interp[0], t_interp[-1]] + ninputs = 8 + inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + + solutions = solver.solve( + model, t_eval, inputs=inputs_list, t_interp=t_interp + ) + + # check solution + for inputs, solution in zip(inputs_list, solutions): + print("checking solution", inputs, solution.all_inputs) + np.testing.assert_array_equal(solution.t, t_interp) + np.testing.assert_allclose( + solution.y[0], + 2 * np.exp(-inputs["rate"] * solution.t), + atol=1e-4, + rtol=1e-4, + ) + def test_model_events(self): for form in ["casadi", "iree"]: if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()): @@ -1121,3 +1162,24 @@ def test_python_idaklu_deprecation_errors(self): match="Unsupported evaluation engine for convert_to_format=jax", ): _ = solver.solve(model, t_eval) + + def test_extrapolation_events_with_output_variables(self): + # Make sure the extrapolation checks work with output variables + model = pybamm.BaseModel() + v = pybamm.Variable("v") + c = pybamm.Variable("c") + model.variables = {"v": v, "c": c} + model.rhs = {v: -1, c: 0} + model.initial_conditions = {v: 1, c: 2} + model.events.append( + pybamm.Event( + "Triggered event", + v - 0.5, + pybamm.EventType.INTERPOLANT_EXTRAPOLATION, + ) + ) + solver = pybamm.IDAKLUSolver(output_variables=["c"]) + solver.set_up(model) + + with pytest.warns(pybamm.SolverWarning, match="extrapolation occurred for"): + solver.solve(model, t_eval=[0, 1]) diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index f7e5b8d3b6..8f43eda3c7 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -237,3 +237,11 @@ def test_get_solve(self): y = solver({"rate": 0.2}) np.testing.assert_allclose(y[0], np.exp(-0.2 * t_eval), rtol=1e-6, atol=1e-6) + + # Reset solver, test passing `calculate_sensitivities` + for method in ["RK45", "BDF"]: + solver = pybamm.JaxSolver(method=method, rtol=1e-8, atol=1e-8) + solution_sens = solver.solve( + model, t_eval, inputs={"rate": 0.1}, calculate_sensitivities=True + ) + assert len(solution_sens.sensitivities) == 0