diff --git a/CHANGELOG.md b/CHANGELOG.md index 80cda39db7..b65d1342f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features - Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415)) +- Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449)) ## Optimizations - Removed the `start_step_offset` setting and disabled minimum `dt` warnings for drive cycles with the (`IDAKLUSolver`). ([#4416](https://github.com/pybamm-team/PyBaMM/pull/4416)) diff --git a/CMakeLists.txt b/CMakeLists.txt index ad56ac34ca..a7f68ce7a0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ endif() project(idaklu) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_EXPORT_COMPILE_COMMANDS 1) @@ -82,6 +82,8 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.cpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp + src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp + src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP_solvers.cpp @@ -94,6 +96,8 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/common.cpp src/pybamm/solvers/c_solvers/idaklu/Solution.cpp src/pybamm/solvers/c_solvers/idaklu/Solution.hpp + src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp + src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp src/pybamm/solvers/c_solvers/idaklu/Options.hpp src/pybamm/solvers/c_solvers/idaklu/Options.cpp # IDAKLU expressions / function evaluation [abstract] @@ -138,6 +142,23 @@ set_target_properties( INSTALL_RPATH_USE_LINK_PATH TRUE ) +# openmp +if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") + execute_process( + COMMAND "brew" "--prefix" + OUTPUT_VARIABLE HOMEBREW_PREFIX + OUTPUT_STRIP_TRAILING_WHITESPACE) + if (OpenMP_ROOT) + set(OpenMP_ROOT "${OpenMP_ROOT}:${HOMEBREW_PREFIX}/opt/libomp") + else() + set(OpenMP_ROOT "${HOMEBREW_PREFIX}/opt/libomp") + endif() +endif() +find_package(OpenMP) +if(OpenMP_CXX_FOUND) + target_link_libraries(idaklu PRIVATE OpenMP::OpenMP_CXX) +endif() + set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${PROJECT_SOURCE_DIR}) # Sundials find_package(SUNDIALS REQUIRED) diff --git a/src/pybamm/solvers/base_solver.py b/src/pybamm/solvers/base_solver.py index 1df9aef35f..dc1ac0cd72 100644 --- a/src/pybamm/solvers/base_solver.py +++ b/src/pybamm/solvers/base_solver.py @@ -86,6 +86,10 @@ def supports_interp(self): def root_method(self): return self._root_method + @property + def supports_parallel_solve(self): + return False + @root_method.setter def root_method(self, method): if method == "casadi": @@ -896,17 +900,8 @@ def solve( pybamm.logger.verbose( f"Calling solver for {t_eval[start_index]} < t < {t_eval[end_index - 1]}" ) - ninputs = len(model_inputs_list) - if ninputs == 1: - new_solution = self._integrate( - model, - t_eval[start_index:end_index], - model_inputs_list[0], - t_interp=t_interp, - ) - new_solutions = [new_solution] - elif model.convert_to_format == "jax": - # Jax can parallelize over the inputs efficiently + if self.supports_parallel_solve: + # Jax and IDAKLU solver can accept a list of inputs new_solutions = self._integrate( model, t_eval[start_index:end_index], @@ -914,18 +909,28 @@ def solve( t_interp, ) else: - with mp.get_context(self._mp_context).Pool(processes=nproc) as p: - new_solutions = p.starmap( - self._integrate, - zip( - [model] * ninputs, - [t_eval[start_index:end_index]] * ninputs, - model_inputs_list, - [t_interp] * ninputs, - ), + ninputs = len(model_inputs_list) + if ninputs == 1: + new_solution = self._integrate( + model, + t_eval[start_index:end_index], + model_inputs_list[0], + t_interp=t_interp, ) - p.close() - p.join() + new_solutions = [new_solution] + else: + with mp.get_context(self._mp_context).Pool(processes=nproc) as p: + new_solutions = p.starmap( + self._integrate, + zip( + [model] * ninputs, + [t_eval[start_index:end_index]] * ninputs, + model_inputs_list, + [t_interp] * ninputs, + ), + ) + p.close() + p.join() # Setting the solve time for each segment. # pybamm.Solution.__add__ assumes attribute solve_time. solve_time = timer.time() @@ -995,7 +1000,7 @@ def solve( ) # Return solution(s) - if ninputs == 1: + if len(solutions) == 1: return solutions[0] else: return solutions @@ -1350,7 +1355,13 @@ def step( # Step pybamm.logger.verbose(f"Stepping for {t_start_shifted:.0f} < t < {t_end:.0f}") timer.reset() - solution = self._integrate(model, t_eval, model_inputs, t_interp) + + # API for _integrate is different for JaxSolver and IDAKLUSolver + if self.supports_parallel_solve: + solutions = self._integrate(model, t_eval, [model_inputs], t_interp) + solution = solutions[0] + else: + solution = self._integrate(model, t_eval, model_inputs, t_interp) solution.solve_time = timer.time() # Check if extrapolation occurred diff --git a/src/pybamm/solvers/c_solvers/idaklu.cpp b/src/pybamm/solvers/c_solvers/idaklu.cpp index 3ef0194403..db7147feb2 100644 --- a/src/pybamm/solvers/c_solvers/idaklu.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu.cpp @@ -9,6 +9,7 @@ #include #include "idaklu/idaklu_solver.hpp" +#include "idaklu/IDAKLUSolverGroup.hpp" #include "idaklu/IdakluJax.hpp" #include "idaklu/common.hpp" #include "idaklu/Expressions/Casadi/CasadiFunctions.hpp" @@ -26,15 +27,17 @@ casadi::Function generate_casadi_function(const std::string &data) namespace py = pybind11; PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MODULE(idaklu, m) { m.doc() = "sundials solvers"; // optional module docstring py::bind_vector>(m, "VectorNdArray"); + py::bind_vector>(m, "VectorSolution"); - py::class_(m, "IDAKLUSolver") - .def("solve", &IDAKLUSolver::solve, + py::class_(m, "IDAKLUSolverGroup") + .def("solve", &IDAKLUSolverGroup::solve, "perform a solve", py::arg("t_eval"), py::arg("t_interp"), @@ -43,8 +46,8 @@ PYBIND11_MODULE(idaklu, m) py::arg("inputs"), py::return_value_policy::take_ownership); - m.def("create_casadi_solver", &create_idaklu_solver, - "Create a casadi idaklu solver object", + m.def("create_casadi_solver_group", &create_idaklu_solver_group, + "Create a group of casadi idaklu solver objects", py::arg("number_of_states"), py::arg("number_of_parameters"), py::arg("rhs_alg"), @@ -70,8 +73,8 @@ PYBIND11_MODULE(idaklu, m) py::return_value_policy::take_ownership); #ifdef IREE_ENABLE - m.def("create_iree_solver", &create_idaklu_solver, - "Create a iree idaklu solver object", + m.def("create_iree_solver_group", &create_idaklu_solver_group, + "Create a group of iree idaklu solver objects", py::arg("number_of_states"), py::arg("number_of_parameters"), py::arg("rhs_alg"), diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp index 29b451e6d3..379d64783a 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolver.hpp @@ -2,7 +2,8 @@ #define PYBAMM_IDAKLU_CASADI_SOLVER_HPP #include "common.hpp" -#include "Solution.hpp" +#include "SolutionData.hpp" + /** * Abstract base class for solutions that can use different solvers and vector @@ -24,14 +25,17 @@ class IDAKLUSolver ~IDAKLUSolver() = default; /** - * @brief Abstract solver method that returns a Solution class + * @brief Abstract solver method that executes the solver */ - virtual Solution solve( - np_array t_eval_np, - np_array t_interp_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs) = 0; + virtual SolutionData solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps + ) = 0; /** * Abstract method to initialize the solver, once vectors and solver classes diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp new file mode 100644 index 0000000000..8a76d73cfe --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.cpp @@ -0,0 +1,145 @@ +#include "IDAKLUSolverGroup.hpp" +#include +#include + +std::vector IDAKLUSolverGroup::solve( + np_array t_eval_np, + np_array t_interp_np, + np_array y0_np, + np_array yp0_np, + np_array inputs) { + DEBUG("IDAKLUSolverGroup::solve"); + + // If t_interp is empty, save all adaptive steps + bool save_adaptive_steps = t_interp_np.size() == 0; + + const realtype* t_eval_begin = t_eval_np.data(); + const realtype* t_eval_end = t_eval_begin + t_eval_np.size(); + const realtype* t_interp_begin = t_interp_np.data(); + const realtype* t_interp_end = t_interp_begin + t_interp_np.size(); + + // Process the time inputs + // 1. Get the sorted and unique t_eval vector + auto const t_eval = makeSortedUnique(t_eval_begin, t_eval_end); + + // 2.1. Get the sorted and unique t_interp vector + auto const t_interp_unique_sorted = makeSortedUnique(t_interp_begin, t_interp_end); + + // 2.2 Remove the t_eval values from t_interp + auto const t_interp_setdiff = setDiff(t_interp_unique_sorted.begin(), t_interp_unique_sorted.end(), t_eval_begin, t_eval_end); + + // 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed + auto const t_interp = makeSortedUnique(t_interp_setdiff.begin(), t_interp_setdiff.end()); + + int const number_of_evals = t_eval.size(); + int const number_of_interps = t_interp.size(); + + // setDiff removes entries of t_interp that overlap with + // t_eval, so we need to check if we need to interpolate any unique points. + // This is not the same as save_adaptive_steps since some entries of t_interp + // may be removed by setDiff + bool save_interp_steps = number_of_interps > 0; + + // 3. Check if the timestepping entries are valid + if (number_of_evals < 2) { + throw std::invalid_argument( + "t_eval must have at least 2 entries" + ); + } else if (save_interp_steps) { + if (t_interp.front() < t_eval.front()) { + throw std::invalid_argument( + "t_interp values must be greater than the smallest t_eval value: " + + std::to_string(t_eval.front()) + ); + } else if (t_interp.back() > t_eval.back()) { + throw std::invalid_argument( + "t_interp values must be less than the greatest t_eval value: " + + std::to_string(t_eval.back()) + ); + } + } + + auto n_coeffs = number_of_states + number_of_parameters * number_of_states; + + // check y0 and yp0 and inputs have the correct dimensions + if (y0_np.ndim() != 2) + throw std::domain_error("y0 has wrong number of dimensions. Expected 2 but got " + std::to_string(y0_np.ndim())); + if (yp0_np.ndim() != 2) + throw std::domain_error("yp0 has wrong number of dimensions. Expected 2 but got " + std::to_string(yp0_np.ndim())); + if (inputs.ndim() != 2) + throw std::domain_error("inputs has wrong number of dimensions. Expected 2 but got " + std::to_string(inputs.ndim())); + + auto number_of_groups = y0_np.shape()[0]; + + // check y0 and yp0 and inputs have the correct shape + if (y0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "y0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(y0_np.shape()[1])); + + if (yp0_np.shape()[1] != n_coeffs) + throw std::domain_error( + "yp0 has wrong number of cols. Expected " + std::to_string(n_coeffs) + + " but got " + std::to_string(yp0_np.shape()[1])); + + if (yp0_np.shape()[0] != number_of_groups) + throw std::domain_error( + "yp0 has wrong number of rows. Expected " + std::to_string(number_of_groups) + + " but got " + std::to_string(yp0_np.shape()[0])); + + if (inputs.shape()[0] != number_of_groups) + throw std::domain_error( + "inputs has wrong number of rows. Expected " + std::to_string(number_of_groups) + + " but got " + std::to_string(inputs.shape()[0])); + + const std::size_t solves_per_thread = number_of_groups / m_solvers.size(); + const std::size_t remainder_solves = number_of_groups % m_solvers.size(); + + const realtype *y0 = y0_np.data(); + const realtype *yp0 = yp0_np.data(); + const realtype *inputs_data = inputs.data(); + + std::vector results(number_of_groups); + + std::optional exception; + + omp_set_num_threads(m_solvers.size()); + #pragma omp parallel for + for (int i = 0; i < m_solvers.size(); i++) { + try { + for (int j = 0; j < solves_per_thread; j++) { + const std::size_t index = i * solves_per_thread + j; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps); + } + } catch (std::exception &e) { + // If an exception is thrown, we need to catch it and rethrow it outside the parallel region + #pragma omp critical + { + exception = e; + } + } + } + + if (exception.has_value()) { + py::set_error(PyExc_ValueError, exception->what()); + throw py::error_already_set(); + } + + for (int i = 0; i < remainder_solves; i++) { + const std::size_t index = number_of_groups - remainder_solves + i; + const realtype *y = y0 + index * y0_np.shape(1); + const realtype *yp = yp0 + index * yp0_np.shape(1); + const realtype *input = inputs_data + index * inputs.shape(1); + results[index] = m_solvers[i]->solve(t_eval, t_interp, y, yp, input, save_adaptive_steps, save_interp_steps); + } + + // create solutions (needs to be serial as we're using the Python GIL) + std::vector solutions(number_of_groups); + for (int i = 0; i < number_of_groups; i++) { + solutions[i] = results[i].generate_solution(); + } + return solutions; +} diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp new file mode 100644 index 0000000000..609b3b6fca --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverGroup.hpp @@ -0,0 +1,48 @@ +#ifndef PYBAMM_IDAKLU_SOLVER_GROUP_HPP +#define PYBAMM_IDAKLU_SOLVER_GROUP_HPP + +#include "IDAKLUSolver.hpp" +#include "common.hpp" + +/** + * @brief class for a group of solvers. + */ +class IDAKLUSolverGroup +{ +public: + + /** + * @brief Default constructor + */ + IDAKLUSolverGroup(std::vector> solvers, int number_of_states, int number_of_parameters): + m_solvers(std::move(solvers)), + number_of_states(number_of_states), + number_of_parameters(number_of_parameters) + {} + + // no copy constructor (unique_ptr cannot be copied) + IDAKLUSolverGroup(IDAKLUSolverGroup &) = delete; + + /** + * @brief Default destructor + */ + ~IDAKLUSolverGroup() = default; + + /** + * @brief solver method that returns a vector of Solutions + */ + std::vector solve( + np_array t_eval_np, + np_array t_interp_np, + np_array y0_np, + np_array yp0_np, + np_array inputs); + + + private: + std::vector> m_solvers; + int number_of_states; + int number_of_parameters; +}; + +#endif // PYBAMM_IDAKLU_SOLVER_GROUP_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp index ca710fbff6..36c2872c3e 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp @@ -52,6 +52,7 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver int const number_of_states; // cppcheck-suppress unusedStructMember int const number_of_parameters; // cppcheck-suppress unusedStructMember int const number_of_events; // cppcheck-suppress unusedStructMember + int number_of_timesteps; int precon_type; // cppcheck-suppress unusedStructMember N_Vector yy, yp, avtol; // y, y', and absolute tolerance N_Vector *yyS; // cppcheck-suppress unusedStructMember @@ -106,12 +107,16 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver /** * @brief The main solve method that solves for each variable and time step */ - Solution solve( - np_array t_eval_np, - np_array t_interp_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs) override; + SolutionData solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps + ) override; + /** * @brief Concrete implementation of initialization method diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl index fd8eb38257..313c4ce12a 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl @@ -3,6 +3,7 @@ #include #include "common.hpp" +#include "SolutionData.hpp" template IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( @@ -86,11 +87,20 @@ IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( template void IDAKLUSolverOpenMP::AllocateVectors() { + DEBUG("IDAKLUSolverOpenMP::AllocateVectors (num_threads = " << setup_opts.num_threads << ")"); // Create vectors - yy = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); - yp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); - avtol = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); - id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + if (setup_opts.num_threads == 1) { + yy = N_VNew_Serial(number_of_states, sunctx); + yp = N_VNew_Serial(number_of_states, sunctx); + avtol = N_VNew_Serial(number_of_states, sunctx); + id = N_VNew_Serial(number_of_states, sunctx); + } else { + DEBUG("IDAKLUSolverOpenMP::AllocateVectors OpenMP"); + yy = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + yp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + avtol = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + } } template @@ -290,6 +300,7 @@ void IDAKLUSolverOpenMP::Initialize() { template IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { + DEBUG("IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP"); // Free memory if (sensitivity) { IDASensFree(ida_mem); @@ -313,63 +324,24 @@ IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { } template -Solution IDAKLUSolverOpenMP::solve( - np_array t_eval_np, - np_array t_interp_np, - np_array y0_np, - np_array yp0_np, - np_array_dense inputs +SolutionData IDAKLUSolverOpenMP::solve( + const std::vector &t_eval, + const std::vector &t_interp, + const realtype *y0, + const realtype *yp0, + const realtype *inputs, + bool save_adaptive_steps, + bool save_interp_steps ) { DEBUG("IDAKLUSolver::solve"); + const int number_of_evals = t_eval.size(); + const int number_of_interps = t_interp.size(); - // If t_interp is empty, save all adaptive steps - bool save_adaptive_steps = t_interp_np.unchecked<1>().size() == 0; - - // Process the time inputs - // 1. Get the sorted and unique t_eval vector - auto const t_eval = makeSortedUnique(t_eval_np); - - // 2.1. Get the sorted and unique t_interp vector - auto const t_interp_unique_sorted = makeSortedUnique(t_interp_np); - - // 2.2 Remove the t_eval values from t_interp - auto const t_interp_setdiff = setDiff(t_interp_unique_sorted, t_eval); - - // 2.3 Finally, get the sorted and unique t_interp vector with t_eval values removed - auto const t_interp = makeSortedUnique(t_interp_setdiff); - - int const number_of_evals = t_eval.size(); - int const number_of_interps = t_interp.size(); - - // setDiff removes entries of t_interp that overlap with - // t_eval, so we need to check if we need to interpolate any unique points. - // This is not the same as save_adaptive_steps since some entries of t_interp - // may be removed by setDiff - bool save_interp_steps = number_of_interps > 0; - - // 3. Check if the timestepping entries are valid - if (number_of_evals < 2) { - throw std::invalid_argument( - "t_eval must have at least 2 entries" - ); - } else if (save_interp_steps) { - if (t_interp.front() < t_eval.front()) { - throw std::invalid_argument( - "t_interp values must be greater than the smallest t_eval value: " - + std::to_string(t_eval.front()) - ); - } else if (t_interp.back() > t_eval.back()) { - throw std::invalid_argument( - "t_interp values must be less than the greatest t_eval value: " - + std::to_string(t_eval.back()) - ); - } + if (t.size() < number_of_evals + number_of_interps) { + InitializeStorage(number_of_evals + number_of_interps); } - // Initialize length_of_return_vector, t, y, and yS - InitializeStorage(number_of_evals + number_of_interps); - int i_save = 0; realtype t0 = t_eval.front(); @@ -386,24 +358,11 @@ Solution IDAKLUSolverOpenMP::solve( t_interp_next = t_interp[0]; } - auto y0 = y0_np.unchecked<1>(); - auto yp0 = yp0_np.unchecked<1>(); auto n_coeffs = number_of_states + number_of_parameters * number_of_states; - if (y0.size() != n_coeffs) { - throw std::domain_error( - "y0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(y0.size())); - } else if (yp0.size() != n_coeffs) { - throw std::domain_error( - "yp0 has wrong size. Expected " + std::to_string(n_coeffs) + - " but got " + std::to_string(yp0.size())); - } - // set inputs - auto p_inputs = inputs.unchecked<2>(); for (int i = 0; i < functions->inputs.size(); i++) { - functions->inputs[i] = p_inputs(i, 0); + functions->inputs[i] = inputs[i]; } // Setup consistent initialization @@ -543,8 +502,8 @@ Solution IDAKLUSolverOpenMP::solve( PrintStats(); } - int const number_of_timesteps = i_save; - int count; + // store number of timesteps so we can generate the solution later + number_of_timesteps = i_save; // Copy the data to return as numpy arrays @@ -554,23 +513,9 @@ Solution IDAKLUSolverOpenMP::solve( t_return[i] = t[i]; } - py::capsule free_t_when_done( - t_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array t_ret = np_array( - number_of_timesteps, - &t_return[0], - free_t_when_done - ); - // States, y realtype *y_return = new realtype[number_of_timesteps * length_of_return_vector]; - count = 0; + int count = 0; for (size_t i = 0; i < number_of_timesteps; i++) { for (size_t j = 0; j < length_of_return_vector; j++) { y_return[count] = y[i][j]; @@ -578,20 +523,6 @@ Solution IDAKLUSolverOpenMP::solve( } } - py::capsule free_y_when_done( - y_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array y_ret = np_array( - number_of_timesteps * length_of_return_vector, - &y_return[0], - free_y_when_done - ); - // Sensitivity states, yS // Note: Ordering of vector is different if computing outputs vs returning // the complete state vector @@ -614,43 +545,7 @@ Solution IDAKLUSolverOpenMP::solve( } } - py::capsule free_yS_when_done( - yS_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array yS_ret = np_array( - vector { - arg_sens0, - arg_sens1, - arg_sens2 - }, - &yS_return[0], - free_yS_when_done - ); - - // Final state slice, yterm - py::capsule free_yterm_when_done( - yterm_return, - [](void *f) { - realtype *vect = reinterpret_cast(f); - delete[] vect; - } - ); - - np_array y_term = np_array( - length_of_final_sv_slice, - &yterm_return[0], - free_yterm_when_done - ); - - // Store the solution - Solution sol(retval, t_ret, y_ret, yS_ret, y_term); - - return sol; + return SolutionData(retval, number_of_timesteps, length_of_return_vector, arg_sens0, arg_sens1, arg_sens2, length_of_final_sv_slice, t_return, y_return, yS_return, yterm_return); } template @@ -828,9 +723,8 @@ void IDAKLUSolverOpenMP::SetStepOutputSensitivities( template void IDAKLUSolverOpenMP::CheckErrors(int const & flag) { if (flag < 0) { - auto message = (std::string("IDA failed with flag ") + std::to_string(flag)).c_str(); - py::set_error(PyExc_ValueError, message); - throw py::error_already_set(); + auto message = std::string("IDA failed with flag ") + std::to_string(flag); + throw std::runtime_error(message.c_str()); } } diff --git a/src/pybamm/solvers/c_solvers/idaklu/Options.cpp b/src/pybamm/solvers/c_solvers/idaklu/Options.cpp index b6a33e016e..51544040ee 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Options.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Options.cpp @@ -1,6 +1,7 @@ #include "Options.hpp" #include #include +#include using namespace std::string_literals; @@ -11,9 +12,25 @@ SetupOptions::SetupOptions(py::dict &py_opts) precon_half_bandwidth(py_opts["precon_half_bandwidth"].cast()), precon_half_bandwidth_keep(py_opts["precon_half_bandwidth_keep"].cast()), num_threads(py_opts["num_threads"].cast()), + num_solvers(py_opts["num_solvers"].cast()), linear_solver(py_opts["linear_solver"].cast()), linsol_max_iterations(py_opts["linsol_max_iterations"].cast()) { + if (num_solvers > num_threads) + { + throw std::domain_error( + "Number of solvers must be less than or equal to the number of threads" + ); + } + + // input num_threads is the overall number of threads to use. num_solvers of these + // will be used to run solvers in parallel, leaving num_threads / num_solvers threads + // to be used by each solver. From here on num_threads is the number of threads to be used by each solver + num_threads = static_cast( + std::floor( + static_cast(num_threads) / static_cast(num_solvers) + ) + ); using_sparse_matrix = true; using_banded_matrix = false; diff --git a/src/pybamm/solvers/c_solvers/idaklu/Options.hpp b/src/pybamm/solvers/c_solvers/idaklu/Options.hpp index 66a175cfff..d0c0c1d766 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Options.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Options.hpp @@ -15,6 +15,7 @@ struct SetupOptions { int precon_half_bandwidth; int precon_half_bandwidth_keep; int num_threads; + int num_solvers; // IDALS linear solver interface std::string linear_solver; // klu, lapack, spbcg int linsol_max_iterations; diff --git a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp index 72d48fa644..a43e6a7174 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp @@ -9,6 +9,11 @@ class Solution { public: + /** + * @brief Default Constructor + */ + Solution() = default; + /** * @brief Constructor */ @@ -17,6 +22,11 @@ class Solution { } + /** + * @brief Default copy from another Solution + */ + Solution(const Solution &solution) = default; + int flag; np_array t; np_array y; diff --git a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp new file mode 100644 index 0000000000..00c2ddbccc --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp @@ -0,0 +1,67 @@ +#include "SolutionData.hpp" + +Solution SolutionData::generate_solution() { + py::capsule free_t_when_done( + t_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array t_ret = np_array( + number_of_timesteps, + &t_return[0], + free_t_when_done + ); + + py::capsule free_y_when_done( + y_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array y_ret = np_array( + number_of_timesteps * length_of_return_vector, + &y_return[0], + free_y_when_done + ); + + py::capsule free_yS_when_done( + yS_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array yS_ret = np_array( + std::vector { + arg_sens0, + arg_sens1, + arg_sens2 + }, + &yS_return[0], + free_yS_when_done + ); + + // Final state slice, yterm + py::capsule free_yterm_when_done( + yterm_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array y_term = np_array( + length_of_final_sv_slice, + &yterm_return[0], + free_yterm_when_done + ); + + // Store the solution + return Solution(flag, t_ret, y_ret, yS_ret, y_term); +} diff --git a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp new file mode 100644 index 0000000000..815e41daca --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp @@ -0,0 +1,73 @@ +#ifndef PYBAMM_IDAKLU_SOLUTION_DATA_HPP +#define PYBAMM_IDAKLU_SOLUTION_DATA_HPP + + +#include "common.hpp" +#include "Solution.hpp" + +/** + * @brief SolutionData class. Contains all the data needed to create a Solution + */ +class SolutionData +{ + public: + /** + * @brief Default constructor + */ + SolutionData() = default; + + /** + * @brief constructor using fields + */ + SolutionData( + int flag, + int number_of_timesteps, + int length_of_return_vector, + int arg_sens0, + int arg_sens1, + int arg_sens2, + int length_of_final_sv_slice, + realtype *t_return, + realtype *y_return, + realtype *yS_return, + realtype *yterm_return): + flag(flag), + number_of_timesteps(number_of_timesteps), + length_of_return_vector(length_of_return_vector), + arg_sens0(arg_sens0), + arg_sens1(arg_sens1), + arg_sens2(arg_sens2), + length_of_final_sv_slice(length_of_final_sv_slice), + t_return(t_return), + y_return(y_return), + yS_return(yS_return), + yterm_return(yterm_return) + {} + + + /** + * @brief Default copy from another SolutionData + */ + SolutionData(const SolutionData &solution_data) = default; + + /** + * @brief Create a solution object from this data + */ + Solution generate_solution(); + +private: + + int flag; + int number_of_timesteps; + int length_of_return_vector; + int arg_sens0; + int arg_sens1; + int arg_sens2; + int length_of_final_sv_slice; + realtype *t_return; + realtype *y_return; + realtype *yS_return; + realtype *yterm_return; +}; + +#endif // PYBAMM_IDAKLU_SOLUTION_DATA_HPP diff --git a/src/pybamm/solvers/c_solvers/idaklu/common.cpp b/src/pybamm/solvers/c_solvers/idaklu/common.cpp index bf38acc56a..161c14f340 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/common.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu/common.cpp @@ -11,21 +11,9 @@ std::vector numpy2realtype(const np_array& input_np) { return output; } -std::vector setDiff(const std::vector& A, const std::vector& B) { - std::vector result; - if (!(A.empty())) { - std::set_difference(A.begin(), A.end(), B.begin(), B.end(), std::back_inserter(result)); - } - return result; -} -std::vector makeSortedUnique(const std::vector& input) { - std::unordered_set uniqueSet(input.begin(), input.end()); // Remove duplicates - std::vector uniqueVector(uniqueSet.begin(), uniqueSet.end()); // Convert to vector - std::sort(uniqueVector.begin(), uniqueVector.end()); // Sort the vector - return uniqueVector; -} std::vector makeSortedUnique(const np_array& input_np) { - return makeSortedUnique(numpy2realtype(input_np)); + const auto input_vec = numpy2realtype(input_np); + return makeSortedUnique(input_vec.begin(), input_vec.end()); } diff --git a/src/pybamm/solvers/c_solvers/idaklu/common.hpp b/src/pybamm/solvers/c_solvers/idaklu/common.hpp index 3289326541..58be90932e 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/common.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/common.hpp @@ -31,8 +31,8 @@ #include namespace py = pybind11; -using np_array = py::array_t; -using np_array_dense = py::array_t; +// note: we rely on c_style ordering for numpy arrays so don't change this! +using np_array = py::array_t; using np_array_int = py::array_t; /** @@ -83,12 +83,25 @@ std::vector numpy2realtype(const np_array& input_np); /** * @brief Utility function to compute the set difference of two vectors */ -std::vector setDiff(const std::vector& A, const std::vector& B); +template +std::vector setDiff(const T1 a_begin, const T1 a_end, const T2 b_begin, const T2 b_end) { + std::vector result; + if (std::distance(a_begin, a_end) > 0) { + std::set_difference(a_begin, a_end, b_begin, b_end, std::back_inserter(result)); + } + return result; +} /** * @brief Utility function to make a sorted and unique vector */ -std::vector makeSortedUnique(const std::vector& input); +template +std::vector makeSortedUnique(const T input_begin, const T input_end) { + std::unordered_set uniqueSet(input_begin, input_end); // Remove duplicates + std::vector uniqueVector(uniqueSet.begin(), uniqueSet.end()); // Convert to vector + std::sort(uniqueVector.begin(), uniqueVector.end()); // Sort the vector + return uniqueVector; +} std::vector makeSortedUnique(const np_array& input_np); @@ -126,8 +139,7 @@ std::vector makeSortedUnique(const np_array& input_np); } \ std::cout << "]" << std::endl; } -#define DEBUG_v(v, M) {\ - int N = 2; \ +#define DEBUG_v(v, N) {\ std::cout << #v << "[n=" << N << "] = ["; \ for (int i = 0; i < N; i++) { \ std::cout << v[i]; \ diff --git a/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp b/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp index ce1765aa82..dcc1e4f8cc 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/idaklu_solver.hpp @@ -2,6 +2,7 @@ #define PYBAMM_CREATE_IDAKLU_SOLVER_HPP #include "IDAKLUSolverOpenMP_solvers.hpp" +#include "IDAKLUSolverGroup.hpp" #include #include @@ -12,52 +13,21 @@ */ template IDAKLUSolver *create_idaklu_solver( - int number_of_states, + std::unique_ptr functions, int number_of_parameters, - const typename ExprSet::BaseFunctionType &rhs_alg, - const typename ExprSet::BaseFunctionType &jac_times_cjmass, const np_array_int &jac_times_cjmass_colptrs, const np_array_int &jac_times_cjmass_rowvals, const int jac_times_cjmass_nnz, const int jac_bandwidth_lower, const int jac_bandwidth_upper, - const typename ExprSet::BaseFunctionType &jac_action, - const typename ExprSet::BaseFunctionType &mass_action, - const typename ExprSet::BaseFunctionType &sens, - const typename ExprSet::BaseFunctionType &events, const int number_of_events, np_array rhs_alg_id, np_array atol_np, double rel_tol, int inputs_length, - const std::vector& var_fcns, - const std::vector& dvar_dy_fcns, - const std::vector& dvar_dp_fcns, - py::dict py_opts + SolverOptions solver_opts, + SetupOptions setup_opts ) { - auto setup_opts = SetupOptions(py_opts); - auto solver_opts = SolverOptions(py_opts); - auto functions = std::make_unique( - rhs_alg, - jac_times_cjmass, - jac_times_cjmass_nnz, - jac_bandwidth_lower, - jac_bandwidth_upper, - jac_times_cjmass_rowvals, - jac_times_cjmass_colptrs, - inputs_length, - jac_action, - mass_action, - sens, - events, - number_of_states, - number_of_events, - number_of_parameters, - var_fcns, - dvar_dy_fcns, - dvar_dp_fcns, - setup_opts - ); IDAKLUSolver *idakluSolver = nullptr; @@ -189,4 +159,88 @@ IDAKLUSolver *create_idaklu_solver( return idakluSolver; } +/** + * @brief Create a group of solvers using create_idaklu_solver + */ +template +IDAKLUSolverGroup *create_idaklu_solver_group( + int number_of_states, + int number_of_parameters, + const typename ExprSet::BaseFunctionType &rhs_alg, + const typename ExprSet::BaseFunctionType &jac_times_cjmass, + const np_array_int &jac_times_cjmass_colptrs, + const np_array_int &jac_times_cjmass_rowvals, + const int jac_times_cjmass_nnz, + const int jac_bandwidth_lower, + const int jac_bandwidth_upper, + const typename ExprSet::BaseFunctionType &jac_action, + const typename ExprSet::BaseFunctionType &mass_action, + const typename ExprSet::BaseFunctionType &sens, + const typename ExprSet::BaseFunctionType &events, + const int number_of_events, + np_array rhs_alg_id, + np_array atol_np, + double rel_tol, + int inputs_length, + const std::vector& var_fcns, + const std::vector& dvar_dy_fcns, + const std::vector& dvar_dp_fcns, + py::dict py_opts +) { + auto setup_opts = SetupOptions(py_opts); + auto solver_opts = SolverOptions(py_opts); + + + std::vector> solvers; + for (int i = 0; i < setup_opts.num_solvers; i++) { + // Note: we can't copy an ExprSet as it contains raw pointers to the functions + // So we create it in the loop + auto functions = std::make_unique( + rhs_alg, + jac_times_cjmass, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + jac_times_cjmass_rowvals, + jac_times_cjmass_colptrs, + inputs_length, + jac_action, + mass_action, + sens, + events, + number_of_states, + number_of_events, + number_of_parameters, + var_fcns, + dvar_dy_fcns, + dvar_dp_fcns, + setup_opts + ); + solvers.emplace_back( + std::unique_ptr( + create_idaklu_solver( + std::move(functions), + number_of_parameters, + jac_times_cjmass_colptrs, + jac_times_cjmass_rowvals, + jac_times_cjmass_nnz, + jac_bandwidth_lower, + jac_bandwidth_upper, + number_of_events, + rhs_alg_id, + atol_np, + rel_tol, + inputs_length, + solver_opts, + setup_opts + ) + ) + ); + } + + return new IDAKLUSolverGroup(std::move(solvers), number_of_states, number_of_parameters); +} + + + #endif // PYBAMM_CREATE_IDAKLU_SOLVER_HPP diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 08f86b3264..ea3903b139 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -29,7 +29,8 @@ idaklu = importlib.util.module_from_spec(idaklu_spec) if idaklu_spec.loader: idaklu_spec.loader.exec_module(idaklu) - except ImportError: # pragma: no cover + except ImportError as e: # pragma: no cover + print(f"Error loading idaklu: {e}") idaklu_spec = None @@ -78,8 +79,10 @@ class IDAKLUSolver(pybamm.BaseSolver): options = { # Print statistics of the solver after every solve "print_stats": False, - # Number of threads available for OpenMP + # Number of threads available for OpenMP (must be greater than or equal to `num_solvers`) "num_threads": 1, + # Number of solvers to use in parallel (for solving multiple sets of input parameters in parallel) + "num_solvers": num_threads, # Evaluation engine to use for jax, can be 'jax'(native) or 'iree' "jax_evaluator": "jax", ## Linear solver interface @@ -182,6 +185,7 @@ def __init__( "precon_half_bandwidth": 5, "precon_half_bandwidth_keep": 5, "num_threads": 1, + "num_solvers": 1, "jax_evaluator": "jax", "linear_solver": "SUNLinSol_KLU", "linsol_max_iterations": 5, @@ -209,6 +213,8 @@ def __init__( if options is None: options = default_options else: + if "num_threads" in options and "num_solvers" not in options: + options["num_solvers"] = options["num_threads"] for key, value in default_options.items(): if key not in options: options[key] = value @@ -443,7 +449,7 @@ def inputs_to_dict(inputs): if model.convert_to_format == "casadi": # Serialize casadi functions - idaklu_solver_fcn = idaklu.create_casadi_solver + idaklu_solver_fcn = idaklu.create_casadi_solver_group rhs_algebraic = idaklu.generate_function(rhs_algebraic.serialize()) jac_times_cjmass = idaklu.generate_function(jac_times_cjmass.serialize()) jac_rhs_algebraic_action = idaklu.generate_function( @@ -457,7 +463,7 @@ def inputs_to_dict(inputs): and self._options["jax_evaluator"] == "iree" ): # Convert Jax functions to MLIR (also, demote to single precision) - idaklu_solver_fcn = idaklu.create_iree_solver + idaklu_solver_fcn = idaklu.create_iree_solver_group pybamm.demote_expressions_to_32bit = True if pybamm.demote_expressions_to_32bit: warnings.warn( @@ -726,7 +732,11 @@ def _check_mlir_conversion(self, name, mlir: str): def _demote_64_to_32(self, x: pybamm.EvaluatorJax): return pybamm.EvaluatorJax._demote_64_to_32(x) - def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): + @property + def supports_parallel_solve(self): + return True + + def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): """ Solve a DAE model defined by residuals with initial conditions y0. @@ -736,22 +746,30 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): The model whose solution to calculate. t_eval : numeric type The times at which to stop the integration due to a discontinuity in time. - inputs_dict : dict, optional + inputs_list: list of dict, optional Any input parameters to pass to the model when solving. t_interp : None, list or ndarray, optional The times (in seconds) at which to interpolate the solution. Defaults to `None`, which returns the adaptive time-stepping times. """ - inputs_dict = inputs_dict or {} - # stack inputs - if inputs_dict: - arrays_to_stack = [np.array(x).reshape(-1, 1) for x in inputs_dict.values()] - inputs = np.vstack(arrays_to_stack) + inputs_list = inputs_list or [{}] + + # stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters) + if inputs_list and inputs_list[0]: + inputs = np.vstack( + [ + np.hstack([np.array(x).reshape(-1) for x in inputs_dict.values()]) + for inputs_dict in inputs_list + ] + ) else: inputs = np.array([[]]) - y0full = model.y0full - ydot0full = model.ydot0full + # stack y0full and ydot0full so they are a 2D array of shape (number_of_inputs, number_of_states + number_of_parameters * number_of_states) + # note that y0full and ydot0full are currently 1D arrays (i.e. independent of inputs), but in the future we will support + # different initial conditions for different inputs (see https://github.com/pybamm-team/PyBaMM/pull/4260). For now we just repeat the same initial conditions for each input + y0full = np.vstack([model.y0full] * len(inputs_list)) + ydot0full = np.vstack([model.ydot0full] * len(inputs_list)) atol = getattr(model, "atol", self.atol) atol = self._check_atol_type(atol, y0full.size) @@ -761,7 +779,7 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): model.convert_to_format == "jax" and self._options["jax_evaluator"] == "iree" ): - sol = self._setup["solver"].solve( + solns = self._setup["solver"].solve( t_eval, t_interp, y0full, @@ -773,6 +791,12 @@ def _integrate(self, model, t_eval, inputs_dict=None, t_interp=None): raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") integration_time = timer.time() + return [ + self._post_process_solution(soln, model, integration_time, inputs_dict) + for soln, inputs_dict in zip(solns, inputs_list) + ] + + def _post_process_solution(self, sol, model, integration_time, inputs_dict): number_of_sensitivity_parameters = self._setup[ "number_of_sensitivity_parameters" ] diff --git a/src/pybamm/solvers/jax_solver.py b/src/pybamm/solvers/jax_solver.py index da5fd4983a..bfcdef1882 100644 --- a/src/pybamm/solvers/jax_solver.py +++ b/src/pybamm/solvers/jax_solver.py @@ -185,6 +185,10 @@ def solve_model_bdf(inputs): else: return jax.jit(solve_model_bdf) + @property + def supports_parallel_solve(self): + return True + def _integrate(self, model, t_eval, inputs=None, t_interp=None): """ Solve a model defined by dydt with initial conditions y0. @@ -200,7 +204,7 @@ def _integrate(self, model, t_eval, inputs=None, t_interp=None): Returns ------- - object + list of `pybamm.Solution` An object containing the times and values of the solution, as well as various diagnostic messages. @@ -301,6 +305,4 @@ async def solve_model_async(inputs_v): sol.integration_time = integration_time solutions.append(sol) - if len(solutions) == 1: - return solutions[0] return solutions diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index 32e289b3e0..213226bb4c 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -63,6 +63,47 @@ def test_ida_roberts_klu(self): true_solution = 0.1 * solution.t np.testing.assert_array_almost_equal(solution.y[0, :], true_solution) + def test_multiple_inputs(self): + model = pybamm.BaseModel() + var = pybamm.Variable("var") + rate = pybamm.InputParameter("rate") + model.rhs = {var: -rate * var} + model.initial_conditions = {var: 2} + disc = pybamm.Discretisation() + disc.process_model(model) + + for num_threads, num_solvers in [ + [1, None], + [2, None], + [8, None], + [8, 1], + [8, 2], + [8, 7], + ]: + options = {"num_threads": num_threads} + if num_solvers is not None: + options["num_solvers"] = num_solvers + solver = pybamm.IDAKLUSolver(rtol=1e-5, atol=1e-5, options=options) + t_interp = np.linspace(0, 1, 10) + t_eval = [t_interp[0], t_interp[-1]] + ninputs = 8 + inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] + + solutions = solver.solve( + model, t_eval, inputs=inputs_list, t_interp=t_interp + ) + + # check solution + for inputs, solution in zip(inputs_list, solutions): + print("checking solution", inputs, solution.all_inputs) + np.testing.assert_array_equal(solution.t, t_interp) + np.testing.assert_allclose( + solution.y[0], + 2 * np.exp(-inputs["rate"] * solution.t), + atol=1e-4, + rtol=1e-4, + ) + def test_model_events(self): for form in ["casadi", "iree"]: if (form == "iree") and (not pybamm.has_jax() or not pybamm.has_iree()):