From b6fa67a5ca1fcce7a3599c885a11b1479c89aa54 Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+prady0t@users.noreply.github.com> Date: Sat, 28 Sep 2024 15:49:23 +0530 Subject: [PATCH 1/6] Final pytest migration (#4443) * Final pytest migration Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * style: pre-commit fixes * Fixing assert failures Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * converting test_idaklu_jax.py and test_solution.py to pytest Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * style: pre-commit fixes * removinh style failures Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * using single argument Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * Updating docs Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --------- Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/user_guide/installation/index.rst | 2 +- pyproject.toml | 4 +- .../test_simulation_with_experiment.py | 168 +++-- .../test_binary_operators.py | 654 +++++++++--------- .../test_base_battery_model.py | 271 ++++---- .../test_parameter_sets/test_OKane2022.py | 18 +- .../test_parameters_with_default_models.py | 17 +- tests/unit/test_solvers/test_idaklu_jax.py | 135 ++-- tests/unit/test_solvers/test_scipy_solver.py | 55 +- tests/unit/test_solvers/test_solution.py | 157 ++--- 10 files changed, 692 insertions(+), 789 deletions(-) diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index d6411348c5..18e62c5dfa 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -145,8 +145,8 @@ Dependency `pre-commit `__ \- dev For managing and maintaining multi-language pre-commit hooks. `ruff `__ \- dev For code formatting. `nox `__ \- dev For running testing sessions in multiple environments. +`pytest-subtests `__ \- dev For subtests pytest fixture. `pytest-cov `__ \- dev For calculating test coverage. -`parameterized `__ \- dev For test parameterization. `pytest `__ 6.0.0 dev For running the test suites. `pytest-doctestplus `__ \- dev For running doctests. `pytest-xdist `__ \- dev For running tests in parallel across distributed workers. diff --git a/pyproject.toml b/pyproject.toml index 7fb1a5ce95..d2c487ede4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,12 +105,11 @@ dev = [ "pytest-cov", # For doctest "pytest-doctestplus", - # For test parameterization - "parameterized>=0.9", # pytest and its plugins "pytest>=6", "pytest-xdist", "pytest-mock", + "pytest-subtests", # For testing Jupyter notebooks "nbmake", # To access the metadata for python packages @@ -230,6 +229,7 @@ minversion = "8" required_plugins = [ "pytest-xdist", "pytest-mock", + "pytest-subtests", ] norecursedirs = 'pybind11*' addopts = [ diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index 4f981ba04c..394d64b257 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -1,11 +1,11 @@ # # Test setting up a simulation with an experiment # +import pytest import casadi import pybamm import numpy as np import os -import unittest from datetime import datetime @@ -15,7 +15,7 @@ def default_duration(self, value): return 1 -class TestSimulationExperiment(unittest.TestCase): +class TestSimulationExperiment: def test_set_up(self): experiment = pybamm.Experiment( [ @@ -29,24 +29,22 @@ def test_set_up(self): sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertEqual(sim.experiment.args, experiment.args) + assert sim.experiment.args == experiment.args steps = sim.experiment.steps model_I = sim.experiment_unique_steps_to_model[ steps[1].basic_repr() ] # CC charge model_V = sim.experiment_unique_steps_to_model[steps[2].basic_repr()] # CV hold - self.assertIn( - "Current cut-off [A] [experiment]", - [event.name for event in model_V.events], - ) - self.assertIn( - "Charge voltage cut-off [V] [experiment]", - [event.name for event in model_I.events], - ) + assert "Current cut-off [A] [experiment]" in [ + event.name for event in model_V.events + ] + assert "Charge voltage cut-off [V] [experiment]" in [ + event.name for event in model_I.events + ] # fails if trying to set up with something that isn't an experiment - with self.assertRaisesRegex(TypeError, "experiment must be"): + with pytest.raises(TypeError, match="experiment must be"): pybamm.Simulation(model, experiment=0) def test_setup_experiment_string_or_list(self): @@ -54,17 +52,14 @@ def test_setup_experiment_string_or_list(self): sim = pybamm.Simulation(model, experiment="Discharge at C/20 for 1 hour") sim.build_for_experiment() - self.assertEqual(len(sim.experiment.steps), 1) - self.assertEqual( - sim.experiment.steps[0].description, - "Discharge at C/20 for 1 hour", - ) + assert len(sim.experiment.steps) == 1 + assert sim.experiment.steps[0].description == "Discharge at C/20 for 1 hour" sim = pybamm.Simulation( model, experiment=["Discharge at C/20 for 1 hour", pybamm.step.rest(60)], ) sim.build_for_experiment() - self.assertEqual(len(sim.experiment.steps), 2) + assert len(sim.experiment.steps) == 2 def test_run_experiment(self): s = pybamm.step.string @@ -84,8 +79,8 @@ def test_run_experiment(self): sim = pybamm.Simulation(model, experiment=experiment) # test the callback here sol = sim.solve(callbacks=pybamm.callbacks.Callback()) - self.assertEqual(sol.termination, "final time") - self.assertEqual(len(sol.cycles), 1) + assert sol.termination == "final time" + assert len(sol.cycles) == 1 # Test outputs np.testing.assert_array_equal(sol.cycles[0].steps[0]["C-rate"].data, 1 / 20) @@ -128,23 +123,23 @@ def test_run_experiment(self): # Solve again starting from solution sol2 = sim.solve(starting_solution=sol) - self.assertEqual(sol2.termination, "final time") - self.assertGreater(sol2.t[-1], sol.t[-1]) - self.assertEqual(sol2.cycles[0], sol.cycles[0]) - self.assertEqual(len(sol2.cycles), 2) + assert sol2.termination == "final time" + assert sol2.t[-1] > sol.t[-1] + assert sol2.cycles[0] == sol.cycles[0] + assert len(sol2.cycles) == 2 # Solve again starting from solution but only inputting the cycle sol2 = sim.solve(starting_solution=sol.cycles[-1]) - self.assertEqual(sol2.termination, "final time") - self.assertGreater(sol2.t[-1], sol.t[-1]) - self.assertEqual(len(sol2.cycles), 2) + assert sol2.termination == "final time" + assert sol2.t[-1] > sol.t[-1] + assert len(sol2.cycles) == 2 # Check starting solution is unchanged - self.assertEqual(len(sol.cycles), 1) + assert len(sol.cycles) == 1 # save sol2.save("test_experiment.sav") sol3 = pybamm.load("test_experiment.sav") - self.assertEqual(len(sol3.cycles), 2) + assert len(sol3.cycles) == 2 os.remove("test_experiment.sav") def test_run_experiment_multiple_times(self): @@ -168,7 +163,9 @@ def test_run_experiment_multiple_times(self): sol1["Voltage [V]"].data, sol2["Voltage [V]"].data ) - @unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") + @pytest.mark.skipif( + not pybamm.has_idaklu(), reason="idaklu solver is not installed" + ) def test_run_experiment_cccv_solvers(self): experiment_2step = pybamm.Experiment( [ @@ -199,9 +196,11 @@ def test_run_experiment_cccv_solvers(self): solutions[1]["Current [A]"](solutions[0].t), decimal=0, ) - self.assertEqual(solutions[1].termination, "final time") + assert solutions[1].termination == "final time" - @unittest.skipIf(not pybamm.has_idaklu(), "idaklu solver is not installed") + @pytest.mark.skipif( + not pybamm.has_idaklu(), reason="idaklu solver is not installed" + ) def test_solve_with_sensitivities_and_experiment(self): experiment_2step = pybamm.Experiment( [ @@ -276,7 +275,7 @@ def test_solve_with_sensitivities_and_experiment(self): ) / len(sens_casadi) ) - self.assertLess(error, 1.0) + assert error < 1.0 def test_run_experiment_drive_cycle(self): drive_cycle = np.array([np.arange(10), np.arange(10)]).T @@ -292,9 +291,8 @@ def test_run_experiment_drive_cycle(self): model = pybamm.lithium_ion.SPM() sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertEqual( - sorted([step.basic_repr() for step in experiment.steps]), - sorted(list(sim.experiment_unique_steps_to_model.keys())), + assert sorted([step.basic_repr() for step in experiment.steps]) == sorted( + list(sim.experiment_unique_steps_to_model.keys()) ) def test_run_experiment_breaks_early_infeasible(self): @@ -308,7 +306,7 @@ def test_run_experiment_breaks_early_infeasible(self): t_eval, solver=pybamm.CasadiSolver(), callbacks=pybamm.callbacks.Callback() ) pybamm.set_logging_level("WARNING") - self.assertEqual(sim._solution.termination, "event: Minimum voltage [V]") + assert sim._solution.termination == "event: Minimum voltage [V]" def test_run_experiment_breaks_early_error(self): s = pybamm.step.string @@ -331,8 +329,8 @@ def test_run_experiment_breaks_early_error(self): solver=solver, ) sol = sim.solve() - self.assertEqual(len(sol.cycles), 1) - self.assertEqual(len(sol.cycles[0].steps), 1) + assert len(sol.cycles) == 1 + assert len(sol.cycles[0].steps) == 1 # Different experiment setup style experiment = pybamm.Experiment( @@ -348,8 +346,8 @@ def test_run_experiment_breaks_early_error(self): solver=solver, ) sol = sim.solve() - self.assertEqual(len(sol.cycles), 1) - self.assertEqual(len(sol.cycles[0].steps), 1) + assert len(sol.cycles) == 1 + assert len(sol.cycles[0].steps) == 1 # Different callback - this is for coverage on the `Callback` class sol = sim.solve(callbacks=pybamm.callbacks.Callback()) @@ -364,8 +362,8 @@ def test_run_experiment_infeasible_time(self): model, parameter_values=parameter_values, experiment=experiment ) sol = sim.solve() - self.assertEqual(len(sol.cycles), 1) - self.assertEqual(len(sol.cycles[0].steps), 1) + assert len(sol.cycles) == 1 + assert len(sol.cycles[0].steps) == 1 def test_run_experiment_termination_capacity(self): # with percent @@ -442,7 +440,7 @@ def test_run_experiment_termination_voltage(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(4, np.min(sol.cycles[0]["Voltage [V]"].data)) np.testing.assert_array_less(np.min(sol.cycles[1]["Voltage [V]"].data), 4) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_run_experiment_termination_time_min(self): experiment = pybamm.Experiment( @@ -460,7 +458,7 @@ def test_run_experiment_termination_time_min(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(np.max(sol.cycles[0]["Time [s]"].data), 1500) np.testing.assert_array_equal(np.max(sol.cycles[1]["Time [s]"].data), 1500) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_run_experiment_termination_time_s(self): experiment = pybamm.Experiment( @@ -478,7 +476,7 @@ def test_run_experiment_termination_time_s(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(np.max(sol.cycles[0]["Time [s]"].data), 1500) np.testing.assert_array_equal(np.max(sol.cycles[1]["Time [s]"].data), 1500) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_run_experiment_termination_time_h(self): experiment = pybamm.Experiment( @@ -496,7 +494,7 @@ def test_run_experiment_termination_time_h(self): # Only two cycles should be completed, only 2nd cycle should go below 4V np.testing.assert_array_less(np.max(sol.cycles[0]["Time [s]"].data), 1800) np.testing.assert_array_equal(np.max(sol.cycles[1]["Time [s]"].data), 1800) - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 def test_save_at_cycles(self): experiment = pybamm.Experiment( @@ -516,22 +514,22 @@ def test_save_at_cycles(self): ) # Solution saves "None" for the cycles that are not saved for cycle_num in [2, 4, 6, 8]: - self.assertIsNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is None for cycle_num in [0, 1, 3, 5, 7, 9]: - self.assertIsNotNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is not None # Summary variables are not None - self.assertIsNotNone(sol.summary_variables["Capacity [A.h]"]) + assert sol.summary_variables["Capacity [A.h]"] is not None sol = sim.solve( solver=pybamm.CasadiSolver("fast with events"), save_at_cycles=[3, 4, 5, 9] ) # Note offset by 1 (0th cycle is cycle 1) for cycle_num in [1, 5, 6, 7]: - self.assertIsNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is None for cycle_num in [0, 2, 3, 4, 8, 9]: # first & last cycle always saved - self.assertIsNotNone(sol.cycles[cycle_num]) + assert sol.cycles[cycle_num] is not None # Summary variables are not None - self.assertIsNotNone(sol.summary_variables["Capacity [A.h]"]) + assert sol.summary_variables["Capacity [A.h]"] is not None def test_cycle_summary_variables(self): # Test cycle_summary_variables works for different combinations of data and @@ -633,8 +631,8 @@ def test_run_experiment_skip_steps(self): model, parameter_values=parameter_values, experiment=experiment ) sol = sim.solve() - self.assertIsInstance(sol.cycles[0].steps[0], pybamm.EmptySolution) - self.assertIsInstance(sol.cycles[0].steps[3], pybamm.EmptySolution) + assert isinstance(sol.cycles[0].steps[0], pybamm.EmptySolution) + assert isinstance(sol.cycles[0].steps[3], pybamm.EmptySolution) # Should get the same result if we run without the charge steps # since they are skipped @@ -689,9 +687,9 @@ def test_all_empty_solution_errors(self): sim = pybamm.Simulation( model, parameter_values=parameter_values, experiment=experiment ) - with self.assertRaisesRegex( + with pytest.raises( pybamm.SolverError, - "Step 'Charge at 1C until 4.2V' is infeasible due to exceeded bounds", + match="Step 'Charge at 1C until 4.2V' is infeasible due to exceeded bounds", ): sim.solve() @@ -702,7 +700,7 @@ def test_all_empty_solution_errors(self): sim = pybamm.Simulation( model, parameter_values=parameter_values, experiment=experiment ) - with self.assertRaisesRegex(pybamm.SolverError, "All steps in the cycle"): + with pytest.raises(pybamm.SolverError, match="All steps in the cycle"): sim.solve() def test_solver_error(self): @@ -719,7 +717,7 @@ def test_solver_error(self): solver=pybamm.CasadiSolver(mode="fast"), ) - with self.assertRaisesRegex(pybamm.SolverError, "IDA_CONV_FAIL"): + with pytest.raises(pybamm.SolverError, match="IDA_CONV_FAIL"): sim.solve() def test_run_experiment_half_cell(self): @@ -749,9 +747,7 @@ def test_padding_rest_model(self): experiment = pybamm.Experiment(["Rest for 1 hour"]) sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertNotIn( - "Rest for padding", sim.experiment_unique_steps_to_model.keys() - ) + assert "Rest for padding" not in sim.experiment_unique_steps_to_model.keys() # Test padding rest model exists if there are start_times experiment = pybamm.step.string( @@ -759,13 +755,13 @@ def test_padding_rest_model(self): ) sim = pybamm.Simulation(model, experiment=experiment) sim.build_for_experiment() - self.assertIn("Rest for padding", sim.experiment_unique_steps_to_model.keys()) + assert "Rest for padding" in sim.experiment_unique_steps_to_model.keys() # Check at least there is an input parameter (temperature) - self.assertGreater( - len(sim.experiment_unique_steps_to_model["Rest for padding"].parameters), 0 + assert ( + len(sim.experiment_unique_steps_to_model["Rest for padding"].parameters) > 0 ) # Check the model is the same - self.assertIsInstance( + assert isinstance( sim.experiment_unique_steps_to_model["Rest for padding"], pybamm.lithium_ion.SPM, ) @@ -787,7 +783,7 @@ def test_run_start_time_experiment(self): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve(calc_esoh=False) - self.assertEqual(sol["Time [s]"].entries[-1], 5400) + assert sol["Time [s]"].entries[-1] == 5400 # Test padding rest is added if time stamp is late experiment = pybamm.Experiment( @@ -803,7 +799,7 @@ def test_run_start_time_experiment(self): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve(calc_esoh=False) - self.assertEqual(sol["Time [s]"].entries[-1], 10800) + assert sol["Time [s]"].entries[-1] == 10800 def test_starting_solution(self): model = pybamm.lithium_ion.SPM() @@ -820,7 +816,7 @@ def test_starting_solution(self): solution = sim.solve(save_at_cycles=[1]) # test that the last state is correct (i.e. final cycle is saved) - self.assertEqual(solution.last_state.t[-1], 1200) + assert solution.last_state.t[-1] == 1200 experiment = pybamm.Experiment( [ @@ -833,7 +829,7 @@ def test_starting_solution(self): new_solution = sim.solve(calc_esoh=False, starting_solution=solution) # test that the final time is correct (i.e. starting solution correctly set) - self.assertEqual(new_solution["Time [s]"].entries[-1], 3600) + assert new_solution["Time [s]"].entries[-1] == 3600 def test_experiment_start_time_starting_solution(self): model = pybamm.lithium_ion.SPM() @@ -855,7 +851,7 @@ def test_experiment_start_time_starting_solution(self): ) sim = pybamm.Simulation(model, experiment=experiment) - with self.assertRaisesRegex(ValueError, "experiments with `start_time`"): + with pytest.raises(ValueError, match="experiments with `start_time`"): sim.solve(starting_solution=solution) # Test starting_solution works well with start_time @@ -892,7 +888,7 @@ def test_experiment_start_time_starting_solution(self): new_solution = sim.solve(starting_solution=solution) # test that the final time is correct (i.e. starting solution correctly set) - self.assertEqual(new_solution["Time [s]"].entries[-1], 5400) + assert new_solution["Time [s]"].entries[-1] == 5400 def test_experiment_start_time_identical_steps(self): # Test that if we have the same step twice, with different start times, @@ -918,15 +914,15 @@ def test_experiment_start_time_identical_steps(self): sim.solve(calc_esoh=False) # Check that there are 4 steps - self.assertEqual(len(experiment.steps), 4) + assert len(experiment.steps) == 4 # Check that there are only 2 unique steps - self.assertEqual(len(sim.experiment.unique_steps), 2) + assert len(sim.experiment.unique_steps) == 2 # Check that there are only 3 built models (unique steps + padding rest) - self.assertEqual(len(sim.steps_to_built_models), 3) + assert len(sim.steps_to_built_models) == 3 - def test_experiment_custom_steps(self): + def test_experiment_custom_steps(self, subtests): model = pybamm.lithium_ion.SPM() # Explicit control @@ -947,7 +943,7 @@ def custom_step_voltage(variables): return 100 * (variables["Voltage [V]"] - 4.2) for control in ["differential"]: - with self.subTest(control=control): + with subtests.test(control=control): custom_step_alg = pybamm.step.CustomStepImplicit( custom_step_voltage, control=control, duration=100, period=10 ) @@ -974,20 +970,10 @@ def neg_stoich_cutoff(variables): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve(calc_esoh=False) - self.assertEqual( - sol.cycles[0].steps[0].termination, - "event: Negative stoichiometry cut-off [experiment]", + assert ( + sol.cycles[0].steps[0].termination + == "event: Negative stoichiometry cut-off [experiment]" ) neg_stoich = sol["Negative electrode stoichiometry"].data - self.assertAlmostEqual(neg_stoich[-1], 0.5, places=4) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert neg_stoich[-1] == pytest.approx(0.5, abs=0.0001) diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index e1a14206b4..206aa0799c 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -2,8 +2,8 @@ # Tests for the Binary Operator classes # -import unittest -import unittest.mock as mock +import pytest + import numpy as np from scipy.sparse import coo_matrix @@ -19,19 +19,19 @@ } -class TestBinaryOperators(unittest.TestCase): +class TestBinaryOperators: def test_binary_operator(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") bin = pybamm.BinaryOperator("binary test", a, b) - self.assertEqual(bin.children[0].name, a.name) - self.assertEqual(bin.children[1].name, b.name) + assert bin.children[0].name == a.name + assert bin.children[1].name == b.name c = pybamm.Scalar(1) d = pybamm.Scalar(2) bin2 = pybamm.BinaryOperator("binary test", c, d) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): bin2.evaluate() - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): bin2._binary_jac(a, b) def test_binary_operator_domains(self): @@ -39,28 +39,28 @@ def test_binary_operator_domains(self): a = pybamm.Symbol("a", domain=["negative electrode"]) b = pybamm.Symbol("b", domain=["negative electrode"]) bin1 = pybamm.BinaryOperator("binary test", a, b) - self.assertEqual(bin1.domain, ["negative electrode"]) + assert bin1.domain == ["negative electrode"] # one empty domain c = pybamm.Symbol("c", domain=[]) bin2 = pybamm.BinaryOperator("binary test", a, c) - self.assertEqual(bin2.domain, ["negative electrode"]) + assert bin2.domain == ["negative electrode"] bin3 = pybamm.BinaryOperator("binary test", c, b) - self.assertEqual(bin3.domain, ["negative electrode"]) + assert bin3.domain == ["negative electrode"] # mismatched domains d = pybamm.Symbol("d", domain=["positive electrode"]) - with self.assertRaises(pybamm.DomainError): + with pytest.raises(pybamm.DomainError): pybamm.BinaryOperator("binary test", a, d) def test_addition(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") summ = pybamm.Addition(a, b) - self.assertEqual(summ.children[0].name, a.name) - self.assertEqual(summ.children[1].name, b.name) + assert summ.children[0].name == a.name + assert summ.children[1].name == b.name # test simplifying summ2 = pybamm.Scalar(1) + pybamm.Scalar(3) - self.assertEqual(summ2, pybamm.Scalar(4)) + assert summ2 == pybamm.Scalar(4) def test_addition_numpy_array(self): a = pybamm.Symbol("a") @@ -68,32 +68,32 @@ def test_addition_numpy_array(self): # converts numpy array to vector array = np.array([1, 2, 3]) summ3 = pybamm.Addition(a, array) - self.assertIsInstance(summ3, pybamm.Addition) - self.assertIsInstance(summ3.children[0], pybamm.Symbol) - self.assertIsInstance(summ3.children[1], pybamm.Vector) + assert isinstance(summ3, pybamm.Addition) + assert isinstance(summ3.children[0], pybamm.Symbol) + assert isinstance(summ3.children[1], pybamm.Vector) summ4 = array + a - self.assertIsInstance(summ4.children[0], pybamm.Vector) + assert isinstance(summ4.children[0], pybamm.Vector) # should error if numpy array is not 1D array = np.array([[1, 2, 3], [4, 5, 6]]) - with self.assertRaisesRegex(ValueError, "left must be a 1D array"): + with pytest.raises(ValueError, match="left must be a 1D array"): pybamm.Addition(array, a) - with self.assertRaisesRegex(ValueError, "right must be a 1D array"): + with pytest.raises(ValueError, match="right must be a 1D array"): pybamm.Addition(a, array) def test_power(self): a = pybamm.Symbol("a") b = pybamm.Symbol("b") pow1 = pybamm.Power(a, b) - self.assertEqual(pow1.name, "**") - self.assertEqual(pow1.children[0].name, a.name) - self.assertEqual(pow1.children[1].name, b.name) + assert pow1.name == "**" + assert pow1.children[0].name == a.name + assert pow1.children[1].name == b.name a = pybamm.Scalar(4) b = pybamm.Scalar(2) pow2 = pybamm.Power(a, b) - self.assertEqual(pow2.evaluate(), 16) + assert pow2.evaluate() == 16 def test_diff(self): a = pybamm.StateVector(slice(0, 1)) @@ -101,51 +101,51 @@ def test_diff(self): y = np.array([5, 3]) # power - self.assertEqual((a**b).diff(b).evaluate(y=y), 5**3 * np.log(5)) - self.assertEqual((a**b).diff(a).evaluate(y=y), 3 * 5**2) - self.assertEqual((a**b).diff(a**b).evaluate(), 1) - self.assertEqual((a**a).diff(a).evaluate(y=y), 5**5 * np.log(5) + 5 * 5**4) - self.assertEqual((a**a).diff(b).evaluate(y=y), 0) + assert (a**b).diff(b).evaluate(y=y) == 5**3 * np.log(5) + assert (a**b).diff(a).evaluate(y=y) == 3 * 5**2 + assert (a**b).diff(a**b).evaluate() == 1 + assert (a**a).diff(a).evaluate(y=y) == 5**5 * np.log(5) + 5 * 5**4 + assert (a**a).diff(b).evaluate(y=y) == 0 # addition - self.assertEqual((a + b).diff(a).evaluate(), 1) - self.assertEqual((a + b).diff(b).evaluate(), 1) - self.assertEqual((a + b).diff(a + b).evaluate(), 1) - self.assertEqual((a + a).diff(a).evaluate(), 2) - self.assertEqual((a + a).diff(b).evaluate(), 0) + assert (a + b).diff(a).evaluate() == 1 + assert (a + b).diff(b).evaluate() == 1 + assert (a + b).diff(a + b).evaluate() == 1 + assert (a + a).diff(a).evaluate() == 2 + assert (a + a).diff(b).evaluate() == 0 # subtraction - self.assertEqual((a - b).diff(a).evaluate(), 1) - self.assertEqual((a - b).diff(b).evaluate(), -1) - self.assertEqual((a - b).diff(a - b).evaluate(), 1) - self.assertEqual((a - a).diff(a).evaluate(), 0) - self.assertEqual((a + a).diff(b).evaluate(), 0) + assert (a - b).diff(a).evaluate() == 1 + assert (a - b).diff(b).evaluate() == -1 + assert (a - b).diff(a - b).evaluate() == 1 + assert (a - a).diff(a).evaluate() == 0 + assert (a + a).diff(b).evaluate() == 0 # multiplication - self.assertEqual((a * b).diff(a).evaluate(y=y), 3) - self.assertEqual((a * b).diff(b).evaluate(y=y), 5) - self.assertEqual((a * b).diff(a * b).evaluate(y=y), 1) - self.assertEqual((a * a).diff(a).evaluate(y=y), 10) - self.assertEqual((a * a).diff(b).evaluate(y=y), 0) + assert (a * b).diff(a).evaluate(y=y) == 3 + assert (a * b).diff(b).evaluate(y=y) == 5 + assert (a * b).diff(a * b).evaluate(y=y) == 1 + assert (a * a).diff(a).evaluate(y=y) == 10 + assert (a * a).diff(b).evaluate(y=y) == 0 # matrix multiplication (not implemented) matmul = a @ b - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): matmul.diff(a) # inner - self.assertEqual(pybamm.inner(a, b).diff(a).evaluate(y=y), 3) - self.assertEqual(pybamm.inner(a, b).diff(b).evaluate(y=y), 5) - self.assertEqual(pybamm.inner(a, b).diff(pybamm.inner(a, b)).evaluate(y=y), 1) - self.assertEqual(pybamm.inner(a, a).diff(a).evaluate(y=y), 10) - self.assertEqual(pybamm.inner(a, a).diff(b).evaluate(y=y), 0) + assert pybamm.inner(a, b).diff(a).evaluate(y=y) == 3 + assert pybamm.inner(a, b).diff(b).evaluate(y=y) == 5 + assert pybamm.inner(a, b).diff(pybamm.inner(a, b)).evaluate(y=y) == 1 + assert pybamm.inner(a, a).diff(a).evaluate(y=y) == 10 + assert pybamm.inner(a, a).diff(b).evaluate(y=y) == 0 # division - self.assertEqual((a / b).diff(a).evaluate(y=y), 1 / 3) - self.assertEqual((a / b).diff(b).evaluate(y=y), -5 / 9) - self.assertEqual((a / b).diff(a / b).evaluate(y=y), 1) - self.assertEqual((a / a).diff(a).evaluate(y=y), 0) - self.assertEqual((a / a).diff(b).evaluate(y=y), 0) + assert (a / b).diff(a).evaluate(y=y) == 1 / 3 + assert (a / b).diff(b).evaluate(y=y) == -5 / 9 + assert (a / b).diff(a / b).evaluate(y=y) == 1 + assert (a / a).diff(a).evaluate(y=y) == 0 + assert (a / a).diff(b).evaluate(y=y) == 0 def test_printing(self): # This in not an exhaustive list of all cases. More test cases may need to @@ -154,23 +154,23 @@ def test_printing(self): b = pybamm.Parameter("b") c = pybamm.Parameter("c") d = pybamm.Parameter("d") - self.assertEqual(str(a + b), "a + b") - self.assertEqual(str(a + b + c + d), "a + b + c + d") - self.assertEqual(str((a + b) + (c + d)), "a + b + c + d") - self.assertEqual(str(a + b - c), "a + b - c") - self.assertEqual(str(a + b - c + d), "a + b - c + d") - self.assertEqual(str((a + b) - (c + d)), "a + b - (c + d)") - self.assertEqual(str((a + b) - (c - d)), "a + b - (c - d)") - - self.assertEqual(str((a + b) * (c + d)), "(a + b) * (c + d)") - self.assertEqual(str(a * b * (c + d)), "a * b * (c + d)") - self.assertEqual(str((a * b) * (c + d)), "a * b * (c + d)") - self.assertEqual(str(a * (b * (c + d))), "a * b * (c + d)") - self.assertEqual(str((a + b) / (c + d)), "(a + b) / (c + d)") - self.assertEqual(str(a + b / (c + d)), "a + b / (c + d)") - self.assertEqual(str(a * b / (c + d)), "a * b / (c + d)") - self.assertEqual(str((a * b) / (c + d)), "a * b / (c + d)") - self.assertEqual(str(a * (b / (c + d))), "a * b / (c + d)") + assert str(a + b) == "a + b" + assert str(a + b + c + d) == "a + b + c + d" + assert str((a + b) + (c + d)) == "a + b + c + d" + assert str(a + b - c) == "a + b - c" + assert str(a + b - c + d) == "a + b - c + d" + assert str((a + b) - (c + d)) == "a + b - (c + d)" + assert str((a + b) - (c - d)) == "a + b - (c - d)" + + assert str((a + b) * (c + d)) == "(a + b) * (c + d)" + assert str(a * b * (c + d)) == "a * b * (c + d)" + assert str((a * b) * (c + d)) == "a * b * (c + d)" + assert str(a * (b * (c + d))) == "a * b * (c + d)" + assert str((a + b) / (c + d)) == "(a + b) / (c + d)" + assert str(a + b / (c + d)) == "a + b / (c + d)" + assert str(a * b / (c + d)) == "a * b / (c + d)" + assert str((a * b) / (c + d)) == "a * b / (c + d)" + assert str(a * (b / (c + d))) == "a * b / (c + d)" def test_eq(self): a = pybamm.Scalar(4) @@ -178,20 +178,20 @@ def test_eq(self): bin1 = pybamm.BinaryOperator("test", a, b) bin2 = pybamm.BinaryOperator("test", a, b) bin3 = pybamm.BinaryOperator("new test", a, b) - self.assertEqual(bin1, bin2) - self.assertNotEqual(bin1, bin3) + assert bin1 == bin2 + assert bin1 != bin3 c = pybamm.Scalar(5) bin4 = pybamm.BinaryOperator("test", a, c) - self.assertEqual(bin1, bin4) + assert bin1 == bin4 d = pybamm.Scalar(42) bin5 = pybamm.BinaryOperator("test", a, d) - self.assertNotEqual(bin1, bin5) + assert bin1 != bin5 def test_number_overloading(self): a = pybamm.Scalar(4) prod = a * 3 - self.assertIsInstance(prod, pybamm.Scalar) - self.assertEqual(prod.evaluate(), 12) + assert isinstance(prod, pybamm.Scalar) + assert prod.evaluate() == 12 def test_sparse_multiply(self): row = np.array([0, 3, 1, 0]) @@ -225,11 +225,11 @@ def test_sparse_multiply(self): np.testing.assert_array_equal( (pybammS2 * pybammD2).evaluate().toarray(), S2.toarray() * D2 ) - with self.assertRaisesRegex(pybamm.ShapeError, "inconsistent shapes"): + with pytest.raises(pybamm.ShapeError, match="inconsistent shapes"): (pybammS1 * pybammS2).test_shape() - with self.assertRaisesRegex(pybamm.ShapeError, "inconsistent shapes"): + with pytest.raises(pybamm.ShapeError, match="inconsistent shapes"): (pybammS2 * pybammS1).test_shape() - with self.assertRaisesRegex(pybamm.ShapeError, "inconsistent shapes"): + with pytest.raises(pybamm.ShapeError, match="inconsistent shapes"): (pybammS2 * pybammS1).evaluate_ignoring_errors() # Matrix multiplication is normal matrix multiplication @@ -243,9 +243,9 @@ def test_sparse_multiply(self): np.testing.assert_array_equal((pybammD2 @ pybammS1).evaluate(), D2 * S1) np.testing.assert_array_equal((pybammS2 @ pybammD1).evaluate(), S2 * D1) np.testing.assert_array_equal((pybammD1 @ pybammS2).evaluate(), D1 * S2) - with self.assertRaisesRegex(pybamm.ShapeError, "dimension mismatch"): + with pytest.raises(pybamm.ShapeError, match="dimension mismatch"): (pybammS1 @ pybammS1).test_shape() - with self.assertRaisesRegex(pybamm.ShapeError, "dimension mismatch"): + with pytest.raises(pybamm.ShapeError, match="dimension mismatch"): (pybammS2 @ pybammS2).test_shape() def test_sparse_divide(self): @@ -294,133 +294,133 @@ def test_inner(self): disc.process_model(model) # check doesn't evaluate on edges anymore - self.assertEqual(model.variables["inner"].evaluates_on_edges("primary"), False) + assert not model.variables["inner"].evaluates_on_edges("primary") def test_source(self): u = pybamm.Variable("u", domain="current collector") v = pybamm.Variable("v", domain="current collector") source = pybamm.source(u, v) - self.assertIsInstance(source.children[0], pybamm.Mass) + assert isinstance(source.children[0], pybamm.Mass) boundary_source = pybamm.source(u, v, boundary=True) - self.assertIsInstance(boundary_source.children[0], pybamm.BoundaryMass) + assert isinstance(boundary_source.children[0], pybamm.BoundaryMass) def test_source_error(self): # test error with domain not current collector v = pybamm.Vector(np.ones(5), domain="current collector") w = pybamm.Vector(2 * np.ones(3), domain="test") - with self.assertRaisesRegex(pybamm.DomainError, "'source'"): + with pytest.raises(pybamm.DomainError, match="'source'"): pybamm.source(v, w) def test_heaviside(self): b = pybamm.StateVector(slice(0, 1)) heav = 1 < b - self.assertEqual(heav.evaluate(y=np.array([2])), 1) - self.assertEqual(heav.evaluate(y=np.array([1])), 0) - self.assertEqual(heav.evaluate(y=np.array([0])), 0) - self.assertEqual(str(heav), "1.0 < y[0:1]") + assert heav.evaluate(y=np.array([2])) == 1 + assert heav.evaluate(y=np.array([1])) == 0 + assert heav.evaluate(y=np.array([0])) == 0 + assert str(heav) == "1.0 < y[0:1]" heav = 1 >= b - self.assertEqual(heav.evaluate(y=np.array([2])), 0) - self.assertEqual(heav.evaluate(y=np.array([1])), 1) - self.assertEqual(heav.evaluate(y=np.array([0])), 1) - self.assertEqual(str(heav), "y[0:1] <= 1.0") + assert heav.evaluate(y=np.array([2])) == 0 + assert heav.evaluate(y=np.array([1])) == 1 + assert heav.evaluate(y=np.array([0])) == 1 + assert str(heav) == "y[0:1] <= 1.0" # simplifications - self.assertEqual(1 < b + 2, -1 < b) - self.assertEqual(b + 1 > 2, b > 1) + assert (1 < b + 2) == (-1 < b) + assert (b + 1 > 2) == (b > 1) # expression with a subtract expr = 2 * (b < 1) - (b > 3) - self.assertEqual(expr.evaluate(y=np.array([0])), 2) - self.assertEqual(expr.evaluate(y=np.array([2])), 0) - self.assertEqual(expr.evaluate(y=np.array([4])), -1) + assert expr.evaluate(y=np.array([0])) == 2 + assert expr.evaluate(y=np.array([2])) == 0 + assert expr.evaluate(y=np.array([4])) == -1 def test_equality(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) equal = pybamm.Equality(a, b) - self.assertEqual(equal.evaluate(y=np.array([1])), 1) - self.assertEqual(equal.evaluate(y=np.array([2])), 0) - self.assertEqual(str(equal), "1.0 == y[0:1]") - self.assertEqual(equal.diff(b), 0) + assert equal.evaluate(y=np.array([1])) == 1 + assert equal.evaluate(y=np.array([2])) == 0 + assert str(equal) == "1.0 == y[0:1]" + assert equal.diff(b) == 0 def test_sigmoid(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) sigm = pybamm.sigmoid(a, b, 10) - self.assertAlmostEqual(sigm.evaluate(y=np.array([2]))[0, 0], 1) - self.assertEqual(sigm.evaluate(y=np.array([1])), 0.5) - self.assertAlmostEqual(sigm.evaluate(y=np.array([0]))[0, 0], 0) - self.assertEqual(str(sigm), "0.5 + 0.5 * tanh(-10.0 + 10.0 * y[0:1])") + assert sigm.evaluate(y=np.array([2]))[0, 0] == pytest.approx(1) + assert sigm.evaluate(y=np.array([1])) == 0.5 + pytest.approx(sigm.evaluate(y=np.array([0]))[0, 0], abs=0) + assert str(sigm) == "0.5 + 0.5 * tanh(-10.0 + 10.0 * y[0:1])" sigm = pybamm.sigmoid(b, a, 10) - self.assertAlmostEqual(sigm.evaluate(y=np.array([2]))[0, 0], 0) - self.assertEqual(sigm.evaluate(y=np.array([1])), 0.5) - self.assertAlmostEqual(sigm.evaluate(y=np.array([0]))[0, 0], 1) - self.assertEqual(str(sigm), "0.5 + 0.5 * tanh(10.0 - (10.0 * y[0:1]))") + pytest.approx(sigm.evaluate(y=np.array([2]))[0, 0], abs=0) + assert sigm.evaluate(y=np.array([1])) == 0.5 + pytest.approx(sigm.evaluate(y=np.array([0]))[0, 0], abs=1) + assert str(sigm) == "0.5 + 0.5 * tanh(10.0 - (10.0 * y[0:1]))" def test_modulo(self): a = pybamm.StateVector(slice(0, 1)) b = pybamm.Scalar(3) mod = a % b - self.assertEqual(mod.evaluate(y=np.array([4]))[0, 0], 1) - self.assertEqual(mod.evaluate(y=np.array([3]))[0, 0], 0) - self.assertEqual(mod.evaluate(y=np.array([2]))[0, 0], 2) - self.assertAlmostEqual(mod.evaluate(y=np.array([4.3]))[0, 0], 1.3) - self.assertAlmostEqual(mod.evaluate(y=np.array([2.2]))[0, 0], 2.2) - self.assertEqual(str(mod), "y[0:1] mod 3.0") + assert mod.evaluate(y=np.array([4]))[0, 0] == 1 + assert mod.evaluate(y=np.array([3]))[0, 0] == 0 + assert mod.evaluate(y=np.array([2]))[0, 0] == 2 + assert mod.evaluate(y=np.array([4.3]))[0, 0] == pytest.approx(1.3) + assert mod.evaluate(y=np.array([2.2]))[0, 0] == pytest.approx(2.2) + assert str(mod) == "y[0:1] mod 3.0" def test_minimum_maximum(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) minimum = pybamm.minimum(a, b) - self.assertEqual(minimum.evaluate(y=np.array([2])), 1) - self.assertEqual(minimum.evaluate(y=np.array([1])), 1) - self.assertEqual(minimum.evaluate(y=np.array([0])), 0) - self.assertEqual(str(minimum), "minimum(1.0, y[0:1])") + assert minimum.evaluate(y=np.array([2])) == 1 + assert minimum.evaluate(y=np.array([1])) == 1 + assert minimum.evaluate(y=np.array([0])) == 0 + assert str(minimum) == "minimum(1.0, y[0:1])" maximum = pybamm.maximum(a, b) - self.assertEqual(maximum.evaluate(y=np.array([2])), 2) - self.assertEqual(maximum.evaluate(y=np.array([1])), 1) - self.assertEqual(maximum.evaluate(y=np.array([0])), 1) - self.assertEqual(str(maximum), "maximum(1.0, y[0:1])") + assert maximum.evaluate(y=np.array([2])) == 2 + assert maximum.evaluate(y=np.array([1])) == 1 + assert maximum.evaluate(y=np.array([0])) == 1 + assert str(maximum) == "maximum(1.0, y[0:1])" def test_softminus_softplus(self): a = pybamm.Scalar(1) b = pybamm.StateVector(slice(0, 1)) minimum = pybamm.softminus(a, b, 50) - self.assertAlmostEqual(minimum.evaluate(y=np.array([2]))[0, 0], 1) - self.assertAlmostEqual(minimum.evaluate(y=np.array([0]))[0, 0], 0) - self.assertEqual( - str(minimum), "-0.02 * log(1.9287498479639178e-22 + exp(-50.0 * y[0:1]))" + assert minimum.evaluate(y=np.array([2]))[0, 0] == pytest.approx(1) + assert minimum.evaluate(y=np.array([0]))[0, 0] == pytest.approx(0) + assert ( + str(minimum) == "-0.02 * log(1.9287498479639178e-22 + exp(-50.0 * y[0:1]))" ) maximum = pybamm.softplus(a, b, 50) - self.assertAlmostEqual(maximum.evaluate(y=np.array([2]))[0, 0], 2) - self.assertAlmostEqual(maximum.evaluate(y=np.array([0]))[0, 0], 1) - self.assertEqual( - str(maximum)[:20], - "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[:20], + assert maximum.evaluate(y=np.array([2]))[0, 0] == pytest.approx(2) + assert maximum.evaluate(y=np.array([0]))[0, 0] == pytest.approx(1) + assert ( + str(maximum)[:20] + == "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[:20] ) - self.assertEqual( - str(maximum)[-20:], - "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[-20:], + assert ( + str(maximum)[-20:] + == "0.02 * log(5.184705528587072e+21 + exp(50.0 * y[0:1]))"[-20:] ) # Test that smooth min/max are used when the setting is changed pybamm.settings.min_max_mode = "soft" pybamm.settings.min_max_smoothing = 10 - self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.softminus(a, b, 10))) - self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.softplus(a, b, 10))) + assert str(pybamm.minimum(a, b)) == str(pybamm.softminus(a, b, 10)) + assert str(pybamm.maximum(a, b)) == str(pybamm.softplus(a, b, 10)) # But exact min/max should still be used if both variables are constant a = pybamm.Scalar(1) b = pybamm.Scalar(2) - self.assertEqual(str(pybamm.minimum(a, b)), str(a)) - self.assertEqual(str(pybamm.maximum(a, b)), str(b)) + assert str(pybamm.minimum(a, b)) == str(a) + assert str(pybamm.maximum(a, b)) == str(b) # Change setting back for other tests pybamm.settings.set_smoothing_parameters("exact") @@ -430,36 +430,34 @@ def test_smooth_minus_plus(self): b = pybamm.StateVector(slice(0, 1)) minimum = pybamm.smooth_min(a, b, 3000) - self.assertAlmostEqual(minimum.evaluate(y=np.array([2]))[0, 0], 1) - self.assertAlmostEqual(minimum.evaluate(y=np.array([0]))[0, 0], 0) + pytest.approx(minimum.evaluate(y=np.array([2]))[0, 0], abs=1) + pytest.approx(minimum.evaluate(y=np.array([0]))[0, 0], abs=0) maximum = pybamm.smooth_max(a, b, 3000) - self.assertAlmostEqual(maximum.evaluate(y=np.array([2]))[0, 0], 2) - self.assertAlmostEqual(maximum.evaluate(y=np.array([0]))[0, 0], 1) + assert maximum.evaluate(y=np.array([2]))[0, 0] == pytest.approx(2) + assert maximum.evaluate(y=np.array([0]))[0, 0] == pytest.approx(1) minimum = pybamm.smooth_min(a, b, 1) - self.assertEqual( - str(minimum), - "0.5 * (1.0 + y[0:1] - sqrt(1.0 + (1.0 - y[0:1]) ** 2.0))", + assert ( + str(minimum) == "0.5 * (1.0 + y[0:1] - sqrt(1.0 + (1.0 - y[0:1]) ** 2.0))" ) maximum = pybamm.smooth_max(a, b, 1) - self.assertEqual( - str(maximum), - "0.5 * (sqrt(1.0 + (1.0 - y[0:1]) ** 2.0) + 1.0 + y[0:1])", + assert ( + str(maximum) == "0.5 * (sqrt(1.0 + (1.0 - y[0:1]) ** 2.0) + 1.0 + y[0:1])" ) # Test that smooth min/max are used when the setting is changed pybamm.settings.min_max_mode = "smooth" pybamm.settings.min_max_smoothing = 1 - self.assertEqual(str(pybamm.minimum(a, b)), str(pybamm.smooth_min(a, b, 1))) - self.assertEqual(str(pybamm.maximum(a, b)), str(pybamm.smooth_max(a, b, 1))) + assert str(pybamm.minimum(a, b)) == str(pybamm.smooth_min(a, b, 1)) + assert str(pybamm.maximum(a, b)) == str(pybamm.smooth_max(a, b, 1)) pybamm.settings.min_max_smoothing = 3000 a = pybamm.Scalar(1) b = pybamm.Scalar(2) - self.assertEqual(str(pybamm.minimum(a, b)), str(a)) - self.assertEqual(str(pybamm.maximum(a, b)), str(b)) + assert str(pybamm.minimum(a, b)) == str(a) + assert str(pybamm.maximum(a, b)) == str(b) # Change setting back for other tests pybamm.settings.set_smoothing_parameters("exact") @@ -480,134 +478,133 @@ def test_binary_simplifications(self): broad2_edge = pybamm.PrimaryBroadcastToEdges(2, "domain") # power - self.assertEqual((c**0), pybamm.Scalar(1)) - self.assertEqual((0**c), pybamm.Scalar(0)) - self.assertEqual((c**1), c) + assert (c**0) == pybamm.Scalar(1) + assert (0**c) == pybamm.Scalar(0) + assert (c**1) == c # power with broadcasts - self.assertEqual((c**broad2), pybamm.PrimaryBroadcast(c**2, "domain")) - self.assertEqual((broad2**c), pybamm.PrimaryBroadcast(2**c, "domain")) - self.assertEqual( - (broad2 ** pybamm.PrimaryBroadcast(c, "domain")), - pybamm.PrimaryBroadcast(2**c, "domain"), - ) + assert (c**broad2) == pybamm.PrimaryBroadcast(c**2, "domain") + assert (broad2**c) == pybamm.PrimaryBroadcast(2**c, "domain") + assert ( + broad2 ** pybamm.PrimaryBroadcast(c, "domain") + ) == pybamm.PrimaryBroadcast(2**c, "domain") # power with broadcasts to edge - self.assertIsInstance(var**broad2_edge, pybamm.Power) - self.assertEqual((var**broad2_edge).left, var) - self.assertEqual((var**broad2_edge).right, broad2_edge) + assert isinstance(var**broad2_edge, pybamm.Power) + assert (var**broad2_edge).left == var + assert (var**broad2_edge).right == broad2_edge # addition - self.assertEqual(a + b, pybamm.Scalar(1)) - self.assertEqual(b + b, pybamm.Scalar(2)) - self.assertEqual(b + a, pybamm.Scalar(1)) - self.assertEqual(0 + b, pybamm.Scalar(1)) - self.assertEqual(0 + c, c) - self.assertEqual(c + 0, c) + assert a + b == pybamm.Scalar(1) + assert b + b == pybamm.Scalar(2) + assert b + a == pybamm.Scalar(1) + assert 0 + b == pybamm.Scalar(1) + assert 0 + c == c + assert c + 0 == c # addition with subtraction - self.assertEqual(c + (d - c), d) - self.assertEqual((c - d) + d, c) + assert c + (d - c) == d + assert (c - d) + d == c # addition with broadcast zero - self.assertIsInstance((1 + broad0), pybamm.PrimaryBroadcast) + assert isinstance((1 + broad0), pybamm.PrimaryBroadcast) np.testing.assert_array_equal((1 + broad0).child.evaluate(), 1) np.testing.assert_array_equal((1 + broad0).domain, "domain") - self.assertIsInstance((broad0 + 1), pybamm.PrimaryBroadcast) + assert isinstance((broad0 + 1), pybamm.PrimaryBroadcast) np.testing.assert_array_equal((broad0 + 1).child.evaluate(), 1) np.testing.assert_array_equal((broad0 + 1).domain, "domain") # addition with broadcasts - self.assertEqual((c + broad2), pybamm.PrimaryBroadcast(c + 2, "domain")) - self.assertEqual((broad2 + c), pybamm.PrimaryBroadcast(2 + c, "domain")) + assert (c + broad2) == pybamm.PrimaryBroadcast(c + 2, "domain") + assert (broad2 + c) == pybamm.PrimaryBroadcast(2 + c, "domain") # addition with negate - self.assertEqual(c + -d, c - d) - self.assertEqual(-c + d, d - c) + assert c + -d == c - d + assert -c + d == d - c # subtraction - self.assertEqual(a - b, pybamm.Scalar(-1)) - self.assertEqual(b - b, pybamm.Scalar(0)) - self.assertEqual(b - a, pybamm.Scalar(1)) + assert a - b == pybamm.Scalar(-1) + assert b - b == pybamm.Scalar(0) + assert b - a == pybamm.Scalar(1) # subtraction with addition - self.assertEqual(c - (d + c), -d) - self.assertEqual(c - (c - d), d) - self.assertEqual((c + d) - d, c) - self.assertEqual((d + c) - d, c) - self.assertEqual((d - c) - d, -c) + assert c - (d + c) == -d + assert c - (c - d) == d + assert (c + d) - d == c + assert (d + c) - d == c + assert (d - c) - d == -c # subtraction with broadcasts - self.assertEqual((c - broad2), pybamm.PrimaryBroadcast(c - 2, "domain")) - self.assertEqual((broad2 - c), pybamm.PrimaryBroadcast(2 - c, "domain")) + assert (c - broad2) == pybamm.PrimaryBroadcast(c - 2, "domain") + assert (broad2 - c) == pybamm.PrimaryBroadcast(2 - c, "domain") # subtraction from itself - self.assertEqual((c - c), pybamm.Scalar(0)) - self.assertEqual((broad2 - broad2), broad0) + assert (c - c) == pybamm.Scalar(0) + assert (broad2 - broad2) == broad0 # subtraction with negate - self.assertEqual((c - (-d)), c + d) + assert (c - (-d)) == c + d # addition and subtraction with matrix zero - self.assertEqual(b + v, pybamm.Vector(np.ones((10, 1)))) - self.assertEqual(v + b, pybamm.Vector(np.ones((10, 1)))) - self.assertEqual(b - v, pybamm.Vector(np.ones((10, 1)))) - self.assertEqual(v - b, pybamm.Vector(-np.ones((10, 1)))) + assert b + v == pybamm.Vector(np.ones((10, 1))) + assert v + b == pybamm.Vector(np.ones((10, 1))) + assert b - v == pybamm.Vector(np.ones((10, 1))) + assert v - b == pybamm.Vector(-np.ones((10, 1))) # multiplication - self.assertEqual(a * b, pybamm.Scalar(0)) - self.assertEqual(b * a, pybamm.Scalar(0)) - self.assertEqual(b * b, pybamm.Scalar(1)) - self.assertEqual(a * a, pybamm.Scalar(0)) - self.assertEqual(a * c, pybamm.Scalar(0)) - self.assertEqual(c * a, pybamm.Scalar(0)) - self.assertEqual(b * c, c) + assert a * b == pybamm.Scalar(0) + assert b * a == pybamm.Scalar(0) + assert b * b == pybamm.Scalar(1) + assert a * a == pybamm.Scalar(0) + assert a * c == pybamm.Scalar(0) + assert c * a == pybamm.Scalar(0) + assert b * c == c # multiplication with -1 - self.assertEqual((c * -1), (-c)) - self.assertEqual((-1 * c), (-c)) + assert (c * -1) == (-c) + assert (-1 * c) == (-c) # multiplication with a negation - self.assertEqual((-c * -f), (c * f)) - self.assertEqual((-c * 4), (c * -4)) - self.assertEqual((4 * -c), (-4 * c)) + assert (-c * -f) == (c * f) + assert (-c * 4) == (c * -4) + assert (4 * -c) == (-4 * c) # multiplication with division - self.assertEqual((c * (d / c)), d) - self.assertEqual((c / d) * d, c) + assert (c * (d / c)) == d + assert (c / d) * d == c # multiplication with broadcasts - self.assertEqual((c * broad2), pybamm.PrimaryBroadcast(c * 2, "domain")) - self.assertEqual((broad2 * c), pybamm.PrimaryBroadcast(2 * c, "domain")) + assert (c * broad2) == pybamm.PrimaryBroadcast(c * 2, "domain") + assert (broad2 * c) == pybamm.PrimaryBroadcast(2 * c, "domain") # multiplication with matrix zero - self.assertEqual(b * v, pybamm.Vector(np.zeros((10, 1)))) - self.assertEqual(v * b, pybamm.Vector(np.zeros((10, 1)))) + assert b * v == pybamm.Vector(np.zeros((10, 1))) + assert v * b == pybamm.Vector(np.zeros((10, 1))) # multiplication with matrix one - self.assertEqual((f * v1), f) - self.assertEqual((v1 * f), f) + assert (f * v1) == f + assert (v1 * f) == f # multiplication with matrix minus one - self.assertEqual((f * (-v1)), (-f)) - self.assertEqual(((-v1) * f), (-f)) + assert (f * (-v1)) == (-f) + assert ((-v1) * f) == (-f) # multiplication with broadcast - self.assertEqual((var * broad2), (var * 2)) - self.assertEqual((broad2 * var), (2 * var)) + assert (var * broad2) == (var * 2) + assert (broad2 * var) == (2 * var) # multiplication with broadcast one - self.assertEqual((var * broad1), var) - self.assertEqual((broad1 * var), var) + assert (var * broad1) == var + assert (broad1 * var) == var # multiplication with broadcast minus one - self.assertEqual((var * -broad1), (-var)) - self.assertEqual((-broad1 * var), (-var)) + assert (var * -broad1) == (-var) + assert (-broad1 * var) == (-var) # division by itself - self.assertEqual((c / c), pybamm.Scalar(1)) - self.assertEqual((broad2 / broad2), broad1) + assert (c / c) == pybamm.Scalar(1) + assert (broad2 / broad2) == broad1 # division with a negation - self.assertEqual((-c / -f), (c / f)) - self.assertEqual((-c / 4), -0.25 * c) - self.assertEqual((4 / -c), (-4 / c)) + assert (-c / -f) == (c / f) + assert (-c / 4) == -0.25 * c + assert (4 / -c) == (-4 / c) # division with multiplication - self.assertEqual((c * d) / c, d) - self.assertEqual((d * c) / c, d) + assert (c * d) / c == d + assert (d * c) / c == d # division with broadcasts - self.assertEqual((c / broad2), pybamm.PrimaryBroadcast(c / 2, "domain")) - self.assertEqual((broad2 / c), pybamm.PrimaryBroadcast(2 / c, "domain")) + assert (c / broad2) == pybamm.PrimaryBroadcast(c / 2, "domain") + assert (broad2 / c) == pybamm.PrimaryBroadcast(2 / c, "domain") # division with matrix one - self.assertEqual((f / v1), f) - self.assertEqual((f / -v1), (-f)) + assert (f / v1) == f + assert (f / -v1) == (-f) # division by zero - with self.assertRaises(ZeroDivisionError): + with pytest.raises(ZeroDivisionError): b / a # division with a common term - self.assertEqual((2 * c) / (2 * var), (c / var)) - self.assertEqual((c * 2) / (var * 2), (c / var)) + assert (2 * c) / (2 * var) == (c / var) + assert (c * 2) / (var * 2) == (c / var) def test_binary_simplifications_concatenations(self): def conc_broad(x, y, z): @@ -625,10 +622,10 @@ def conc_broad(x, y, z): pybamm.InputParameter("y"), pybamm.InputParameter("z"), ) - self.assertEqual((a + 4), conc_broad(5, 6, 7)) - self.assertEqual((4 + a), conc_broad(5, 6, 7)) - self.assertEqual((a + b), conc_broad(12, 14, 16)) - self.assertIsInstance((a + c), pybamm.Concatenation) + assert (a + 4) == conc_broad(5, 6, 7) + assert (4 + a) == conc_broad(5, 6, 7) + assert (a + b) == conc_broad(12, 14, 16) + assert isinstance((a + c), pybamm.Concatenation) # No simplifications if all are Variable or StateVector objects v = pybamm.concatenation( @@ -636,8 +633,8 @@ def conc_broad(x, y, z): pybamm.Variable("y", "separator"), pybamm.Variable("z", "positive electrode"), ) - self.assertIsInstance((v * v), pybamm.Multiplication) - self.assertIsInstance((a * v), pybamm.Multiplication) + assert isinstance((v * v), pybamm.Multiplication) + assert isinstance((a * v), pybamm.Multiplication) def test_advanced_binary_simplifications(self): # MatMul simplifications that often appear when discretising spatial operators @@ -650,120 +647,120 @@ def test_advanced_binary_simplifications(self): # Do A@B first if it is constant expr = A @ (B @ var) - self.assertEqual(expr, ((A @ B) @ var)) + assert expr == ((A @ B) @ var) # Distribute the @ operator to a sum if one of the symbols being summed is # constant expr = A @ (var + vec) - self.assertEqual(expr, ((A @ var) + (A @ vec))) + assert expr == ((A @ var) + (A @ vec)) expr = A @ (var - vec) - self.assertEqual(expr, ((A @ var) - (A @ vec))) + assert expr == ((A @ var) - (A @ vec)) expr = A @ ((B @ var) + vec) - self.assertEqual(expr, (((A @ B) @ var) + (A @ vec))) + assert expr == (((A @ B) @ var) + (A @ vec)) expr = A @ ((B @ var) - vec) - self.assertEqual(expr, (((A @ B) @ var) - (A @ vec))) + assert expr == (((A @ B) @ var) - (A @ vec)) # Distribute the @ operator to a sum if both symbols being summed are matmuls expr = A @ (B @ var + C @ var2) - self.assertEqual(expr, ((A @ B) @ var + (A @ C) @ var2)) + assert expr == ((A @ B) @ var + (A @ C) @ var2) expr = A @ (B @ var - C @ var2) - self.assertEqual(expr, ((A @ B) @ var - (A @ C) @ var2)) + assert expr == ((A @ B) @ var - (A @ C) @ var2) # Reduce (A@var + B@var) to ((A+B)@var) expr = A @ var + B @ var - self.assertEqual(expr, ((A + B) @ var)) + assert expr == ((A + B) @ var) # Do A*e first if it is constant expr = A @ (5 * var) - self.assertEqual(expr, ((A * 5) @ var)) + assert expr == ((A * 5) @ var) expr = A @ (var * 5) - self.assertEqual(expr, ((A * 5) @ var)) + assert expr == ((A * 5) @ var) # Do A/e first if it is constant expr = A @ (var / 2) - self.assertEqual(expr, ((A / 2) @ var)) + assert expr == ((A / 2) @ var) # Do (vec*A) first if it is constant expr = vec * (A @ var) - self.assertEqual(expr, ((vec * A) @ var)) + assert expr == ((vec * A) @ var) expr = (A @ var) * vec - self.assertEqual(expr, ((vec * A) @ var)) + assert expr == ((vec * A) @ var) # Do (A/vec) first if it is constant expr = (A @ var) / vec - self.assertIsInstance(expr, pybamm.MatrixMultiplication) + assert isinstance(expr, pybamm.MatrixMultiplication) np.testing.assert_array_almost_equal(expr.left.evaluate(), (A / vec).evaluate()) - self.assertEqual(expr.children[1], var) + assert expr.children[1] == var # simplify additions and subtractions expr = 7 + (var + 5) - self.assertEqual(expr, (12 + var)) + assert expr == (12 + var) expr = 7 + (5 + var) - self.assertEqual(expr, (12 + var)) + assert expr == (12 + var) expr = (var + 5) + 7 - self.assertEqual(expr, (var + 12)) + assert expr == (var + 12) expr = (5 + var) + 7 - self.assertEqual(expr, (12 + var)) + assert expr == (12 + var) expr = 7 + (var - 5) - self.assertEqual(expr, (2 + var)) + assert expr == (2 + var) expr = 7 + (5 - var) - self.assertEqual(expr, (12 - var)) + assert expr == (12 - var) expr = (var - 5) + 7 - self.assertEqual(expr, (var + 2)) + assert expr == (var + 2) expr = (5 - var) + 7 - self.assertEqual(expr, (12 - var)) + assert expr == (12 - var) expr = 7 - (var + 5) - self.assertEqual(expr, (2 - var)) + assert expr == (2 - var) expr = 7 - (5 + var) - self.assertEqual(expr, (2 - var)) + assert expr == (2 - var) expr = (var + 5) - 7 - self.assertEqual(expr, (var + -2)) + assert expr == (var + -2) expr = (5 + var) - 7 - self.assertEqual(expr, (-2 + var)) + assert expr == (-2 + var) expr = 7 - (var - 5) - self.assertEqual(expr, (12 - var)) + assert expr == (12 - var) expr = 7 - (5 - var) - self.assertEqual(expr, (2 + var)) + assert expr == (2 + var) expr = (var - 5) - 7 - self.assertEqual(expr, (var - 12)) + assert expr == (var - 12) expr = (5 - var) - 7 - self.assertEqual(expr, (-2 - var)) + assert expr == (-2 - var) expr = var - (var + var2) - self.assertEqual(expr, -var2) + assert expr == -var2 # simplify multiplications and divisions expr = 10 * (var * 5) - self.assertEqual(expr, 50 * var) + assert expr == 50 * var expr = (var * 5) * 10 - self.assertEqual(expr, var * 50) + assert expr == var * 50 expr = 10 * (5 * var) - self.assertEqual(expr, 50 * var) + assert expr == 50 * var expr = (5 * var) * 10 - self.assertEqual(expr, 50 * var) + assert expr == 50 * var expr = 10 * (var / 5) - self.assertEqual(expr, (10 / 5) * var) + assert expr == (10 / 5) * var expr = (var / 5) * 10 - self.assertEqual(expr, var * (10 / 5)) + assert expr == var * (10 / 5) expr = (var * 5) / 10 - self.assertEqual(expr, var * (5 / 10)) + assert expr == var * (5 / 10) expr = (5 * var) / 10 - self.assertEqual(expr, (5 / 10) * var) + assert expr == (5 / 10) * var expr = 5 / (10 * var) - self.assertEqual(expr, (5 / 10) / var) + assert expr == (5 / 10) / var expr = 5 / (var * 10) - self.assertEqual(expr, (5 / 10) / var) + assert expr == (5 / 10) / var expr = (5 / var) / 10 - self.assertEqual(expr, (5 / 10) / var) + assert expr == (5 / 10) / var expr = 5 / (10 / var) - self.assertEqual(expr, (5 / 10) * var) + assert expr == (5 / 10) * var expr = 5 / (var / 10) - self.assertEqual(expr, 50 / var) + assert expr == 50 / var # use power rules on multiplications and divisions expr = (var * 5) ** 2 - self.assertEqual(expr, var**2 * 25) + assert expr == var**2 * 25 expr = (5 * var) ** 2 - self.assertEqual(expr, 25 * var**2) + assert expr == 25 * var**2 expr = (5 / var) ** 2 - self.assertEqual(expr, 25 / var**2) + assert expr == 25 / var**2 def test_inner_simplifications(self): a1 = pybamm.Scalar(0) @@ -776,116 +773,105 @@ def test_inner_simplifications(self): np.testing.assert_array_equal( pybamm.inner(a1, M2).evaluate().toarray(), M1.entries ) - self.assertEqual(pybamm.inner(a1, a2).evaluate(), 0) + assert pybamm.inner(a1, a2).evaluate() == 0 np.testing.assert_array_equal( pybamm.inner(M2, a1).evaluate().toarray(), M1.entries ) - self.assertEqual(pybamm.inner(a2, a1).evaluate(), 0) + assert pybamm.inner(a2, a1).evaluate() == 0 np.testing.assert_array_equal( pybamm.inner(M1, a3).evaluate().toarray(), M1.entries ) np.testing.assert_array_equal(pybamm.inner(v1, a3).evaluate(), 3 * v1.entries) - self.assertEqual(pybamm.inner(a2, a3).evaluate(), 3) - self.assertEqual(pybamm.inner(a3, a2).evaluate(), 3) - self.assertEqual(pybamm.inner(a3, a3).evaluate(), 9) + assert pybamm.inner(a2, a3).evaluate() == 3 + assert pybamm.inner(a3, a2).evaluate() == 3 + assert pybamm.inner(a3, a3).evaluate() == 9 def test_to_equation(self): # Test print_name pybamm.Addition.print_name = "test" - self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.Symbol("test")) + assert pybamm.Addition(1, 2).to_equation() == sympy.Symbol("test") # Test Power - self.assertEqual(pybamm.Power(7, 2).to_equation(), 49) + assert pybamm.Power(7, 2).to_equation() == 49 # Test Division - self.assertEqual(pybamm.Division(10, 2).to_equation(), 5) + assert pybamm.Division(10, 2).to_equation() == 5 # Test Matrix Multiplication arr1 = pybamm.Array([[1, 0], [0, 1]]) arr2 = pybamm.Array([[4, 1], [2, 2]]) - self.assertEqual( - pybamm.MatrixMultiplication(arr1, arr2).to_equation(), - sympy.Matrix([[4.0, 1.0], [2.0, 2.0]]), + assert pybamm.MatrixMultiplication(arr1, arr2).to_equation() == sympy.Matrix( + [[4.0, 1.0], [2.0, 2.0]] ) # Test EqualHeaviside - self.assertEqual(pybamm.EqualHeaviside(1, 0).to_equation(), False) + assert not pybamm.EqualHeaviside(1, 0).to_equation() # Test NotEqualHeaviside - self.assertEqual(pybamm.NotEqualHeaviside(2, 4).to_equation(), True) + assert pybamm.NotEqualHeaviside(2, 4).to_equation() - def test_to_json(self): + def test_to_json(self, mocker): # Test Addition add_json = { "name": "+", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } add = pybamm.Addition(2, 4) - self.assertEqual(add.to_json(), add_json) + assert add.to_json() == add_json add_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] - self.assertEqual(pybamm.Addition._from_json(add_json), add) + assert pybamm.Addition._from_json(add_json) == add # Test Power pow_json = { "name": "**", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } pow = pybamm.Power(7, 2) - self.assertEqual(pow.to_json(), pow_json) + assert pow.to_json() == pow_json pow_json["children"] = [pybamm.Scalar(7), pybamm.Scalar(2)] - self.assertEqual(pybamm.Power._from_json(pow_json), pow) + assert pybamm.Power._from_json(pow_json) == pow # Test Division div_json = { "name": "/", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } div = pybamm.Division(10, 5) - self.assertEqual(div.to_json(), div_json) + assert div.to_json() == div_json div_json["children"] = [pybamm.Scalar(10), pybamm.Scalar(5)] - self.assertEqual(pybamm.Division._from_json(div_json), div) + assert pybamm.Division._from_json(div_json) == div # Test EqualHeaviside equal_json = { "name": "<=", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } equal_h = pybamm.EqualHeaviside(2, 4) - self.assertEqual(equal_h.to_json(), equal_json) + assert equal_h.to_json() == equal_json equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] - self.assertEqual(pybamm.EqualHeaviside._from_json(equal_json), equal_h) + assert pybamm.EqualHeaviside._from_json(equal_json) == equal_h # Test notEqualHeaviside not_equal_json = { "name": "<", - "id": mock.ANY, + "id": mocker.ANY, "domains": EMPTY_DOMAINS, } ne_h = pybamm.NotEqualHeaviside(2, 4) - self.assertEqual(ne_h.to_json(), not_equal_json) + assert ne_h.to_json() == not_equal_json not_equal_json["children"] = [pybamm.Scalar(2), pybamm.Scalar(4)] - self.assertEqual(pybamm.NotEqualHeaviside._from_json(not_equal_json), ne_h) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert pybamm.NotEqualHeaviside._from_json(not_equal_json) == ne_h diff --git a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py index 033dcf5345..95d0b53a64 100644 --- a/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py +++ b/tests/unit/test_models/test_full_battery_models/test_base_battery_model.py @@ -2,9 +2,9 @@ # Tests for the base battery model class # +import pytest from pybamm.models.full_battery_models.base_battery_model import BatteryModelOptions import pybamm -import unittest import io from contextlib import redirect_stdout import os @@ -56,7 +56,7 @@ """ -class TestBaseBatteryModel(unittest.TestCase): +class TestBaseBatteryModel: def test_process_parameters_and_discretise(self): model = pybamm.lithium_ion.SPM() # Set up geometry and parameters @@ -72,9 +72,9 @@ def test_process_parameters_and_discretise(self): * model.variables["X-averaged negative particle concentration [mol.m-3]"] ) processed_c = model.process_parameters_and_discretise(c, parameter_values, disc) - self.assertIsInstance(processed_c, pybamm.Multiplication) - self.assertIsInstance(processed_c.left, pybamm.Scalar) - self.assertIsInstance(processed_c.right, pybamm.StateVector) + assert isinstance(processed_c, pybamm.Multiplication) + assert isinstance(processed_c.left, pybamm.Scalar) + assert isinstance(processed_c.right, pybamm.StateVector) # Process flux manually and check result against flux computed in particle # submodel c_n = model.variables["X-averaged negative particle concentration [mol.m-3]"] @@ -89,47 +89,39 @@ def test_process_parameters_and_discretise(self): flux_2 = model.variables["X-averaged negative particle flux [mol.m-2.s-1]"] param_flux_2 = parameter_values.process_symbol(flux_2) disc_flux_2 = disc.process_symbol(param_flux_2) - self.assertEqual(flux_1, disc_flux_2) + assert flux_1 == disc_flux_2 def test_summary_variables(self): model = pybamm.BaseBatteryModel() model.variables["var"] = pybamm.Scalar(1) model.summary_variables = ["var"] - self.assertEqual(model.summary_variables, ["var"]) - with self.assertRaisesRegex(KeyError, "No cycling variable defined"): + assert model.summary_variables == ["var"] + with pytest.raises(KeyError, match="No cycling variable defined"): model.summary_variables = ["bad var"] def test_default_geometry(self): model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertEqual( - model.default_geometry["current collector"]["z"]["position"], 1 - ) + assert model.default_geometry["current collector"]["z"]["position"] == 1 model = pybamm.BaseBatteryModel({"dimensionality": 1}) - self.assertEqual(model.default_geometry["current collector"]["z"]["min"], 0) + assert model.default_geometry["current collector"]["z"]["min"] == 0 model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertEqual(model.default_geometry["current collector"]["y"]["min"], 0) + assert model.default_geometry["current collector"]["y"]["min"] == 0 def test_default_submesh_types(self): model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertTrue( - issubclass( - model.default_submesh_types["current collector"], - pybamm.SubMesh0D, - ) + assert issubclass( + model.default_submesh_types["current collector"], + pybamm.SubMesh0D, ) model = pybamm.BaseBatteryModel({"dimensionality": 1}) - self.assertTrue( - issubclass( - model.default_submesh_types["current collector"], - pybamm.Uniform1DSubMesh, - ) + assert issubclass( + model.default_submesh_types["current collector"], + pybamm.Uniform1DSubMesh, ) model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertTrue( - issubclass( - model.default_submesh_types["current collector"], - pybamm.ScikitUniform2DSubMesh, - ) + assert issubclass( + model.default_submesh_types["current collector"], + pybamm.ScikitUniform2DSubMesh, ) def test_default_var_pts(self): @@ -149,44 +141,44 @@ def test_default_var_pts(self): "R_p": 30, } model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertDictEqual(var_pts, model.default_var_pts) + assert var_pts == model.default_var_pts var_pts.update({"x_n": 10, "x_s": 10, "x_p": 10}) model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertDictEqual(var_pts, model.default_var_pts) + assert var_pts == model.default_var_pts def test_default_spatial_methods(self): model = pybamm.BaseBatteryModel({"dimensionality": 0}) - self.assertIsInstance( + assert isinstance( model.default_spatial_methods["current collector"], pybamm.ZeroDimensionalSpatialMethod, ) model = pybamm.BaseBatteryModel({"dimensionality": 1}) - self.assertIsInstance( + assert isinstance( model.default_spatial_methods["current collector"], pybamm.FiniteVolume ) model = pybamm.BaseBatteryModel({"dimensionality": 2}) - self.assertIsInstance( + assert isinstance( model.default_spatial_methods["current collector"], pybamm.ScikitFiniteElement, ) def test_options(self): - with self.assertRaisesRegex(pybamm.OptionError, "Option"): + with pytest.raises(pybamm.OptionError, match="Option"): pybamm.BaseBatteryModel({"bad option": "bad option"}) - with self.assertRaisesRegex(pybamm.OptionError, "current collector model"): + with pytest.raises(pybamm.OptionError, match="current collector model"): pybamm.BaseBatteryModel({"current collector": "bad current collector"}) - with self.assertRaisesRegex(pybamm.OptionError, "thermal"): + with pytest.raises(pybamm.OptionError, match="thermal"): pybamm.BaseBatteryModel({"thermal": "bad thermal"}) - with self.assertRaisesRegex(pybamm.OptionError, "cell geometry"): + with pytest.raises(pybamm.OptionError, match="cell geometry"): pybamm.BaseBatteryModel({"cell geometry": "bad geometry"}) - with self.assertRaisesRegex(pybamm.OptionError, "dimensionality"): + with pytest.raises(pybamm.OptionError, match="dimensionality"): pybamm.BaseBatteryModel({"dimensionality": 5}) - with self.assertRaisesRegex(pybamm.OptionError, "current collector"): + with pytest.raises(pybamm.OptionError, match="current collector"): pybamm.BaseBatteryModel( {"dimensionality": 1, "current collector": "bad option"} ) - with self.assertRaisesRegex(pybamm.OptionError, "1D current collectors"): + with pytest.raises(pybamm.OptionError, match="1D current collectors"): pybamm.BaseBatteryModel( { "current collector": "potential pair", @@ -194,7 +186,7 @@ def test_options(self): "thermal": "x-full", } ) - with self.assertRaisesRegex(pybamm.OptionError, "2D current collectors"): + with pytest.raises(pybamm.OptionError, match="2D current collectors"): pybamm.BaseBatteryModel( { "current collector": "potential pair", @@ -202,58 +194,54 @@ def test_options(self): "thermal": "x-full", } ) - with self.assertRaisesRegex(pybamm.OptionError, "surface form"): + with pytest.raises(pybamm.OptionError, match="surface form"): pybamm.BaseBatteryModel({"surface form": "bad surface form"}) - with self.assertRaisesRegex(pybamm.OptionError, "convection"): + with pytest.raises(pybamm.OptionError, match="convection"): pybamm.BaseBatteryModel({"convection": "bad convection"}) - with self.assertRaisesRegex( - pybamm.OptionError, "cannot have transverse convection in 0D model" + with pytest.raises( + pybamm.OptionError, match="cannot have transverse convection in 0D model" ): pybamm.BaseBatteryModel({"convection": "full transverse"}) - with self.assertRaisesRegex(pybamm.OptionError, "particle"): + with pytest.raises(pybamm.OptionError, match="particle"): pybamm.BaseBatteryModel({"particle": "bad particle"}) - with self.assertRaisesRegex(pybamm.OptionError, "working electrode"): + with pytest.raises(pybamm.OptionError, match="working electrode"): pybamm.BaseBatteryModel({"working electrode": "bad working electrode"}) - with self.assertRaisesRegex(pybamm.OptionError, "The 'negative' working"): + with pytest.raises(pybamm.OptionError, match="The 'negative' working"): pybamm.BaseBatteryModel({"working electrode": "negative"}) - with self.assertRaisesRegex(pybamm.OptionError, "particle shape"): + with pytest.raises(pybamm.OptionError, match="particle shape"): pybamm.BaseBatteryModel({"particle shape": "bad particle shape"}) - with self.assertRaisesRegex(pybamm.OptionError, "operating mode"): + with pytest.raises(pybamm.OptionError, match="operating mode"): pybamm.BaseBatteryModel({"operating mode": "bad operating mode"}) - with self.assertRaisesRegex(pybamm.OptionError, "electrolyte conductivity"): + with pytest.raises(pybamm.OptionError, match="electrolyte conductivity"): pybamm.BaseBatteryModel( {"electrolyte conductivity": "bad electrolyte conductivity"} ) # SEI options - with self.assertRaisesRegex(pybamm.OptionError, "SEI"): + with pytest.raises(pybamm.OptionError, match="SEI"): pybamm.BaseBatteryModel({"SEI": "bad sei"}) - with self.assertRaisesRegex(pybamm.OptionError, "SEI film resistance"): + with pytest.raises(pybamm.OptionError, match="SEI film resistance"): pybamm.BaseBatteryModel({"SEI film resistance": "bad SEI film resistance"}) - with self.assertRaisesRegex(pybamm.OptionError, "SEI porosity change"): + with pytest.raises(pybamm.OptionError, match="SEI porosity change"): pybamm.BaseBatteryModel({"SEI porosity change": "bad SEI porosity change"}) # changing defaults based on other options model = pybamm.BaseBatteryModel() - self.assertEqual(model.options["SEI film resistance"], "none") + assert model.options["SEI film resistance"] == "none" model = pybamm.BaseBatteryModel({"SEI": "constant"}) - self.assertEqual(model.options["SEI film resistance"], "distributed") - self.assertEqual( - model.options["total interfacial current density as a state"], "true" - ) + assert model.options["SEI film resistance"] == "distributed" + assert model.options["total interfacial current density as a state"] == "true" model = pybamm.BaseBatteryModel( {"SEI film resistance": "average", "particle phases": "2"} ) - self.assertEqual( - model.options["total interfacial current density as a state"], "true" - ) - with self.assertRaisesRegex(pybamm.OptionError, "must be 'true'"): + assert model.options["total interfacial current density as a state"] == "true" + with pytest.raises(pybamm.OptionError, match="must be 'true'"): pybamm.BaseBatteryModel( { "SEI film resistance": "distributed", "total interfacial current density as a state": "false", } ) - with self.assertRaisesRegex(pybamm.OptionError, "must be 'true'"): + with pytest.raises(pybamm.OptionError, match="must be 'true'"): pybamm.BaseBatteryModel( { "SEI film resistance": "average", @@ -263,9 +251,9 @@ def test_options(self): ) # loss of active material model - with self.assertRaisesRegex(pybamm.OptionError, "loss of active material"): + with pytest.raises(pybamm.OptionError, match="loss of active material"): pybamm.BaseBatteryModel({"loss of active material": "bad LAM model"}) - with self.assertRaisesRegex(pybamm.OptionError, "loss of active material"): + with pytest.raises(pybamm.OptionError, match="loss of active material"): # can't have a 3-tuple pybamm.BaseBatteryModel( { @@ -281,11 +269,11 @@ def test_options(self): model = pybamm.BaseBatteryModel( {"loss of active material": "stress-driven", "SEI on cracks": "true"} ) - self.assertEqual( - model.options["particle mechanics"], - ("swelling and cracking", "swelling only"), + assert model.options["particle mechanics"] == ( + "swelling and cracking", + "swelling only", ) - self.assertEqual(model.options["stress-induced diffusion"], "true") + assert model.options["stress-induced diffusion"] == "true" model = pybamm.BaseBatteryModel( { "working electrode": "positive", @@ -293,29 +281,27 @@ def test_options(self): "SEI on cracks": "true", } ) - self.assertEqual(model.options["particle mechanics"], "swelling and cracking") - self.assertEqual(model.options["stress-induced diffusion"], "true") + assert model.options["particle mechanics"] == "swelling and cracking" + assert model.options["stress-induced diffusion"] == "true" # crack model - with self.assertRaisesRegex(pybamm.OptionError, "particle mechanics"): + with pytest.raises(pybamm.OptionError, match="particle mechanics"): pybamm.BaseBatteryModel({"particle mechanics": "bad particle cracking"}) - with self.assertRaisesRegex(pybamm.OptionError, "particle cracking"): + with pytest.raises(pybamm.OptionError, match="particle cracking"): pybamm.BaseBatteryModel({"particle cracking": "bad particle cracking"}) # SEI on cracks - with self.assertRaisesRegex(pybamm.OptionError, "SEI on cracks"): + with pytest.raises(pybamm.OptionError, match="SEI on cracks"): pybamm.BaseBatteryModel({"SEI on cracks": "bad SEI on cracks"}) - with self.assertRaisesRegex(pybamm.OptionError, "'SEI on cracks' is 'true'"): + with pytest.raises(pybamm.OptionError, match="'SEI on cracks' is 'true'"): pybamm.BaseBatteryModel( {"SEI on cracks": "true", "particle mechanics": "swelling only"} ) # plating model - with self.assertRaisesRegex(pybamm.OptionError, "lithium plating"): + with pytest.raises(pybamm.OptionError, match="lithium plating"): pybamm.BaseBatteryModel({"lithium plating": "bad plating"}) - with self.assertRaisesRegex( - pybamm.OptionError, "lithium plating porosity change" - ): + with pytest.raises(pybamm.OptionError, match="lithium plating porosity change"): pybamm.BaseBatteryModel( { "lithium plating porosity change": "bad lithium " @@ -324,16 +310,16 @@ def test_options(self): ) # contact resistance - with self.assertRaisesRegex(pybamm.OptionError, "contact resistance"): + with pytest.raises(pybamm.OptionError, match="contact resistance"): pybamm.BaseBatteryModel({"contact resistance": "bad contact resistance"}) - with self.assertRaisesRegex(NotImplementedError, "Contact resistance not yet"): + with pytest.raises(NotImplementedError, match="Contact resistance not yet"): pybamm.BaseBatteryModel( { "contact resistance": "true", "operating mode": "explicit power", } ) - with self.assertRaisesRegex(NotImplementedError, "Contact resistance not yet"): + with pytest.raises(NotImplementedError, match="Contact resistance not yet"): pybamm.BaseBatteryModel( { "contact resistance": "true", @@ -342,29 +328,29 @@ def test_options(self): ) # stress-induced diffusion - with self.assertRaisesRegex(pybamm.OptionError, "cannot have stress"): + with pytest.raises(pybamm.OptionError, match="cannot have stress"): pybamm.BaseBatteryModel({"stress-induced diffusion": "true"}) # hydrolysis - with self.assertRaisesRegex(pybamm.OptionError, "surface formulation"): + with pytest.raises(pybamm.OptionError, match="surface formulation"): pybamm.lead_acid.LOQS({"hydrolysis": "true", "surface form": "false"}) # timescale - with self.assertRaisesRegex(pybamm.OptionError, "timescale"): + with pytest.raises(pybamm.OptionError, match="timescale"): pybamm.BaseBatteryModel({"timescale": "bad timescale"}) # thermal x-lumped - with self.assertRaisesRegex(pybamm.OptionError, "x-lumped"): + with pytest.raises(pybamm.OptionError, match="x-lumped"): pybamm.lithium_ion.BaseModel( {"cell geometry": "arbitrary", "thermal": "x-lumped"} ) # thermal half-cell - with self.assertRaisesRegex(pybamm.OptionError, "X-full"): + with pytest.raises(pybamm.OptionError, match="X-full"): pybamm.BaseBatteryModel( {"thermal": "x-full", "working electrode": "positive"} ) - with self.assertRaisesRegex(pybamm.OptionError, "X-lumped"): + with pytest.raises(pybamm.OptionError, match="X-lumped"): pybamm.BaseBatteryModel( { "dimensionality": 2, @@ -374,7 +360,7 @@ def test_options(self): ) # thermal heat of mixing - with self.assertRaisesRegex(NotImplementedError, "Heat of mixing"): + with pytest.raises(NotImplementedError, match="Heat of mixing"): pybamm.BaseBatteryModel( { "heat of mixing": "true", @@ -383,35 +369,35 @@ def test_options(self): ) # surface thermal model - with self.assertRaisesRegex(pybamm.OptionError, "surface temperature"): + with pytest.raises(pybamm.OptionError, match="surface temperature"): pybamm.BaseBatteryModel( {"surface temperature": "lumped", "thermal": "x-full"} ) # phases - with self.assertRaisesRegex(pybamm.OptionError, "multiple particle phases"): + with pytest.raises(pybamm.OptionError, match="multiple particle phases"): pybamm.BaseBatteryModel({"particle phases": "2", "surface form": "false"}) # msmr - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel({"open-circuit potential": "MSMR"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel({"particle": "MSMR"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel({"intercalation kinetics": "MSMR"}) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( {"open-circuit potential": "MSMR", "particle": "MSMR"} ) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( {"open-circuit potential": "MSMR", "intercalation kinetics": "MSMR"} ) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( {"particle": "MSMR", "intercalation kinetics": "MSMR"} ) - with self.assertRaisesRegex(pybamm.OptionError, "MSMR"): + with pytest.raises(pybamm.OptionError, match="MSMR"): pybamm.BaseBatteryModel( { "open-circuit potential": "MSMR", @@ -423,7 +409,7 @@ def test_options(self): def test_build_twice(self): model = pybamm.lithium_ion.SPM() # need to pick a model to set vars and build - with self.assertRaisesRegex(pybamm.ModelError, "Model already built"): + with pytest.raises(pybamm.ModelError, match="Model already built"): model.build_model() def test_get_coupled_variables(self): @@ -431,37 +417,37 @@ def test_get_coupled_variables(self): model.submodels["current collector"] = pybamm.current_collector.Uniform( model.param ) - with self.assertRaisesRegex(pybamm.ModelError, "Missing variable"): + with pytest.raises(pybamm.ModelError, match="Missing variable"): model.build_model() def test_default_solver(self): model = pybamm.BaseBatteryModel() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) + assert isinstance(model.default_solver, pybamm.CasadiSolver) # check that default_solver gives you a new solver, not an internal object solver = model.default_solver solver = pybamm.BaseModel() - self.assertIsInstance(model.default_solver, pybamm.CasadiSolver) - self.assertIsInstance(solver, pybamm.BaseModel) + assert isinstance(model.default_solver, pybamm.CasadiSolver) + assert isinstance(solver, pybamm.BaseModel) # check that adding algebraic variables gives algebraic solver a = pybamm.Variable("a") model.algebraic = {a: a - 1} - self.assertIsInstance(model.default_solver, pybamm.CasadiAlgebraicSolver) + assert isinstance(model.default_solver, pybamm.CasadiAlgebraicSolver) def test_option_type(self): # no entry gets default options model = pybamm.BaseBatteryModel() - self.assertIsInstance(model.options, pybamm.BatteryModelOptions) + assert isinstance(model.options, pybamm.BatteryModelOptions) # dict options get converted to BatteryModelOptions model = pybamm.BaseBatteryModel({"thermal": "isothermal"}) - self.assertIsInstance(model.options, pybamm.BatteryModelOptions) + assert isinstance(model.options, pybamm.BatteryModelOptions) # special dict types are not changed options = pybamm.FuzzyDict({"thermal": "isothermal"}) model = pybamm.BaseBatteryModel(options) - self.assertEqual(model.options, options) + assert model.options == options def test_save_load_model(self): model = pybamm.lithium_ion.SPM() @@ -479,7 +465,7 @@ def test_save_load_model(self): ) # raises error if variables are saved without mesh - with self.assertRaises(ValueError): + with pytest.raises(ValueError): model.save_model( filename="test_base_battery_model", variables=model.variables ) @@ -487,71 +473,56 @@ def test_save_load_model(self): os.remove("test_base_battery_model.json") -class TestOptions(unittest.TestCase): +class TestOptions: def test_print_options(self): with io.StringIO() as buffer, redirect_stdout(buffer): BatteryModelOptions(OPTIONS_DICT).print_options() output = buffer.getvalue() - self.assertEqual(output, PRINT_OPTIONS_OUTPUT) + assert output == PRINT_OPTIONS_OUTPUT def test_option_phases(self): options = BatteryModelOptions({}) - self.assertEqual( - options.phases, {"negative": ["primary"], "positive": ["primary"]} - ) + assert options.phases == {"negative": ["primary"], "positive": ["primary"]} options = BatteryModelOptions({"particle phases": ("1", "2")}) - self.assertEqual( - options.phases, - {"negative": ["primary"], "positive": ["primary", "secondary"]}, - ) + assert options.phases == { + "negative": ["primary"], + "positive": ["primary", "secondary"], + } def test_domain_options(self): options = BatteryModelOptions( {"particle": ("Fickian diffusion", "quadratic profile")} ) - self.assertEqual(options.negative["particle"], "Fickian diffusion") - self.assertEqual(options.positive["particle"], "quadratic profile") + assert options.negative["particle"] == "Fickian diffusion" + assert options.positive["particle"] == "quadratic profile" # something that is the same in both domains - self.assertEqual(options.negative["thermal"], "isothermal") - self.assertEqual(options.positive["thermal"], "isothermal") + assert options.negative["thermal"] == "isothermal" + assert options.positive["thermal"] == "isothermal" def test_domain_phase_options(self): options = BatteryModelOptions( {"particle mechanics": (("swelling only", "swelling and cracking"), "none")} ) - self.assertEqual( - options.negative["particle mechanics"], - ("swelling only", "swelling and cracking"), + assert options.negative["particle mechanics"] == ( + "swelling only", + "swelling and cracking", ) - self.assertEqual( - options.negative.primary["particle mechanics"], "swelling only" + assert options.negative.primary["particle mechanics"] == "swelling only" + assert ( + options.negative.secondary["particle mechanics"] == "swelling and cracking" ) - self.assertEqual( - options.negative.secondary["particle mechanics"], "swelling and cracking" - ) - self.assertEqual(options.positive["particle mechanics"], "none") - self.assertEqual(options.positive.primary["particle mechanics"], "none") - self.assertEqual(options.positive.secondary["particle mechanics"], "none") + assert options.positive["particle mechanics"] == "none" + assert options.positive.primary["particle mechanics"] == "none" + assert options.positive.secondary["particle mechanics"] == "none" def test_whole_cell_domains(self): options = BatteryModelOptions({"working electrode": "positive"}) - self.assertEqual( - options.whole_cell_domains, ["separator", "positive electrode"] - ) + assert options.whole_cell_domains == ["separator", "positive electrode"] options = BatteryModelOptions({}) - self.assertEqual( - options.whole_cell_domains, - ["negative electrode", "separator", "positive electrode"], - ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert options.whole_cell_domains == [ + "negative electrode", + "separator", + "positive electrode", + ] diff --git a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py index 014b467715..91fa8ef87e 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_OKane2022.py @@ -2,11 +2,11 @@ # Tests for O'Kane (2022) parameter set # +import pytest import pybamm -import unittest -class TestOKane2022(unittest.TestCase): +class TestOKane2022: def test_functions(self): param = pybamm.ParameterValues("OKane2022") sto = pybamm.Scalar(0.9) @@ -40,16 +40,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py b/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py index d7133a73e0..77fc3d66e7 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_parameters_with_default_models.py @@ -3,11 +3,10 @@ # import pybamm -import unittest -class TestParameterValuesWithModel(unittest.TestCase): - def test_parameter_values_with_model(self): +class TestParameterValuesWithModel: + def test_parameter_values_with_model(self, subtests): param_to_model = { "Ai2020": pybamm.lithium_ion.DFN( {"particle mechanics": "swelling and cracking"} @@ -46,16 +45,6 @@ def test_parameter_values_with_model(self): # Loop over each parameter set, testing that parameters can be set for param, model in param_to_model.items(): - with self.subTest(param=param): + with subtests.test(param=param): parameter_values = pybamm.ParameterValues(param) parameter_values.process_model(model) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_idaklu_jax.py b/tests/unit/test_solvers/test_idaklu_jax.py index a99f108f40..53abb94c83 100644 --- a/tests/unit/test_solvers/test_idaklu_jax.py +++ b/tests/unit/test_solvers/test_idaklu_jax.py @@ -2,11 +2,10 @@ # Tests for the KLU-Jax interface class # -from parameterized import parameterized +import pytest import pybamm import numpy as np -import unittest testcase = [] if pybamm.has_idaklu() and pybamm.has_jax(): @@ -86,21 +85,21 @@ def no_jit(f): # Check the interface throws an appropriate error if either IDAKLU or JAX not available -@unittest.skipIf( +@pytest.mark.skipif( pybamm.has_idaklu() and pybamm.has_jax(), - "Both IDAKLU and JAX are available", + reason="Both IDAKLU and JAX are available", ) -class TestIDAKLUJax_NoJax(unittest.TestCase): +class TestIDAKLUJax_NoJax: def test_instantiate_fails(self): - with self.assertRaises(ModuleNotFoundError): + with pytest.raises(ModuleNotFoundError): pybamm.IDAKLUJax([], [], []) -@unittest.skipIf( +@pytest.mark.skipif( not pybamm.has_idaklu() or not pybamm.has_jax(), - "IDAKLU Solver and/or JAX are not available", + reason="IDAKLU Solver and/or JAX are not available", ) -class TestIDAKLUJax(unittest.TestCase): +class TestIDAKLUJax: # Initialisation tests def test_initialise_twice(self): @@ -110,7 +109,7 @@ def test_initialise_twice(self): output_variables=output_variables, calculate_sensitivities=True, ) - with self.assertWarns(UserWarning): + with pytest.warns(UserWarning): idaklu_jax_solver.jaxify( model, t_eval, @@ -127,15 +126,15 @@ def test_uninitialised(self): ) # simulate failure in initialisation idaklu_jax_solver.jaxpr = None - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_jax_solver.get_jaxpr() - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_jax_solver.jax_value() - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_jax_solver.jax_grad() def test_no_output_variables(self): - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): idaklu_solver.jaxify( model, t_eval, @@ -170,63 +169,63 @@ def test_no_inputs(self): # Scalar evaluation - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(f)(t_eval[k], inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval[k]) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(f)(t_eval, inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.vmap(f, in_axes=in_axes))(t_eval, inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_f_batch_over_inputs(self, output_variables, idaklu_jax_solver, f, wrapper): inputs_mock = np.array([1.0, 2.0, 3.0]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper(jax.vmap(f, in_axes=(None, 0)))(t_eval, inputs_mock) # Get all vars (should mirror test_f_* [above]) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_call_signature( self, output_variables, idaklu_jax_solver, f, wrapper ): if wrapper == jax.jit: return # test does not involve a JAX expression - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_vars() # no variable name specified idaklu_jax_solver.get_vars(output_variables) # (okay) idaklu_jax_solver.get_vars(f, output_variables) # (okay) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_vars(1, 2, 3) # too many arguments - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(idaklu_jax_solver.get_vars(output_variables))(t_eval[k], inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval[k]) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(idaklu_jax_solver.get_vars(output_variables))(t_eval, inputs) np.testing.assert_allclose( out, np.array([sim[outvar](t_eval) for outvar in output_variables]).T ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_vector_array( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -236,7 +235,7 @@ def test_getvars_vector_array( out = idaklu_jax_solver.get_vars(array, output_variables) np.testing.assert_allclose(out, array) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvars_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -250,20 +249,20 @@ def test_getvars_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): # Isolate single output variable - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_call_signature( self, output_variables, idaklu_jax_solver, f, wrapper ): if wrapper == jax.jit: return # test does not involve a JAX expression - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_var() # no variable name specified idaklu_jax_solver.get_var(output_variables[0]) # (okay) idaklu_jax_solver.get_var(f, output_variables[0]) # (okay) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): idaklu_jax_solver.get_var(1, 2, 3) # too many arguments - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_float_jaxpr( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -272,7 +271,7 @@ def test_getvar_scalar_float_jaxpr( out = wrapper(idaklu_jax_solver.get_var(outvar))(float(t_eval[k]), inputs) np.testing.assert_allclose(out, sim[outvar](float(t_eval[k]))) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_float_f( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -283,35 +282,35 @@ def test_getvar_scalar_float_f( ) np.testing.assert_allclose(out, sim[outvar](float(t_eval[k]))) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using the default JAX expression (self.jaxpr) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(outvar))(t_eval[k], inputs) np.testing.assert_allclose(out, sim[outvar](t_eval[k])) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_scalar_f(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided JAX expression (f) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(outvar))(t_eval[k], inputs) np.testing.assert_allclose(out, sim[outvar](t_eval[k])) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vector_jaxpr(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using the default JAX expression (self.jaxpr) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(outvar))(t_eval, inputs) np.testing.assert_allclose(out, sim[outvar](t_eval)) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vector_f(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided JAX expression (f) for outvar in output_variables: out = wrapper(idaklu_jax_solver.get_var(f, outvar))(t_eval, inputs) np.testing.assert_allclose(out, sim[outvar](t_eval)) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vector_array(self, output_variables, idaklu_jax_solver, f, wrapper): # Per variable checks using a provided np.ndarray if wrapper == jax.jit: @@ -321,7 +320,7 @@ def test_getvar_vector_array(self, output_variables, idaklu_jax_solver, f, wrapp out = idaklu_jax_solver.get_var(array, outvar) np.testing.assert_allclose(out, sim[outvar](t_eval)) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_getvar_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -334,7 +333,7 @@ def test_getvar_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): # Differentiation rules (jacfwd) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacfwd(f, argnums=1))(t_eval[k], inputs) flat_out, _ = tree_flatten(out) @@ -348,7 +347,7 @@ def test_jacfwd_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): ).T np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacfwd(f, argnums=1))(t_eval, inputs) flat_out, _ = tree_flatten(out) @@ -365,7 +364,7 @@ def test_jacfwd_vector(self, output_variables, idaklu_jax_solver, f, wrapper): f"Got: {flat_out}\nExpected: {check}", ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -384,11 +383,11 @@ def test_jacfwd_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap_wrt_time( self, output_variables, idaklu_jax_solver, f, wrapper ): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper( jax.vmap( jax.jacfwd(f, argnums=0), @@ -396,12 +395,12 @@ def test_jacfwd_vmap_wrt_time( ), )(t_eval, inputs) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_batch_over_inputs( self, output_variables, idaklu_jax_solver, f, wrapper ): inputs_mock = np.array([1.0, 2.0, 3.0]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper( jax.vmap( jax.jacfwd(f, argnums=1), @@ -411,7 +410,7 @@ def test_jacfwd_batch_over_inputs( # Differentiation rules (jacrev) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacrev(f, argnums=1))(t_eval[k], inputs) flat_out, _ = tree_flatten(out) @@ -425,7 +424,7 @@ def test_jacrev_scalar(self, output_variables, idaklu_jax_solver, f, wrapper): ).T np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vector(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper(jax.jacrev(f, argnums=1))(t_eval, inputs) flat_out, _ = tree_flatten(out) @@ -439,7 +438,7 @@ def test_jacrev_vector(self, output_variables, idaklu_jax_solver, f, wrapper): ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -458,12 +457,12 @@ def test_jacrev_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_batch_over_inputs( self, output_variables, idaklu_jax_solver, f, wrapper ): inputs_mock = np.array([1.0, 2.0, 3.0]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): wrapper( jax.vmap( jax.jacrev(f, argnums=1), @@ -473,7 +472,7 @@ def test_jacrev_batch_over_inputs( # Forward differentiation rules with get_vars (multiple) and get_var (singular) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_scalar_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -496,7 +495,7 @@ def test_jacfwd_scalar_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_scalar_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -515,7 +514,7 @@ def test_jacfwd_scalar_getvar( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vector_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -539,7 +538,7 @@ def test_jacfwd_vector_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vector_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -558,7 +557,7 @@ def test_jacfwd_vector_getvar( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -577,7 +576,7 @@ def test_jacfwd_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapp ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacfwd_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -596,7 +595,7 @@ def test_jacfwd_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrappe # Reverse differentiation rules with get_vars (multiple) and get_var (singular) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_scalar_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -619,7 +618,7 @@ def test_jacrev_scalar_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_scalar_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -640,7 +639,7 @@ def test_jacrev_scalar_getvar( f"Got: {flat_out}\nExpected: {check}", ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vector_getvars( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -664,7 +663,7 @@ def test_jacrev_vector_getvars( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vector_getvar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -683,7 +682,7 @@ def test_jacrev_vector_getvar( flat_check, _ = tree_flatten(check) np.testing.assert_allclose(flat_out, flat_check) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapper): out = wrapper( jax.vmap( @@ -702,7 +701,7 @@ def test_jacrev_vmap_getvars(self, output_variables, idaklu_jax_solver, f, wrapp ) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jacrev_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -721,7 +720,7 @@ def test_jacrev_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrappe # Gradient rule (takes single variable) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_grad_scalar_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -735,7 +734,7 @@ def test_grad_scalar_getvar(self, output_variables, idaklu_jax_solver, f, wrappe check = np.array([sim[outvar].sensitivities[invar][k] for invar in inputs]) np.testing.assert_allclose(flat_out, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_grad_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: out = wrapper( @@ -754,7 +753,7 @@ def test_grad_vmap_getvar(self, output_variables, idaklu_jax_solver, f, wrapper) # Value and gradient (takes single variable) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_value_and_grad_scalar( self, output_variables, idaklu_jax_solver, f, wrapper ): @@ -774,7 +773,7 @@ def test_value_and_grad_scalar( check = np.array([sim[outvar].sensitivities[invar][k] for invar in inputs]) np.testing.assert_allclose(flat_t, check.flatten()) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_value_and_grad_vmap(self, output_variables, idaklu_jax_solver, f, wrapper): for outvar in output_variables: primals, tangents = wrapper( @@ -797,7 +796,7 @@ def test_value_and_grad_vmap(self, output_variables, idaklu_jax_solver, f, wrapp # Helper functions - These return values (not jaxexprs) so cannot be JITed - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jax_vars(self, output_variables, idaklu_jax_solver, f, wrapper): if wrapper == jax.jit: # Skipping test_jax_vars for jax.jit, jit not supported on helper functions @@ -812,7 +811,7 @@ def test_jax_vars(self, output_variables, idaklu_jax_solver, f, wrapper): f"{outvar}: Got: {flat_out}\nExpected: {check}", ) - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_jax_grad(self, output_variables, idaklu_jax_solver, f, wrapper): if wrapper == jax.jit: # Skipping test_jax_grad for jax.jit, jit not supported on helper functions @@ -829,7 +828,7 @@ def test_jax_grad(self, output_variables, idaklu_jax_solver, f, wrapper): # Wrap jaxified expression in another function and take the gradient - @parameterized.expand(testcase, skip_on_empty=True) + @pytest.mark.parametrize("output_variables,idaklu_jax_solver,f,wrapper", testcase) def test_grad_wrapper_sse(self, output_variables, idaklu_jax_solver, f, wrapper): # Use surrogate for experimental data data = sim["v"](t_eval) diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py index dfaa6c7201..c5516ee880 100644 --- a/tests/unit/test_solvers/test_scipy_solver.py +++ b/tests/unit/test_solvers/test_scipy_solver.py @@ -1,15 +1,14 @@ # Tests for the Scipy Solver class # +import pytest import pybamm -import unittest import numpy as np from tests import get_mesh_for_testing, get_discretisation_for_testing import warnings -import sys -class TestScipySolver(unittest.TestCase): +class TestScipySolver: def test_model_solver_python_and_jax(self): if pybamm.has_jax(): formats = ["python", "jax"] @@ -43,10 +42,8 @@ def test_model_solver_python_and_jax(self): np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t)) # Test time - self.assertEqual( - solution.total_time, solution.solve_time + solution.set_up_time - ) - self.assertEqual(solution.termination, "final time") + assert solution.total_time == solution.solve_time + solution.set_up_time + assert solution.termination == "final time" def test_model_solver_failure(self): # Turn off warnings to ignore sqrt error @@ -65,7 +62,7 @@ def test_model_solver_failure(self): t_eval = np.linspace(0, 3, 100) solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") # Expect solver to fail when y goes negative - with self.assertRaises(pybamm.SolverError): + with pytest.raises(pybamm.SolverError): solver.solve(model, t_eval) # Turn warnings back on @@ -96,7 +93,7 @@ def test_model_solver_with_event_python(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model, t_eval) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal(solution.t[:-1], t_eval[: len(solution.t) - 1]) np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) np.testing.assert_equal(solution.t_event[0], solution.t[-1]) @@ -222,7 +219,7 @@ def test_step_different_model(self): np.testing.assert_array_almost_equal(step_sol1.y[0], np.exp(0.1 * step_sol1.t)) # Step again, the model has changed so this raises an error - with self.assertRaisesRegex(RuntimeError, "already been initialised"): + with pytest.raises(RuntimeError, match="already been initialised"): solver.step(step_sol1, model2, dt) def test_model_solver_with_inputs(self): @@ -245,11 +242,11 @@ def test_model_solver_with_inputs(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model, t_eval, inputs={"rate": 0.1}) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal(solution.t[:-1], t_eval[: len(solution.t) - 1]) np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) - def test_model_solver_multiple_inputs_happy_path(self): + def test_model_solver_multiple_inputs_happy_path(self, subtests): for convert_to_format in ["python", "casadi"]: # Create model model = pybamm.BaseModel() @@ -271,7 +268,7 @@ def test_model_solver_multiple_inputs_happy_path(self): solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) for i in range(ninputs): - with self.subTest(i=i): + with subtests.test(i=i): solution = solutions[i] np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_allclose( @@ -304,12 +301,10 @@ def test_model_solver_multiple_inputs_discontinuity_error(self): event_type=pybamm.EventType.DISCONTINUITY, ) ] - with self.assertRaisesRegex( + with pytest.raises( pybamm.SolverError, - ( - "Cannot solve for a list of input parameters" - " sets with discontinuities" - ), + match="Cannot solve for a list of input parameters" + " sets with discontinuities", ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) @@ -332,13 +327,14 @@ def test_model_solver_multiple_inputs_initial_conditions_error(self): ninputs = 8 inputs_list = [{"rate": 0.01 * (i + 1)} for i in range(ninputs)] - with self.assertRaisesRegex( + with pytest.raises( pybamm.SolverError, - ("Input parameters cannot appear in expression " "for initial conditions."), + match="Input parameters cannot appear in expression " + "for initial conditions.", ): solver.solve(model, t_eval, inputs=inputs_list, nproc=2) - def test_model_solver_multiple_inputs_jax_format(self): + def test_model_solver_multiple_inputs_jax_format(self, subtests): if pybamm.has_jax(): # Create model model = pybamm.BaseModel() @@ -360,7 +356,7 @@ def test_model_solver_multiple_inputs_jax_format(self): solutions = solver.solve(model, t_eval, inputs=inputs_list, nproc=2) for i in range(ninputs): - with self.subTest(i=i): + with subtests.test(i=i): solution = solutions[i] np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_allclose( @@ -395,7 +391,7 @@ def test_model_solver_with_event_with_casadi(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model_disc, t_eval) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal( solution.t[:-1], t_eval[: len(solution.t) - 1] ) @@ -422,7 +418,7 @@ def test_model_solver_with_inputs_with_casadi(self): solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45") t_eval = np.linspace(0, 10, 100) solution = solver.solve(model, t_eval, inputs={"rate": 0.1}) - self.assertLess(len(solution.t), len(t_eval)) + assert len(solution.t) < len(t_eval) np.testing.assert_array_equal(solution.t[:-1], t_eval[: len(solution.t) - 1]) np.testing.assert_allclose(solution.y[0], np.exp(-0.1 * solution.t)) @@ -492,7 +488,7 @@ def test_scale_and_reference(self): ) -class TestScipySolverWithSensitivity(unittest.TestCase): +class TestScipySolverWithSensitivity: def test_solve_sensitivity_scalar_var_scalar_input(self): # Create model model = pybamm.BaseModel() @@ -780,12 +776,3 @@ def test_solve_sensitivity_vector_var_vector_input(self): solution["integral of var"].sensitivities["param"], np.vstack([-2 * t * np.exp(-p_eval * t) * l_n / n for t in t_eval]), ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 5a584fabbf..3aff012d5b 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -1,11 +1,12 @@ # # Tests for the Solution class # +import pytest import os - +import io +import logging import json import pybamm -import unittest import numpy as np import pandas as pd from scipy.io import loadmat @@ -13,23 +14,23 @@ from tempfile import TemporaryDirectory -class TestSolution(unittest.TestCase): +class TestSolution: def test_init(self): t = np.linspace(0, 1) y = np.tile(t, (20, 1)) sol = pybamm.Solution(t, y, pybamm.BaseModel(), {}) np.testing.assert_array_equal(sol.t, t) np.testing.assert_array_equal(sol.y, y) - self.assertEqual(sol.t_event, None) - self.assertEqual(sol.y_event, None) - self.assertEqual(sol.termination, "final time") - self.assertEqual(sol.all_inputs, [{}]) - self.assertIsInstance(sol.all_models[0], pybamm.BaseModel) + assert sol.t_event is None + assert sol.y_event is None + assert sol.termination == "final time" + assert sol.all_inputs == [{}] + assert isinstance(sol.all_models[0], pybamm.BaseModel) def test_sensitivities(self): t = np.linspace(0, 1) y = np.tile(t, (20, 1)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): pybamm.Solution(t, y, pybamm.BaseModel(), {}, sensitivities=1.0) def test_errors(self): @@ -37,8 +38,8 @@ def test_errors(self): sol = pybamm.Solution( bad_ts, [np.ones((1, 3)), np.ones((1, 3))], pybamm.BaseModel(), {} ) - with self.assertRaisesRegex( - ValueError, "Solution time vector must be strictly increasing" + with pytest.raises( + ValueError, match="Solution time vector must be strictly increasing" ): sol.set_t() @@ -48,21 +49,27 @@ def test_errors(self): var = pybamm.StateVector(slice(0, 1)) model.rhs = {var: 0} model.variables = {var.name: var} - with self.assertLogs() as captured: - pybamm.Solution(ts, bad_ys, model, {}) - self.assertIn("exceeds the maximum", captured.records[0].getMessage()) - - with self.assertRaisesRegex( - TypeError, "sensitivities arg needs to be a bool or dict" + log_capture = io.StringIO() + handler = logging.StreamHandler(log_capture) + handler.setLevel(logging.ERROR) + logger = logging.getLogger("pybamm.logger") + logger.addHandler(handler) + pybamm.Solution(ts, bad_ys, model, {}) + log_output = log_capture.getvalue() + assert "exceeds the maximum" in log_output + logger.removeHandler(handler) + + with pytest.raises( + TypeError, match="sensitivities arg needs to be a bool or dict" ): pybamm.Solution(ts, bad_ys, model, {}, all_sensitivities="bad") sol = pybamm.Solution(ts, bad_ys, model, {}, all_sensitivities={}) - with self.assertRaisesRegex(TypeError, "sensitivities arg needs to be a bool"): + with pytest.raises(TypeError, match="sensitivities arg needs to be a bool"): sol.sensitivities = "bad" - with self.assertRaisesRegex( + with pytest.raises( NotImplementedError, - "Setting sensitivities is not supported if sensitivities are already provided as a dict", + match="Setting sensitivities is not supported if sensitivities are already provided as a dict", ): sol.sensitivities = True @@ -84,7 +91,7 @@ def test_add_solutions(self): sol_sum = sol1 + sol2 # Test - self.assertEqual(sol_sum.integration_time, 0.8) + assert sol_sum.integration_time == 0.8 np.testing.assert_array_equal(sol_sum.t, np.concatenate([t1, t2[1:]])) np.testing.assert_array_equal( sol_sum.y, np.concatenate([y1, y2[:, 1:]], axis=1) @@ -92,36 +99,38 @@ def test_add_solutions(self): np.testing.assert_array_equal(sol_sum.all_inputs, [{"a": 1}, {"a": 2}]) # Test sub-solutions - self.assertEqual(len(sol_sum.sub_solutions), 2) + assert len(sol_sum.sub_solutions) == 2 np.testing.assert_array_equal(sol_sum.sub_solutions[0].t, t1) np.testing.assert_array_equal(sol_sum.sub_solutions[1].t, t2) - self.assertEqual(sol_sum.sub_solutions[0].all_models[0], sol_sum.all_models[0]) + assert sol_sum.sub_solutions[0].all_models[0] == sol_sum.all_models[0] np.testing.assert_array_equal(sol_sum.sub_solutions[0].all_inputs[0]["a"], 1) - self.assertEqual(sol_sum.sub_solutions[1].all_models[0], sol2.all_models[0]) - self.assertEqual(sol_sum.all_models[1], sol2.all_models[0]) + assert sol_sum.sub_solutions[1].all_models[0] == sol2.all_models[0] + assert sol_sum.all_models[1] == sol2.all_models[0] np.testing.assert_array_equal(sol_sum.sub_solutions[1].all_inputs[0]["a"], 2) # Add solution already contained in existing solution t3 = np.array([2]) y3 = np.ones((1, 1)) sol3 = pybamm.Solution(t3, y3, pybamm.BaseModel(), {"a": 3}) - self.assertEqual((sol_sum + sol3).all_ts, sol_sum.copy().all_ts) + assert (sol_sum + sol3).all_ts == sol_sum.copy().all_ts # add None sol4 = sol3 + None - self.assertEqual(sol3.all_ys, sol4.all_ys) + assert sol3.all_ys == sol4.all_ys # radd sol5 = None + sol3 - self.assertEqual(sol3.all_ys, sol5.all_ys) + assert sol3.all_ys == sol5.all_ys # radd failure - with self.assertRaisesRegex( - pybamm.SolverError, "Only a Solution or None can be added to a Solution" + with pytest.raises( + pybamm.SolverError, + match="Only a Solution or None can be added to a Solution", ): sol3 + 2 - with self.assertRaisesRegex( - pybamm.SolverError, "Only a Solution or None can be added to a Solution" + with pytest.raises( + pybamm.SolverError, + match="Only a Solution or None can be added to a Solution", ): 2 + sol3 @@ -133,14 +142,12 @@ def test_add_solutions(self): all_sensitivities={"test": [np.ones((1, 3))]}, ) sol2 = pybamm.Solution(t2, y2, pybamm.BaseModel(), {}, all_sensitivities=True) - with self.assertRaisesRegex( - ValueError, "Sensitivities must be of the same type" - ): + with pytest.raises(ValueError, match="Sensitivities must be of the same type"): sol3 = sol1 + sol2 sol1 = pybamm.Solution(t1, y3, pybamm.BaseModel(), {}, all_sensitivities=False) sol2 = pybamm.Solution(t3, y3, pybamm.BaseModel(), {}, all_sensitivities={}) sol3 = sol1 + sol2 - self.assertFalse(sol3._all_sensitivities) + assert not sol3._all_sensitivities def test_add_solutions_different_models(self): # Set up first solution @@ -160,8 +167,8 @@ def test_add_solutions_different_models(self): # Test np.testing.assert_array_equal(sol_sum.t, np.concatenate([t1, t2[1:]])) - with self.assertRaisesRegex( - pybamm.SolverError, "The solution is made up from different models" + with pytest.raises( + pybamm.SolverError, match="The solution is made up from different models" ): sol_sum.y @@ -176,14 +183,14 @@ def test_copy(self): sol1.integration_time = 0.3 sol_copy = sol1.copy() - self.assertEqual(sol_copy.all_ts, sol1.all_ts) + assert sol_copy.all_ts == sol1.all_ts for ys_copy, ys1 in zip(sol_copy.all_ys, sol1.all_ys): np.testing.assert_array_equal(ys_copy, ys1) - self.assertEqual(sol_copy.all_inputs, sol1.all_inputs) - self.assertEqual(sol_copy.all_inputs_casadi, sol1.all_inputs_casadi) - self.assertEqual(sol_copy.set_up_time, sol1.set_up_time) - self.assertEqual(sol_copy.solve_time, sol1.solve_time) - self.assertEqual(sol_copy.integration_time, sol1.integration_time) + assert sol_copy.all_inputs == sol1.all_inputs + assert sol_copy.all_inputs_casadi == sol1.all_inputs_casadi + assert sol_copy.set_up_time == sol1.set_up_time + assert sol_copy.solve_time == sol1.solve_time + assert sol_copy.integration_time == sol1.integration_time def test_last_state(self): # Set up first solution @@ -196,14 +203,14 @@ def test_last_state(self): sol1.integration_time = 0.3 sol_last_state = sol1.last_state - self.assertEqual(sol_last_state.all_ts[0], 2) + assert sol_last_state.all_ts[0] == 2 np.testing.assert_array_equal(sol_last_state.all_ys[0], 2) - self.assertEqual(sol_last_state.all_inputs, sol1.all_inputs[-1:]) - self.assertEqual(sol_last_state.all_inputs_casadi, sol1.all_inputs_casadi[-1:]) - self.assertEqual(sol_last_state.all_models, sol1.all_models[-1:]) - self.assertEqual(sol_last_state.set_up_time, 0) - self.assertEqual(sol_last_state.solve_time, 0) - self.assertEqual(sol_last_state.integration_time, 0) + assert sol_last_state.all_inputs == sol1.all_inputs[-1:] + assert sol_last_state.all_inputs_casadi == sol1.all_inputs_casadi[-1:] + assert sol_last_state.all_models == sol1.all_models[-1:] + assert sol_last_state.set_up_time == 0 + assert sol_last_state.solve_time == 0 + assert sol_last_state.integration_time == 0 def test_cycles(self): model = pybamm.lithium_ion.SPM() @@ -215,14 +222,14 @@ def test_cycles(self): ) sim = pybamm.Simulation(model, experiment=experiment) sol = sim.solve() - self.assertEqual(len(sol.cycles), 2) + assert len(sol.cycles) == 2 len_cycle_1 = len(sol.cycles[0].t) - self.assertIsInstance(sol.cycles[0], pybamm.Solution) + assert isinstance(sol.cycles[0], pybamm.Solution) np.testing.assert_array_equal(sol.cycles[0].t, sol.t[:len_cycle_1]) np.testing.assert_array_equal(sol.cycles[0].y, sol.y[:, :len_cycle_1]) - self.assertIsInstance(sol.cycles[1], pybamm.Solution) + assert isinstance(sol.cycles[1], pybamm.Solution) np.testing.assert_array_equal(sol.cycles[1].t, sol.t[len_cycle_1:]) np.testing.assert_allclose(sol.cycles[1].y, sol.y[:, len_cycle_1:]) @@ -230,7 +237,7 @@ def test_total_time(self): sol = pybamm.Solution(np.array([0]), np.array([[1, 2]]), pybamm.BaseModel(), {}) sol.set_up_time = 0.5 sol.solve_time = 1.2 - self.assertEqual(sol.total_time, 1.7) + assert sol.total_time == 1.7 def test_getitem(self): model = pybamm.BaseModel() @@ -244,13 +251,13 @@ def test_getitem(self): # test create a new processed variable c_sol = solution["c"] - self.assertIsInstance(c_sol, pybamm.ProcessedVariable) + assert isinstance(c_sol, pybamm.ProcessedVariable) np.testing.assert_array_equal(c_sol.entries, c_sol(solution.t)) # test call an already created variable solution.update("2c") twoc_sol = solution["2c"] - self.assertIsInstance(twoc_sol, pybamm.ProcessedVariable) + assert isinstance(twoc_sol, pybamm.ProcessedVariable) np.testing.assert_array_equal(twoc_sol.entries, twoc_sol(solution.t)) np.testing.assert_array_equal(twoc_sol.entries, 2 * c_sol.entries) @@ -283,12 +290,12 @@ def test_save(self): solution = pybamm.ScipySolver().solve(model, np.linspace(0, 1)) # test save data - with self.assertRaises(ValueError): + with pytest.raises(ValueError): solution.save_data(f"{test_stub}.pickle") # set variables first then save solution.update(["c", "d"]) - with self.assertRaisesRegex(ValueError, "pickle"): + with pytest.raises(ValueError, match="pickle"): solution.save_data(to_format="pickle") solution.save_data(f"{test_stub}.pickle") @@ -302,12 +309,12 @@ def test_save(self): np.testing.assert_array_equal(solution.data["c"], data_load["c"].flatten()) np.testing.assert_array_equal(solution.data["d"], data_load["d"]) - with self.assertRaisesRegex(ValueError, "matlab"): + with pytest.raises(ValueError, match="matlab"): solution.save_data(to_format="matlab") # to matlab with bad variables name fails solution.update(["c + d"]) - with self.assertRaisesRegex(ValueError, "Invalid character"): + with pytest.raises(ValueError, match="Invalid character"): solution.save_data(f"{test_stub}.mat", to_format="matlab") # Works if providing alternative name solution.save_data( @@ -319,8 +326,8 @@ def test_save(self): np.testing.assert_array_equal(solution.data["c + d"], data_load["c_plus_d"]) # to csv - with self.assertRaisesRegex( - ValueError, "only 0D variables can be saved to csv" + with pytest.raises( + ValueError, match="only 0D variables can be saved to csv" ): solution.save_data(f"{test_stub}.csv", to_format="csv") # only save "c" and "2c" @@ -330,7 +337,7 @@ def test_save(self): # check string is the same as the file with open(f"{test_stub}.csv") as f: # need to strip \r chars for windows - self.assertEqual(csv_str.replace("\r", ""), f.read()) + assert csv_str.replace("\r", "") == f.read() # read csv df = pd.read_csv(f"{test_stub}.csv") @@ -344,7 +351,7 @@ def test_save(self): # check string is the same as the file with open(f"{test_stub}.json") as f: # need to strip \r chars for windows - self.assertEqual(json_str.replace("\r", ""), f.read()) + assert json_str.replace("\r", "") == f.read() # check if string has the right values json_data = json.loads(json_str) @@ -352,17 +359,15 @@ def test_save(self): np.testing.assert_array_almost_equal(json_data["d"], solution.data["d"]) # raise error if format is unknown - with self.assertRaisesRegex( - ValueError, "format 'wrong_format' not recognised" + with pytest.raises( + ValueError, match="format 'wrong_format' not recognised" ): solution.save_data(f"{test_stub}.csv", to_format="wrong_format") # test save whole solution solution.save(f"{test_stub}.pickle") solution_load = pybamm.load(f"{test_stub}.pickle") - self.assertEqual( - solution.all_models[0].name, solution_load.all_models[0].name - ) + assert solution.all_models[0].name == solution_load.all_models[0].name np.testing.assert_array_equal( solution["c"].entries, solution_load["c"].entries ) @@ -412,14 +417,4 @@ def test_solution_evals_with_inputs(self): inputs = {"Negative electrode conductivity [S.m-1]": 0.1} sim.solve(t_eval=np.linspace(0, 10, 10), inputs=inputs) time = sim.solution["Time [h]"](sim.solution.t) - self.assertEqual(len(time), 10) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() + assert len(time) == 10 From 78c24106f45f9f04787856ed1908710ce210356c Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+prady0t@users.noreply.github.com> Date: Tue, 1 Oct 2024 18:53:55 +0530 Subject: [PATCH 2/6] Removing all instances of unittest (#4472) * Removing all instances of unittest Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> * style: pre-commit fixes * Adding ruff rules Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --------- Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 2 + .../test_lithium_ion/test_compare_outputs.py | 12 +--- .../test_interface/test_butler_volmer.py | 57 +++++++------------ tests/testcase.py | 17 ------ .../test_concatenations.py | 7 +-- .../test_expression_tree/test_functions.py | 5 +- .../test_input_parameter.py | 5 +- .../test_operations/test_evaluate_python.py | 11 ---- .../test_parameter_sets/test_Ecker2015.py | 18 ++---- 9 files changed, 33 insertions(+), 101 deletions(-) delete mode 100644 tests/testcase.py diff --git a/pyproject.toml b/pyproject.toml index d2c487ede4..328388bed7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -196,6 +196,8 @@ extend-select = [ "YTT", # flake8-2020 "TID252", # relative-imports "S101", # to identify use of assert statement + "PT027", # remove unittest style assertion + "PT009", # Use pytest.raises instead of unittest-style ] ignore = [ "E741", # Ambiguous variable name diff --git a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_compare_outputs.py b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_compare_outputs.py index 57067d6e3b..8ae7393f83 100644 --- a/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_compare_outputs.py +++ b/tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_compare_outputs.py @@ -4,11 +4,10 @@ import pybamm import numpy as np -import unittest from tests import StandardOutputComparison -class TestCompareOutputs(unittest.TestCase): +class TestCompareOutputs: def test_compare_outputs_surface_form(self): # load models options = [ @@ -142,12 +141,3 @@ def test_compare_narrow_size_distribution(self): # compare outputs comparison = StandardOutputComparison(solutions) comparison.test_all(skip_first_timestep=True) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - unittest.main() diff --git a/tests/integration/test_models/test_submodels/test_interface/test_butler_volmer.py b/tests/integration/test_models/test_submodels/test_interface/test_butler_volmer.py index f1f02350cd..7b2100792e 100644 --- a/tests/integration/test_models/test_submodels/test_interface/test_butler_volmer.py +++ b/tests/integration/test_models/test_submodels/test_interface/test_butler_volmer.py @@ -2,15 +2,15 @@ # Tests for the electrode-electrolyte interface equations # +import pytest import pybamm from tests import get_discretisation_for_testing -import unittest import numpy as np -class TestButlerVolmer(unittest.TestCase): - def setUp(self): +class TestButlerVolmer: + def setup_method(self): self.delta_phi_s_n = pybamm.Variable( "surface potential difference [V]", ["negative electrode"], @@ -73,7 +73,7 @@ def setUp(self): "reaction source terms [A.m-3]": 1, } - def tearDown(self): + def teardown_method(self): del self.variables del self.c_e_n del self.c_e_p @@ -114,12 +114,12 @@ def test_creation(self): ] # negative electrode Butler-Volmer is Multiplication - self.assertIsInstance(j_n, pybamm.Multiplication) - self.assertEqual(j_n.domain, ["negative electrode"]) + assert isinstance(j_n, pybamm.Multiplication) + assert j_n.domain == ["negative electrode"] # positive electrode Butler-Volmer is Multiplication - self.assertIsInstance(j_p, pybamm.Multiplication) - self.assertEqual(j_p.domain, ["positive electrode"]) + assert isinstance(j_p, pybamm.Multiplication) + assert j_p.domain == ["positive electrode"] def test_set_parameters(self): param = pybamm.LithiumIonParameters() @@ -159,9 +159,9 @@ def test_set_parameters(self): j_p = parameter_values.process_symbol(j_p) # Test for x in j_n.pre_order(): - self.assertNotIsInstance(x, pybamm.Parameter) + assert not isinstance(x, pybamm.Parameter) for x in j_p.pre_order(): - self.assertNotIsInstance(x, pybamm.Parameter) + assert not isinstance(x, pybamm.Parameter) def test_discretisation(self): param = pybamm.LithiumIonParameters() @@ -219,17 +219,13 @@ def test_discretisation(self): [mesh["negative electrode"].nodes, mesh["positive electrode"].nodes] ) y = np.concatenate([submesh**2, submesh**3, submesh**4]) - self.assertEqual( - j_n.evaluate(None, y).shape, (mesh["negative electrode"].npts, 1) - ) - self.assertEqual( - j_p.evaluate(None, y).shape, (mesh["positive electrode"].npts, 1) - ) + assert j_n.evaluate(None, y).shape == (mesh["negative electrode"].npts, 1) + assert j_p.evaluate(None, y).shape == (mesh["positive electrode"].npts, 1) # test concatenated butler-volmer whole_cell = ["negative electrode", "separator", "positive electrode"] whole_cell_mesh = disc.mesh[whole_cell] - self.assertEqual(j.evaluate(None, y).shape, (whole_cell_mesh.npts, 1)) + assert j.evaluate(None, y).shape == (whole_cell_mesh.npts, 1) def test_diff_c_e_lead_acid(self): # With intercalation @@ -361,27 +357,12 @@ def j_p(delta_phi): j_n_FD = parameter_values.process_symbol( (j_n(delta_phi + h) - j_n(delta_phi - h)) / (2 * h) ) - self.assertAlmostEqual( - j_n_diff.evaluate(inputs={"delta_phi": 0.5}) - / j_n_FD.evaluate(inputs={"delta_phi": 0.5}), - 1, - places=5, - ) + assert j_n_diff.evaluate(inputs={"delta_phi": 0.5}) / j_n_FD.evaluate( + inputs={"delta_phi": 0.5} + ) == pytest.approx(1, abs=1e-05) j_p_FD = parameter_values.process_symbol( (j_p(delta_phi + h) - j_p(delta_phi - h)) / (2 * h) ) - self.assertAlmostEqual( - j_p_diff.evaluate(inputs={"delta_phi": 0.5}) - / j_p_FD.evaluate(inputs={"delta_phi": 0.5}), - 1, - places=5, - ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - unittest.main() + assert j_p_diff.evaluate(inputs={"delta_phi": 0.5}) / j_p_FD.evaluate( + inputs={"delta_phi": 0.5} + ) == pytest.approx(1, abs=1e-05) diff --git a/tests/testcase.py b/tests/testcase.py deleted file mode 100644 index 0b9b1f5dee..0000000000 --- a/tests/testcase.py +++ /dev/null @@ -1,17 +0,0 @@ -# -# Custom TestCase class for pybamm -# -import unittest - - -class TestCase(unittest.TestCase): - """ - Custom TestCase class for PyBaMM - TO BE REMOVED - """ - - def assertDomainEqual(self, a, b): - "Check that two domains are equal, ignoring empty domains" - a_dict = {k: v for k, v in a.items() if v != []} - b_dict = {k: v for k, v in b.items() if v != []} - self.assertEqual(a_dict, b_dict) diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index 1d7ccef610..9d653a4d51 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -2,7 +2,6 @@ # Tests for the Concatenation class and subclasses # import pytest -import unittest.mock as mock from tests import assert_domain_equal @@ -377,7 +376,7 @@ def test_to_equation(self): # Test concat_sym assert pybamm.Concatenation(a, b).to_equation() == func_symbol - def test_to_from_json(self): + def test_to_from_json(self, mocker): # test DomainConcatenation mesh = get_mesh_for_testing() a = pybamm.Symbol("a", domain=["negative electrode"]) @@ -386,7 +385,7 @@ def test_to_from_json(self): json_dict = { "name": "domain_concatenation", - "id": mock.ANY, + "id": mocker.ANY, "domains": { "primary": ["negative electrode", "separator", "positive electrode"], "secondary": [], @@ -429,7 +428,7 @@ def test_to_from_json(self): np_json = { "name": "numpy_concatenation", - "id": mock.ANY, + "id": mocker.ANY, "domains": { "primary": [], "secondary": [], diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index 5f5324c0ae..fca7d7ee67 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -3,7 +3,6 @@ # import pytest -import unittest.mock as mock import numpy as np from scipy import special @@ -399,7 +398,7 @@ def test_tanh(self): abs=1e-05, ) - def test_erf(self): + def test_erf(self, mocker): a = pybamm.InputParameter("a") fun = pybamm.erf(a) assert fun.evaluate(inputs={"a": 3}) == special.erf(3) @@ -416,7 +415,7 @@ def test_erf(self): # test creation from json input_json = { "name": "erf", - "id": mock.ANY, + "id": mocker.ANY, "function": "erf", "children": [a], } diff --git a/tests/unit/test_expression_tree/test_input_parameter.py b/tests/unit/test_expression_tree/test_input_parameter.py index 87cbe79a31..884341cb4f 100644 --- a/tests/unit/test_expression_tree/test_input_parameter.py +++ b/tests/unit/test_expression_tree/test_input_parameter.py @@ -4,7 +4,6 @@ import numpy as np import pybamm import pytest -import unittest.mock as mock class TestInputParameter: @@ -49,12 +48,12 @@ def test_errors(self): with pytest.raises(KeyError): a.evaluate() - def test_to_from_json(self): + def test_to_from_json(self, mocker): a = pybamm.InputParameter("a") json_dict = { "name": "a", - "id": mock.ANY, + "id": mocker.ANY, "domain": [], "expected_size": 1, } diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index e6d8a0da83..14b980b358 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -6,7 +6,6 @@ import pybamm from tests import get_discretisation_for_testing, get_1p1d_discretisation_for_testing -import unittest import numpy as np import scipy.sparse from collections import OrderedDict @@ -746,13 +745,3 @@ def test_jax_coo_matrix(self): with pytest.raises(NotImplementedError): A.multiply(v) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() diff --git a/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015.py b/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015.py index 4be67175d7..d703b46200 100644 --- a/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015.py +++ b/tests/unit/test_parameters/test_parameter_sets/test_Ecker2015.py @@ -2,11 +2,11 @@ # Tests for O'Kane (2022) parameter set # +import pytest import pybamm -import unittest -class TestEcker2015(unittest.TestCase): +class TestEcker2015: def test_functions(self): param = pybamm.ParameterValues("Ecker2015") sto = pybamm.Scalar(0.5) @@ -40,16 +40,6 @@ def test_functions(self): } for name, value in fun_test.items(): - self.assertAlmostEqual( - param.evaluate(param[name](*value[0])), value[1], places=4 + assert param.evaluate(param[name](*value[0])) == pytest.approx( + value[1], abs=0.0001 ) - - -if __name__ == "__main__": - print("Add -v for more debug output") - import sys - - if "-v" in sys.argv: - debug = True - pybamm.settings.debug_mode = True - unittest.main() From 591315406c1342b0b3015b1a0eaee4d475ecdbbb Mon Sep 17 00:00:00 2001 From: "Eric G. Kratz" Date: Tue, 1 Oct 2024 09:56:30 -0400 Subject: [PATCH 3/6] Require setuptools for citations (#4478) * Require setuptools for citations * Update pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 328388bed7..a70f712e3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ plot = [ "matplotlib>=3.6.0", ] cite = [ + "setuptools", # Fix for a pybtex issue "pybtex>=0.24.0", ] # Battery Parameter eXchange format From 5cf68bdbdb93191a42f7a4ec2b79debcb3faa73e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:02:04 +0100 Subject: [PATCH 4/6] chore: update pre-commit hooks (#4476) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.6.7 → v0.6.8](https://github.com/astral-sh/ruff-pre-commit/compare/v0.6.7...v0.6.8) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric G. Kratz --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 306118e254..51e4d7c23e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ ci: repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.6.7" + rev: "v0.6.8" hooks: - id: ruff args: [--fix, --show-fixes] From 08e5cf24f2c3f60afbba851b3b22e514aac80c27 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 1 Oct 2024 10:20:26 -0400 Subject: [PATCH 5/6] Build(deps): bump github/codeql-action in the actions group (#4474) Bumps the actions group with 1 update: [github/codeql-action](https://github.com/github/codeql-action). Updates `github/codeql-action` from 3.26.8 to 3.26.10 - [Release notes](https://github.com/github/codeql-action/releases) - [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/github/codeql-action/compare/294a9d92911152fe08befb9ec03e240add280cb3...e2b3eafc8d227b0241d48be5f425d47c2d750a13) --- updated-dependencies: - dependency-name: github/codeql-action dependency-type: direct:production update-type: version-update:semver-patch dependency-group: actions ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/scorecard.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 3186f0e49b..017919c4e7 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -68,6 +68,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@294a9d92911152fe08befb9ec03e240add280cb3 # v3.26.8 + uses: github/codeql-action/upload-sarif@e2b3eafc8d227b0241d48be5f425d47c2d750a13 # v3.26.10 with: sarif_file: results.sarif From 444ecc13295dee5b8e4cb2813d9220d5724b94b3 Mon Sep 17 00:00:00 2001 From: Marc Berliner <34451391+MarcBerliner@users.noreply.github.com> Date: Wed, 2 Oct 2024 11:54:40 -0400 Subject: [PATCH 6/6] Fast Hermite interpolation and observables (#4464) * fast observables * fix release and types * good * faster interp * `double` -> `realtype` * clean separation * good again * private members * cleanup * fix codecov, tests * naming * codecov * codacy, cse/expand * fix `try/except` * Update CHANGELOG.md * address comments * initialize `save_hermite` * fix codecov --------- Co-authored-by: Eric G. Kratz --- CHANGELOG.md | 1 + CMakeLists.txt | 2 + setup.py | 2 + src/pybamm/__init__.py | 2 +- src/pybamm/plotting/quick_plot.py | 17 +- src/pybamm/solvers/c_solvers/idaklu.cpp | 26 + .../c_solvers/idaklu/IDAKLUSolverOpenMP.hpp | 43 +- .../c_solvers/idaklu/IDAKLUSolverOpenMP.inl | 224 ++++- .../solvers/c_solvers/idaklu/Options.cpp | 1 + .../solvers/c_solvers/idaklu/Options.hpp | 1 + .../solvers/c_solvers/idaklu/Solution.hpp | 6 +- .../solvers/c_solvers/idaklu/SolutionData.cpp | 34 +- .../solvers/c_solvers/idaklu/SolutionData.hpp | 9 + .../solvers/c_solvers/idaklu/common.hpp | 1 + .../solvers/c_solvers/idaklu/observe.cpp | 343 +++++++ .../solvers/c_solvers/idaklu/observe.hpp | 42 + src/pybamm/solvers/idaklu_solver.py | 50 +- src/pybamm/solvers/processed_variable.py | 914 ++++++++++++++---- .../solvers/processed_variable_computed.py | 60 +- src/pybamm/solvers/solution.py | 181 +++- .../test_models/standard_output_comparison.py | 1 + .../test_simulation_with_experiment.py | 8 +- tests/unit/test_solvers/test_idaklu_solver.py | 58 +- .../test_solvers/test_processed_variable.py | 580 +++++++---- .../test_processed_variable_computed.py | 19 +- tests/unit/test_solvers/test_solution.py | 6 +- 26 files changed, 2021 insertions(+), 610 deletions(-) create mode 100644 src/pybamm/solvers/c_solvers/idaklu/observe.cpp create mode 100644 src/pybamm/solvers/c_solvers/idaklu/observe.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index f70aae0272..fd590fbc12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- Added Hermite interpolation to the (`IDAKLUSolver`) that improves the accuracy and performance of post-processing variables. ([#4464](https://github.com/pybamm-team/PyBaMM/pull/4464)) - Added sensitivity calculation support for `pybamm.Simulation` and `pybamm.Experiment` ([#4415](https://github.com/pybamm-team/PyBaMM/pull/4415)) - Added OpenMP parallelization to IDAKLU solver for lists of input parameters ([#4449](https://github.com/pybamm-team/PyBaMM/pull/4449)) - Added phase-dependent particle options to LAM diff --git a/CMakeLists.txt b/CMakeLists.txt index a7f68ce7a0..ec594e5ca5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,6 +105,8 @@ pybind11_add_module(idaklu src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/Expression.hpp src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionSet.hpp src/pybamm/solvers/c_solvers/idaklu/Expressions/Base/ExpressionTypes.hpp + src/pybamm/solvers/c_solvers/idaklu/observe.hpp + src/pybamm/solvers/c_solvers/idaklu/observe.cpp # IDAKLU expressions - concrete implementations ${IDAKLU_EXPR_CASADI_SOURCE_FILES} ${IDAKLU_EXPR_IREE_SOURCE_FILES} diff --git a/setup.py b/setup.py index 74de1baca4..8a49bfd715 100644 --- a/setup.py +++ b/setup.py @@ -327,6 +327,8 @@ def compile_KLU(): "src/pybamm/solvers/c_solvers/idaklu/Solution.hpp", "src/pybamm/solvers/c_solvers/idaklu/Options.hpp", "src/pybamm/solvers/c_solvers/idaklu/Options.cpp", + "src/pybamm/solvers/c_solvers/idaklu/observe.hpp", + "src/pybamm/solvers/c_solvers/idaklu/observe.cpp", "src/pybamm/solvers/c_solvers/idaklu.cpp", ], ) diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index 36ad0b137a..51c7f49969 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -157,7 +157,7 @@ # Solver classes from .solvers.solution import Solution, EmptySolution, make_cycle_solution -from .solvers.processed_variable import ProcessedVariable +from .solvers.processed_variable import ProcessedVariable, process_variable from .solvers.processed_variable_computed import ProcessedVariableComputed from .solvers.base_solver import BaseSolver from .solvers.dummy_solver import DummySolver diff --git a/src/pybamm/plotting/quick_plot.py b/src/pybamm/plotting/quick_plot.py index 39dc974f9b..cddce58d77 100644 --- a/src/pybamm/plotting/quick_plot.py +++ b/src/pybamm/plotting/quick_plot.py @@ -419,14 +419,14 @@ def reset_axis(self): spatial_vars = self.spatial_variable_dict[key] var_min = np.min( [ - ax_min(var(self.ts_seconds[i], **spatial_vars, warn=False)) + ax_min(var(self.ts_seconds[i], **spatial_vars)) for i, variable_list in enumerate(variable_lists) for var in variable_list ] ) var_max = np.max( [ - ax_max(var(self.ts_seconds[i], **spatial_vars, warn=False)) + ax_max(var(self.ts_seconds[i], **spatial_vars)) for i, variable_list in enumerate(variable_lists) for var in variable_list ] @@ -512,7 +512,7 @@ def plot(self, t, dynamic=False): full_t = self.ts_seconds[i] (self.plots[key][i][j],) = ax.plot( full_t / self.time_scaling_factor, - variable(full_t, warn=False), + variable(full_t), color=self.colors[i], linestyle=linestyle, ) @@ -548,7 +548,7 @@ def plot(self, t, dynamic=False): linestyle = self.linestyles[j] (self.plots[key][i][j],) = ax.plot( self.first_spatial_variable[key], - variable(t_in_seconds, **spatial_vars, warn=False), + variable(t_in_seconds, **spatial_vars), color=self.colors[i], linestyle=linestyle, zorder=10, @@ -570,13 +570,13 @@ def plot(self, t, dynamic=False): y_name = next(iter(spatial_vars.keys()))[0] x = self.second_spatial_variable[key] y = self.first_spatial_variable[key] - var = variable(t_in_seconds, **spatial_vars, warn=False) + var = variable(t_in_seconds, **spatial_vars) else: x_name = next(iter(spatial_vars.keys()))[0] y_name = list(spatial_vars.keys())[1][0] x = self.first_spatial_variable[key] y = self.second_spatial_variable[key] - var = variable(t_in_seconds, **spatial_vars, warn=False).T + var = variable(t_in_seconds, **spatial_vars).T ax.set_xlabel(f"{x_name} [{self.spatial_unit}]") ax.set_ylabel(f"{y_name} [{self.spatial_unit}]") vmin, vmax = self.variable_limits[key] @@ -710,7 +710,6 @@ def slider_update(self, t): var = variable( time_in_seconds, **self.spatial_variable_dict[key], - warn=False, ) plot[i][j].set_ydata(var) var_min = min(var_min, ax_min(var)) @@ -729,11 +728,11 @@ def slider_update(self, t): if self.x_first_and_y_second[key] is False: x = self.second_spatial_variable[key] y = self.first_spatial_variable[key] - var = variable(time_in_seconds, **spatial_vars, warn=False) + var = variable(time_in_seconds, **spatial_vars) else: x = self.first_spatial_variable[key] y = self.second_spatial_variable[key] - var = variable(time_in_seconds, **spatial_vars, warn=False).T + var = variable(time_in_seconds, **spatial_vars).T # store the plot and the var data (for testing) as cant access # z data from QuadMesh or QuadContourSet object if self.is_y_z[key] is True: diff --git a/src/pybamm/solvers/c_solvers/idaklu.cpp b/src/pybamm/solvers/c_solvers/idaklu.cpp index db7147feb2..82a3cbe91c 100644 --- a/src/pybamm/solvers/c_solvers/idaklu.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu.cpp @@ -9,6 +9,7 @@ #include #include "idaklu/idaklu_solver.hpp" +#include "idaklu/observe.hpp" #include "idaklu/IDAKLUSolverGroup.hpp" #include "idaklu/IdakluJax.hpp" #include "idaklu/common.hpp" @@ -27,6 +28,7 @@ casadi::Function generate_casadi_function(const std::string &data) namespace py = pybind11; PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MODULE(idaklu, m) @@ -34,6 +36,7 @@ PYBIND11_MODULE(idaklu, m) m.doc() = "sundials solvers"; // optional module docstring py::bind_vector>(m, "VectorNdArray"); + py::bind_vector>(m, "VectorRealtypeNdArray"); py::bind_vector>(m, "VectorSolution"); py::class_(m, "IDAKLUSolverGroup") @@ -72,6 +75,27 @@ PYBIND11_MODULE(idaklu, m) py::arg("options"), py::return_value_policy::take_ownership); + m.def("observe", &observe, + "Observe variables", + py::arg("ts"), + py::arg("ys"), + py::arg("inputs"), + py::arg("funcs"), + py::arg("is_f_contiguous"), + py::arg("shape"), + py::return_value_policy::take_ownership); + + m.def("observe_hermite_interp", &observe_hermite_interp, + "Observe and Hermite interpolate variables", + py::arg("t_interp"), + py::arg("ts"), + py::arg("ys"), + py::arg("yps"), + py::arg("inputs"), + py::arg("funcs"), + py::arg("shape"), + py::return_value_policy::take_ownership); + #ifdef IREE_ENABLE m.def("create_iree_solver_group", &create_idaklu_solver_group, "Create a group of iree idaklu solver objects", @@ -167,7 +191,9 @@ PYBIND11_MODULE(idaklu, m) py::class_(m, "solution") .def_readwrite("t", &Solution::t) .def_readwrite("y", &Solution::y) + .def_readwrite("yp", &Solution::yp) .def_readwrite("yS", &Solution::yS) + .def_readwrite("ypS", &Solution::ypS) .def_readwrite("y_term", &Solution::y_term) .def_readwrite("flag", &Solution::flag); } diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp index 92eede3643..ee2c03abff 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.hpp @@ -54,9 +54,9 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver int const number_of_events; // cppcheck-suppress unusedStructMember int number_of_timesteps; int precon_type; // cppcheck-suppress unusedStructMember - N_Vector yy, yp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance + N_Vector yy, yyp, y_cache, avtol; // y, y', y cache vector, and absolute tolerance N_Vector *yyS; // cppcheck-suppress unusedStructMember - N_Vector *ypS; // cppcheck-suppress unusedStructMember + N_Vector *yypS; // cppcheck-suppress unusedStructMember N_Vector id; // rhs_alg_id realtype rtol; int const jac_times_cjmass_nnz; // cppcheck-suppress unusedStructMember @@ -70,11 +70,14 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver vector res_dvar_dp; bool const sensitivity; // cppcheck-suppress unusedStructMember bool const save_outputs_only; // cppcheck-suppress unusedStructMember + bool save_hermite; // cppcheck-suppress unusedStructMember bool is_ODE; // cppcheck-suppress unusedStructMember int length_of_return_vector; // cppcheck-suppress unusedStructMember vector t; // cppcheck-suppress unusedStructMember vector> y; // cppcheck-suppress unusedStructMember + vector> yp; // cppcheck-suppress unusedStructMember vector>> yS; // cppcheck-suppress unusedStructMember + vector>> ypS; // cppcheck-suppress unusedStructMember SetupOptions const setup_opts; SolverOptions const solver_opts; @@ -144,6 +147,11 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver */ void InitializeStorage(int const N); + /** + * @brief Initialize the storage for Hermite interpolation + */ + void InitializeHermiteStorage(int const N); + /** * @brief Apply user-configurable IDA options */ @@ -190,13 +198,20 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver */ void ExtendAdaptiveArrays(); + /** + * @brief Extend the Hermite interpolation info by 1 + */ + void ExtendHermiteArrays(); + /** * @brief Set the step values */ void SetStep( - realtype &t_val, + realtype &tval, realtype *y_val, + realtype *yp_val, vector const &yS_val, + vector const &ypS_val, int &i_save ); @@ -211,7 +226,9 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver realtype &t_prev, realtype const &t_next, realtype *y_val, + realtype *yp_val, vector const &yS_val, + vector const &ypS_val, int &i_save ); @@ -255,6 +272,26 @@ class IDAKLUSolverOpenMP : public IDAKLUSolver int &i_save ); + /** + * @brief Save the output function results at the requested time + */ + void SetStepHermite( + realtype &t_val, + realtype *yp_val, + const vector &ypS_val, + int &i_save + ); + + /** + * @brief Save the output function sensitivities at the requested time + */ + void SetStepHermiteSensitivities( + realtype &t_val, + realtype *yp_val, + const vector &ypS_val, + int &i_save + ); + }; #include "IDAKLUSolverOpenMP.inl" diff --git a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl index 56f546facf..d128ae1809 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl +++ b/src/pybamm/solvers/c_solvers/idaklu/IDAKLUSolverOpenMP.inl @@ -1,7 +1,6 @@ #include "Expressions/Expressions.hpp" #include "sundials_functions.hpp" #include - #include "common.hpp" #include "SolutionData.hpp" @@ -48,7 +47,7 @@ IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( AllocateVectors(); if (sensitivity) { yyS = N_VCloneVectorArray(number_of_parameters, yy); - ypS = N_VCloneVectorArray(number_of_parameters, yp); + yypS = N_VCloneVectorArray(number_of_parameters, yyp); } // set initial values realtype *atval = N_VGetArrayPointer(avtol); @@ -58,14 +57,14 @@ IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( for (int is = 0; is < number_of_parameters; is++) { N_VConst(RCONST(0.0), yyS[is]); - N_VConst(RCONST(0.0), ypS[is]); + N_VConst(RCONST(0.0), yypS[is]); } // create Matrix objects SetMatrix(); // initialise solver - IDAInit(ida_mem, residual_eval, 0, yy, yp); + IDAInit(ida_mem, residual_eval, 0, yy, yyp); // set tolerances rtol = RCONST(rel_tol); @@ -87,6 +86,9 @@ IDAKLUSolverOpenMP::IDAKLUSolverOpenMP( // The default is to solve a DAE for generality. This may be changed // to an ODE during the Initialize() call is_ODE = false; + + // Will be overwritten during the solve() call + save_hermite = solver_opts.hermite_interpolation; } template @@ -95,14 +97,14 @@ void IDAKLUSolverOpenMP::AllocateVectors() { // Create vectors if (setup_opts.num_threads == 1) { yy = N_VNew_Serial(number_of_states, sunctx); - yp = N_VNew_Serial(number_of_states, sunctx); + yyp = N_VNew_Serial(number_of_states, sunctx); y_cache = N_VNew_Serial(number_of_states, sunctx); avtol = N_VNew_Serial(number_of_states, sunctx); id = N_VNew_Serial(number_of_states, sunctx); } else { DEBUG("IDAKLUSolverOpenMP::AllocateVectors OpenMP"); yy = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); - yp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); + yyp = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); y_cache = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); avtol = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); id = N_VNew_OpenMP(number_of_states, setup_opts.num_threads, sunctx); @@ -127,6 +129,26 @@ void IDAKLUSolverOpenMP::InitializeStorage(int const N) { vector(length_of_return_vector, 0.0) ) ); + + if (save_hermite) { + InitializeHermiteStorage(N); + } +} + +template +void IDAKLUSolverOpenMP::InitializeHermiteStorage(int const N) { + yp = vector>( + N, + vector(number_of_states, 0.0) + ); + + ypS = vector>>( + N, + vector>( + number_of_parameters, + vector(number_of_states, 0.0) + ) + ); } template @@ -285,7 +307,7 @@ void IDAKLUSolverOpenMP::Initialize() { if (sensitivity) { CheckErrors(IDASensInit(ida_mem, number_of_parameters, IDA_SIMULTANEOUS, - sensitivities_eval, yyS, ypS)); + sensitivities_eval, yyS, yypS)); CheckErrors(IDASensEEtolerances(ida_mem)); } @@ -321,13 +343,13 @@ IDAKLUSolverOpenMP::~IDAKLUSolverOpenMP() { SUNMatDestroy(J); N_VDestroy(avtol); N_VDestroy(yy); - N_VDestroy(yp); + N_VDestroy(yyp); N_VDestroy(y_cache); N_VDestroy(id); if (sensitivity) { N_VDestroyVectorArray(yyS, number_of_parameters); - N_VDestroyVectorArray(ypS, number_of_parameters); + N_VDestroyVectorArray(yypS, number_of_parameters); } IDAFree(&ida_mem); @@ -349,9 +371,15 @@ SolutionData IDAKLUSolverOpenMP::solve( const int number_of_evals = t_eval.size(); const int number_of_interps = t_interp.size(); - if (t.size() < number_of_evals + number_of_interps) { - InitializeStorage(number_of_evals + number_of_interps); - } + // Hermite interpolation is only available when saving + // 1. adaptive steps and 2. the full solution + save_hermite = ( + solver_opts.hermite_interpolation && + save_adaptive_steps && + !save_outputs_only + ); + + InitializeStorage(number_of_evals + number_of_interps); int i_save = 0; @@ -378,12 +406,12 @@ SolutionData IDAKLUSolverOpenMP::solve( // Setup consistent initialization realtype *y_val = N_VGetArrayPointer(yy); - realtype *yp_val = N_VGetArrayPointer(yp); + realtype *yp_val = N_VGetArrayPointer(yyp); vector yS_val(number_of_parameters); vector ypS_val(number_of_parameters); for (int p = 0 ; p < number_of_parameters; p++) { yS_val[p] = N_VGetArrayPointer(yyS[p]); - ypS_val[p] = N_VGetArrayPointer(ypS[p]); + ypS_val[p] = N_VGetArrayPointer(yypS[p]); for (int i = 0; i < number_of_states; i++) { yS_val[p][i] = y0[i + (p + 1) * number_of_states]; ypS_val[p][i] = yp0[i + (p + 1) * number_of_states]; @@ -409,23 +437,27 @@ SolutionData IDAKLUSolverOpenMP::solve( ConsistentInitialization(t0, t_eval_next, init_type); } + // Set the initial stop time + IDASetStopTime(ida_mem, t_eval_next); + + // Progress one step. This must be done before the while loop to ensure + // that we can run IDAGetDky at t0 for dky = 1 + int retval = IDASolve(ida_mem, tf, &t_val, yy, yyp, IDA_ONE_STEP); + + // Store consistent initialization + CheckErrors(IDAGetDky(ida_mem, t0, 0, yy)); if (sensitivity) { - CheckErrors(IDAGetSensDky(ida_mem, t_val, 0, yyS)); + CheckErrors(IDAGetSensDky(ida_mem, t0, 0, yyS)); } - // Store Consistent initialization - SetStep(t0, y_val, yS_val, i_save); + SetStep(t0, y_val, yp_val, yS_val, ypS_val, i_save); - // Set the initial stop time - IDASetStopTime(ida_mem, t_eval_next); + // Reset the states at t = t_val. Sensitivities are handled in the while-loop + CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy)); // Solve the system - int retval; DEBUG("IDASolve"); while (true) { - // Progress one step - retval = IDASolve(ida_mem, tf, &t_val, yy, yp, IDA_ONE_STEP); - if (retval < 0) { // failed break; @@ -448,18 +480,21 @@ SolutionData IDAKLUSolverOpenMP::solve( if (hit_tinterp) { // Save the interpolated state at t_prev < t < t_val, for all t in t_interp - SetStepInterp(i_interp, + SetStepInterp( + i_interp, t_interp_next, t_interp, t_val, t_prev, t_eval_next, y_val, + yp_val, yS_val, + ypS_val, i_save); } - if (hit_adaptive || hit_teval || hit_event) { + if (hit_adaptive || hit_teval || hit_event || hit_final_time) { if (hit_tinterp) { // Reset the states and sensitivities at t = t_val CheckErrors(IDAGetDky(ida_mem, t_val, 0, yy)); @@ -469,11 +504,16 @@ SolutionData IDAKLUSolverOpenMP::solve( } // Save the current state at t_val - if (hit_adaptive) { - // Dynamically allocate memory for the adaptive step - ExtendAdaptiveArrays(); + // First, check to make sure that the t_val is not equal to the current t value + // If it is, we don't want to save the current state twice + if (!hit_tinterp || t_val != t.back()) { + if (hit_adaptive) { + // Dynamically allocate memory for the adaptive step + ExtendAdaptiveArrays(); + } + + SetStep(t_val, y_val, yp_val, yS_val, ypS_val, i_save); } - SetStep(t_val, y_val, yS_val, i_save); } if (hit_final_time || hit_event) { @@ -481,7 +521,7 @@ SolutionData IDAKLUSolverOpenMP::solve( break; } else if (hit_teval) { // Set the next stop time - i_eval += 1; + i_eval++; t_eval_next = t_eval[i_eval]; CheckErrors(IDASetStopTime(ida_mem, t_eval_next)); @@ -491,6 +531,9 @@ SolutionData IDAKLUSolverOpenMP::solve( } t_prev = t_val; + + // Progress one step + retval = IDASolve(ida_mem, tf, &t_val, yy, yyp, IDA_ONE_STEP); } int const length_of_final_sv_slice = save_outputs_only ? number_of_states : 0; @@ -547,7 +590,50 @@ SolutionData IDAKLUSolverOpenMP::solve( } } - return SolutionData(retval, number_of_timesteps, length_of_return_vector, arg_sens0, arg_sens1, arg_sens2, length_of_final_sv_slice, t_return, y_return, yS_return, yterm_return); + realtype *yp_return = new realtype[(save_hermite ? 1 : 0) * (number_of_timesteps * number_of_states)]; + realtype *ypS_return = new realtype[(save_hermite ? 1 : 0) * (arg_sens0 * arg_sens1 * arg_sens2)]; + if (save_hermite) { + count = 0; + for (size_t i = 0; i < number_of_timesteps; i++) { + for (size_t j = 0; j < number_of_states; j++) { + yp_return[count] = yp[i][j]; + count++; + } + } + + // Sensitivity states, ypS + // Note: Ordering of vector is different if computing outputs vs returning + // the complete state vector + count = 0; + for (size_t idx0 = 0; idx0 < arg_sens0; idx0++) { + for (size_t idx1 = 0; idx1 < arg_sens1; idx1++) { + for (size_t idx2 = 0; idx2 < arg_sens2; idx2++) { + auto i = (save_outputs_only ? idx0 : idx1); + auto j = (save_outputs_only ? idx1 : idx2); + auto k = (save_outputs_only ? idx2 : idx0); + + ypS_return[count] = ypS[i][k][j]; + count++; + } + } + } + } + + return SolutionData( + retval, + number_of_timesteps, + length_of_return_vector, + arg_sens0, + arg_sens1, + arg_sens2, + length_of_final_sv_slice, + save_hermite, + t_return, + y_return, + yp_return, + yS_return, + ypS_return, + yterm_return); } template @@ -563,14 +649,30 @@ void IDAKLUSolverOpenMP::ExtendAdaptiveArrays() { if (sensitivity) { yS.emplace_back(number_of_parameters, vector(length_of_return_vector, 0.0)); } + + if (save_hermite) { + ExtendHermiteArrays(); + } +} + +template +void IDAKLUSolverOpenMP::ExtendHermiteArrays() { + DEBUG("IDAKLUSolver::ExtendHermiteArrays"); + // States + yp.emplace_back(number_of_states, 0.0); + + // Sensitivity + if (sensitivity) { + ypS.emplace_back(number_of_parameters, vector(number_of_states, 0.0)); + } } template void IDAKLUSolverOpenMP::ReinitializeIntegrator(const realtype& t_val) { DEBUG("IDAKLUSolver::ReinitializeIntegrator"); - CheckErrors(IDAReInit(ida_mem, t_val, yy, yp)); + CheckErrors(IDAReInit(ida_mem, t_val, yy, yyp)); if (sensitivity) { - CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, ypS)); + CheckErrors(IDASensReInit(ida_mem, IDA_SIMULTANEOUS, yyS, yypS)); } } @@ -609,14 +711,16 @@ void IDAKLUSolverOpenMP::ConsistentInitializationODE( realtype *y_cache_val = N_VGetArrayPointer(y_cache); std::memset(y_cache_val, 0, number_of_states * sizeof(realtype)); // Overwrite yp - residual_eval(t_val, yy, y_cache, yp, functions.get()); + residual_eval(t_val, yy, y_cache, yyp, functions.get()); } template void IDAKLUSolverOpenMP::SetStep( realtype &tval, realtype *y_val, + realtype *yp_val, vector const &yS_val, + vector const &ypS_val, int &i_save ) { // Set adaptive step results for y and yS @@ -629,6 +733,10 @@ void IDAKLUSolverOpenMP::SetStep( SetStepOutput(tval, y_val, yS_val, i_save); } else { SetStepFull(tval, y_val, yS_val, i_save); + + if (save_hermite) { + SetStepHermite(tval, yp_val, ypS_val, i_save); + } } i_save++; @@ -644,7 +752,9 @@ void IDAKLUSolverOpenMP::SetStepInterp( realtype &t_prev, realtype const &t_eval_next, realtype *y_val, + realtype *yp_val, vector const &yS_val, + vector const &ypS_val, int &i_save ) { // Save the state at the requested time @@ -657,7 +767,7 @@ void IDAKLUSolverOpenMP::SetStepInterp( } // Memory is already allocated for the interpolated values - SetStep(t_interp_next, y_val, yS_val, i_save); + SetStep(t_interp_next, y_val, yp_val, yS_val, ypS_val, i_save); i_interp++; if (i_interp == (t_interp.size())) { @@ -769,6 +879,50 @@ void IDAKLUSolverOpenMP::SetStepOutputSensitivities( } } +template +void IDAKLUSolverOpenMP::SetStepHermite( + realtype &tval, + realtype *yp_val, + vector const &ypS_val, + int &i_save +) { + // Set adaptive step results for yp and ypS + DEBUG("IDAKLUSolver::SetStepHermite"); + + // States + CheckErrors(IDAGetDky(ida_mem, tval, 1, yyp)); + auto &yp_back = yp[i_save]; + for (size_t j = 0; j < length_of_return_vector; ++j) { + yp_back[j] = yp_val[j]; + + } + + // Sensitivity + if (sensitivity) { + SetStepHermiteSensitivities(tval, yp_val, ypS_val, i_save); + } +} + +template +void IDAKLUSolverOpenMP::SetStepHermiteSensitivities( + realtype &tval, + realtype *yp_val, + vector const &ypS_val, + int &i_save +) { + DEBUG("IDAKLUSolver::SetStepHermiteSensitivities"); + + // Calculate sensitivities for the full ypS array + CheckErrors(IDAGetSensDky(ida_mem, tval, 1, yypS)); + for (size_t j = 0; j < number_of_parameters; ++j) { + auto &ypS_back_j = ypS[i_save][j]; + auto &ypSval_j = ypS_val[j]; + for (size_t k = 0; k < number_of_states; ++k) { + ypS_back_j[k] = ypSval_j[k]; + } + } +} + template void IDAKLUSolverOpenMP::CheckErrors(int const & flag) { if (flag < 0) { diff --git a/src/pybamm/solvers/c_solvers/idaklu/Options.cpp b/src/pybamm/solvers/c_solvers/idaklu/Options.cpp index 51544040ee..8eb605fe77 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Options.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Options.cpp @@ -149,6 +149,7 @@ SolverOptions::SolverOptions(py::dict &py_opts) nonlinear_convergence_coefficient(RCONST(py_opts["nonlinear_convergence_coefficient"].cast())), nonlinear_convergence_coefficient_ic(RCONST(py_opts["nonlinear_convergence_coefficient_ic"].cast())), suppress_algebraic_error(py_opts["suppress_algebraic_error"].cast()), + hermite_interpolation(py_opts["hermite_interpolation"].cast()), // IDA initial conditions calculation calc_ic(py_opts["calc_ic"].cast()), init_all_y_ic(py_opts["init_all_y_ic"].cast()), diff --git a/src/pybamm/solvers/c_solvers/idaklu/Options.hpp b/src/pybamm/solvers/c_solvers/idaklu/Options.hpp index d0c0c1d766..7418c68ec3 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Options.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Options.hpp @@ -38,6 +38,7 @@ struct SolverOptions { double nonlinear_convergence_coefficient; double nonlinear_convergence_coefficient_ic; sunbooleantype suppress_algebraic_error; + bool hermite_interpolation; // IDA initial conditions calculation bool calc_ic; bool init_all_y_ic; diff --git a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp index a43e6a7174..8227bb9da8 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/Solution.hpp @@ -17,8 +17,8 @@ class Solution /** * @brief Constructor */ - Solution(int &retval, np_array &t_np, np_array &y_np, np_array &yS_np, np_array &y_term_np) - : flag(retval), t(t_np), y(y_np), yS(yS_np), y_term(y_term_np) + Solution(int &retval, np_array &t_np, np_array &y_np, np_array &yp_np, np_array &yS_np, np_array &ypS_np, np_array &y_term_np) + : flag(retval), t(t_np), y(y_np), yp(yp_np), yS(yS_np), ypS(ypS_np), y_term(y_term_np) { } @@ -30,7 +30,9 @@ class Solution int flag; np_array t; np_array y; + np_array yp; np_array yS; + np_array ypS; np_array y_term; }; diff --git a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp index 00c2ddbccc..bc48c646d3 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp +++ b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.cpp @@ -29,6 +29,20 @@ Solution SolutionData::generate_solution() { free_y_when_done ); + py::capsule free_yp_when_done( + yp_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array yp_ret = np_array( + (save_hermite ? 1 : 0) * number_of_timesteps * length_of_return_vector, + &yp_return[0], + free_yp_when_done + ); + py::capsule free_yS_when_done( yS_return, [](void *f) { @@ -47,6 +61,24 @@ Solution SolutionData::generate_solution() { free_yS_when_done ); + py::capsule free_ypS_when_done( + ypS_return, + [](void *f) { + realtype *vect = reinterpret_cast(f); + delete[] vect; + } + ); + + np_array ypS_ret = np_array( + std::vector { + (save_hermite ? 1 : 0) * arg_sens0, + arg_sens1, + arg_sens2 + }, + &ypS_return[0], + free_ypS_when_done + ); + // Final state slice, yterm py::capsule free_yterm_when_done( yterm_return, @@ -63,5 +95,5 @@ Solution SolutionData::generate_solution() { ); // Store the solution - return Solution(flag, t_ret, y_ret, yS_ret, y_term); + return Solution(flag, t_ret, y_ret, yp_ret, yS_ret, ypS_ret, y_term); } diff --git a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp index 815e41daca..81ca7f5221 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/SolutionData.hpp @@ -27,9 +27,12 @@ class SolutionData int arg_sens1, int arg_sens2, int length_of_final_sv_slice, + bool save_hermite, realtype *t_return, realtype *y_return, + realtype *yp_return, realtype *yS_return, + realtype *ypS_return, realtype *yterm_return): flag(flag), number_of_timesteps(number_of_timesteps), @@ -38,9 +41,12 @@ class SolutionData arg_sens1(arg_sens1), arg_sens2(arg_sens2), length_of_final_sv_slice(length_of_final_sv_slice), + save_hermite(save_hermite), t_return(t_return), y_return(y_return), + yp_return(yp_return), yS_return(yS_return), + ypS_return(ypS_return), yterm_return(yterm_return) {} @@ -64,9 +70,12 @@ class SolutionData int arg_sens1; int arg_sens2; int length_of_final_sv_slice; + bool save_hermite; realtype *t_return; realtype *y_return; + realtype *yp_return; realtype *yS_return; + realtype *ypS_return; realtype *yterm_return; }; diff --git a/src/pybamm/solvers/c_solvers/idaklu/common.hpp b/src/pybamm/solvers/c_solvers/idaklu/common.hpp index 58be90932e..90672080b6 100644 --- a/src/pybamm/solvers/c_solvers/idaklu/common.hpp +++ b/src/pybamm/solvers/c_solvers/idaklu/common.hpp @@ -33,6 +33,7 @@ namespace py = pybind11; // note: we rely on c_style ordering for numpy arrays so don't change this! using np_array = py::array_t; +using np_array_realtype = py::array_t; using np_array_int = py::array_t; /** diff --git a/src/pybamm/solvers/c_solvers/idaklu/observe.cpp b/src/pybamm/solvers/c_solvers/idaklu/observe.cpp new file mode 100644 index 0000000000..8f1d90e55d --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/observe.cpp @@ -0,0 +1,343 @@ +#include "observe.hpp" + +int _setup_len_spatial(const std::vector& shape) { + // Calculate the product of all dimensions except the last (spatial dimensions) + int size_spatial = 1; + for (size_t i = 0; i < shape.size() - 1; ++i) { + size_spatial *= shape[i]; + } + + if (size_spatial == 0 || shape.back() == 0) { + throw std::invalid_argument("output array must have at least one element"); + } + + return size_spatial; +} + +// Coupled observe and Hermite interpolation of variables +class HermiteInterpolator { +public: + HermiteInterpolator(const py::detail::unchecked_reference& t, + const py::detail::unchecked_reference& y, + const py::detail::unchecked_reference& yp) + : t(t), y(y), yp(yp) {} + + void compute_knots(const size_t j, vector& c, vector& d) const { + // Called at the start of each interval + const realtype h_full = t(j + 1) - t(j); + const realtype inv_h = 1.0 / h_full; + const realtype inv_h2 = inv_h * inv_h; + const realtype inv_h3 = inv_h2 * inv_h; + + for (size_t i = 0; i < y.shape(0); ++i) { + realtype y_ij = y(i, j); + realtype yp_ij = yp(i, j); + realtype y_ijp1 = y(i, j + 1); + realtype yp_ijp1 = yp(i, j + 1); + + c[i] = 3.0 * (y_ijp1 - y_ij) * inv_h2 - (2.0 * yp_ij + yp_ijp1) * inv_h; + d[i] = 2.0 * (y_ij - y_ijp1) * inv_h3 + (yp_ij + yp_ijp1) * inv_h2; + } + } + + void interpolate(vector& entries, + realtype t_interp, + const size_t j, + vector& c, + vector& d) const { + // Must be called after compute_knots + const realtype h = t_interp - t(j); + const realtype h2 = h * h; + const realtype h3 = h2 * h; + + for (size_t i = 0; i < entries.size(); ++i) { + realtype y_ij = y(i, j); + realtype yp_ij = yp(i, j); + entries[i] = y_ij + yp_ij * h + c[i] * h2 + d[i] * h3; + } + } + +private: + const py::detail::unchecked_reference& t; + const py::detail::unchecked_reference& y; + const py::detail::unchecked_reference& yp; +}; + +class TimeSeriesInterpolator { +public: + TimeSeriesInterpolator(const np_array_realtype& _t_interp, + const vector& _ts_data, + const vector& _ys_data, + const vector& _yps_data, + const vector& _inputs, + const vector>& _funcs, + realtype* _entries, + const int _size_spatial) + : t_interp_np(_t_interp), ts_data_np(_ts_data), ys_data_np(_ys_data), + yps_data_np(_yps_data), inputs_np(_inputs), funcs(_funcs), + entries(_entries), size_spatial(_size_spatial) {} + + void process() { + auto t_interp = t_interp_np.unchecked<1>(); + ssize_t i_interp = 0; + int i_entries = 0; + const ssize_t N_interp = t_interp.size(); + + // Main processing within bounds + process_within_bounds(i_interp, i_entries, t_interp, N_interp); + + // Extrapolation for remaining points + if (i_interp < N_interp) { + extrapolate_remaining(i_interp, i_entries, t_interp, N_interp); + } + } + + void process_within_bounds( + ssize_t& i_interp, + int& i_entries, + const py::detail::unchecked_reference& t_interp, + const ssize_t N_interp + ) { + for (size_t i = 0; i < ts_data_np.size(); i++) { + const auto& t_data = ts_data_np[i].unchecked<1>(); + const realtype t_data_final = t_data(t_data.size() - 1); + realtype t_interp_next = t_interp(i_interp); + // Continue if the next interpolation point is beyond the final data point + if (t_interp_next > t_data_final) { + continue; + } + + const auto& y_data = ys_data_np[i].unchecked<2>(); + const auto& yp_data = yps_data_np[i].unchecked<2>(); + const auto input = inputs_np[i].data(); + const auto func = *funcs[i]; + + resize_arrays(y_data.shape(0), funcs[i]); + args[1] = y_buffer.data(); + args[2] = input; + + ssize_t j = 0; + ssize_t j_prev = -1; + const auto itp = HermiteInterpolator(t_data, y_data, yp_data); + while (t_interp_next <= t_data_final) { + for (; j < t_data.size() - 2; ++j) { + if (t_data(j) <= t_interp_next && t_interp_next <= t_data(j + 1)) { + break; + } + } + + if (j != j_prev) { + // Compute c and d for the new interval + itp.compute_knots(j, c, d); + } + + itp.interpolate(y_buffer, t_interp(i_interp), j, c, d); + + args[0] = &t_interp(i_interp); + results[0] = &entries[i_entries]; + func(args.data(), results.data(), iw.data(), w.data(), 0); + + ++i_interp; + if (i_interp == N_interp) { + return; + } + t_interp_next = t_interp(i_interp); + i_entries += size_spatial; + j_prev = j; + } + } + } + + void extrapolate_remaining( + ssize_t& i_interp, + int& i_entries, + const py::detail::unchecked_reference& t_interp, + const ssize_t N_interp + ) { + const auto& t_data = ts_data_np.back().unchecked<1>(); + const auto& y_data = ys_data_np.back().unchecked<2>(); + const auto& yp_data = yps_data_np.back().unchecked<2>(); + const auto input = inputs_np.back().data(); + const auto func = *funcs.back(); + const ssize_t j = t_data.size() - 2; + + resize_arrays(y_data.shape(0), funcs.back()); + args[1] = y_buffer.data(); + args[2] = input; + + const auto itp = HermiteInterpolator(t_data, y_data, yp_data); + itp.compute_knots(j, c, d); + + for (; i_interp < N_interp; ++i_interp) { + const realtype t_interp_next = t_interp(i_interp); + itp.interpolate(y_buffer, t_interp_next, j, c, d); + + args[0] = &t_interp_next; + results[0] = &entries[i_entries]; + func(args.data(), results.data(), iw.data(), w.data(), 0); + + i_entries += size_spatial; + } + } + + void resize_arrays(const int M, std::shared_ptr func) { + args.resize(func->sz_arg()); + results.resize(func->sz_res()); + iw.resize(func->sz_iw()); + w.resize(func->sz_w()); + if (y_buffer.size() < M) { + y_buffer.resize(M); + c.resize(M); + d.resize(M); + } + } + +private: + const np_array_realtype& t_interp_np; + const vector& ts_data_np; + const vector& ys_data_np; + const vector& yps_data_np; + const vector& inputs_np; + const vector>& funcs; + realtype* entries; + const int size_spatial; + vector c; + vector d; + vector y_buffer; + vector args; + vector results; + vector iw; + vector w; +}; + +// Observe the raw data +class TimeSeriesProcessor { +public: + TimeSeriesProcessor(const vector& _ts, + const vector& _ys, + const vector& _inputs, + const vector>& _funcs, + realtype* _entries, + const bool _is_f_contiguous, + const int _size_spatial) + : ts(_ts), ys(_ys), inputs(_inputs), funcs(_funcs), + entries(_entries), is_f_contiguous(_is_f_contiguous), size_spatial(_size_spatial) {} + + void process() { + int i_entries = 0; + for (size_t i = 0; i < ts.size(); i++) { + const auto& t = ts[i].unchecked<1>(); + const auto& y = ys[i].unchecked<2>(); + const auto input = inputs[i].data(); + const auto func = *funcs[i]; + + resize_arrays(y.shape(0), funcs[i]); + args[2] = input; + + for (size_t j = 0; j < t.size(); j++) { + const realtype t_val = t(j); + const realtype* y_val = is_f_contiguous ? &y(0, j) : copy_to_buffer(y_buffer, y, j); + + args[0] = &t_val; + args[1] = y_val; + results[0] = &entries[i_entries]; + + func(args.data(), results.data(), iw.data(), w.data(), 0); + + i_entries += size_spatial; + } + } + } + +private: + const realtype* copy_to_buffer( + vector& entries, + const py::detail::unchecked_reference& y, + size_t j) { + for (size_t i = 0; i < entries.size(); ++i) { + entries[i] = y(i, j); + } + + return entries.data(); + } + + void resize_arrays(const int M, std::shared_ptr func) { + args.resize(func->sz_arg()); + results.resize(func->sz_res()); + iw.resize(func->sz_iw()); + w.resize(func->sz_w()); + if (!is_f_contiguous && y_buffer.size() < M) { + y_buffer.resize(M); + } + } + + const vector& ts; + const vector& ys; + const vector& inputs; + const vector>& funcs; + realtype* entries; + const bool is_f_contiguous; + int size_spatial; + vector y_buffer; + vector args; + vector results; + vector iw; + vector w; +}; + +const np_array_realtype observe_hermite_interp( + const np_array_realtype& t_interp_np, + const vector& ts_np, + const vector& ys_np, + const vector& yps_np, + const vector& inputs_np, + const vector& strings, + const vector& shape +) { + const int size_spatial = _setup_len_spatial(shape); + const auto& funcs = setup_casadi_funcs(strings); + py::array_t out_array(shape); + auto entries = out_array.mutable_data(); + + TimeSeriesInterpolator(t_interp_np, ts_np, ys_np, yps_np, inputs_np, funcs, entries, size_spatial).process(); + + return out_array; +} + +const np_array_realtype observe( + const vector& ts_np, + const vector& ys_np, + const vector& inputs_np, + const vector& strings, + const bool is_f_contiguous, + const vector& shape +) { + const int size_spatial = _setup_len_spatial(shape); + const auto& funcs = setup_casadi_funcs(strings); + py::array_t out_array(shape); + auto entries = out_array.mutable_data(); + + TimeSeriesProcessor(ts_np, ys_np, inputs_np, funcs, entries, is_f_contiguous, size_spatial).process(); + + return out_array; +} + +const vector> setup_casadi_funcs(const vector& strings) { + std::unordered_map> function_cache; + vector> funcs(strings.size()); + + for (size_t i = 0; i < strings.size(); ++i) { + const std::string& str = strings[i]; + + // Check if function is already in the local cache + if (function_cache.find(str) == function_cache.end()) { + // If not in the cache, create a new casadi::Function::deserialize and store it + function_cache[str] = std::make_shared(casadi::Function::deserialize(str)); + } + + // Retrieve the function from the cache as a shared pointer + funcs[i] = function_cache[str]; + } + + return funcs; +} diff --git a/src/pybamm/solvers/c_solvers/idaklu/observe.hpp b/src/pybamm/solvers/c_solvers/idaklu/observe.hpp new file mode 100644 index 0000000000..e8dc432240 --- /dev/null +++ b/src/pybamm/solvers/c_solvers/idaklu/observe.hpp @@ -0,0 +1,42 @@ +#ifndef PYBAMM_CREATE_OBSERVE_HPP +#define PYBAMM_CREATE_OBSERVE_HPP + +#include +#include +#include +#include "common.hpp" +#include +#include +using std::vector; + +/** + * @brief Observe and Hermite interpolate ND variables + */ +const np_array_realtype observe_hermite_interp( + const np_array_realtype& t_interp, + const vector& ts, + const vector& ys, + const vector& yps, + const vector& inputs, + const vector& strings, + const vector& shape +); + + +/** + * @brief Observe ND variables + */ +const np_array_realtype observe( + const vector& ts_np, + const vector& ys_np, + const vector& inputs_np, + const vector& strings, + const bool is_f_contiguous, + const vector& shape +); + +const vector> setup_casadi_funcs(const vector& strings); + +int _setup_len_spatial(const vector& shape); + +#endif // PYBAMM_CREATE_OBSERVE_HPP diff --git a/src/pybamm/solvers/idaklu_solver.py b/src/pybamm/solvers/idaklu_solver.py index 2b2d852697..80eaffebf4 100644 --- a/src/pybamm/solvers/idaklu_solver.py +++ b/src/pybamm/solvers/idaklu_solver.py @@ -30,6 +30,7 @@ if idaklu_spec.loader: idaklu_spec.loader.exec_module(idaklu) except ImportError as e: # pragma: no cover + idaklu = None print(f"Error loading idaklu: {e}") idaklu_spec = None @@ -133,6 +134,10 @@ class IDAKLUSolver(pybamm.BaseSolver): "nonlinear_convergence_coefficient": 0.33, # Suppress algebraic variables from error test "suppress_algebraic_error": False, + # Store Hermite interpolation data for the solution. + # Note: this option is always disabled if output_variables are given + # or if t_interp values are specified + "hermite_interpolation": True, ## Initial conditions calculation # Positive constant in the Newton iteration convergence test within the # initial condition calculation @@ -201,6 +206,7 @@ def __init__( "max_convergence_failures": 100, "nonlinear_convergence_coefficient": 0.33, "suppress_algebraic_error": False, + "hermite_interpolation": True, "nonlinear_convergence_coefficient_ic": 0.0033, "max_num_steps_ic": 50, "max_num_jacobians_ic": 40, @@ -756,6 +762,16 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): The times (in seconds) at which to interpolate the solution. Defaults to `None`, which returns the adaptive time-stepping times. """ + if not ( + model.convert_to_format == "casadi" + or ( + model.convert_to_format == "jax" + and self._options["jax_evaluator"] == "iree" + ) + ): # pragma: no cover + # Shouldn't ever reach this point + raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") + inputs_list = inputs_list or [{}] # stack inputs so that they are a 2D array of shape (number_of_inputs, number_of_parameters) @@ -779,20 +795,13 @@ def _integrate(self, model, t_eval, inputs_list=None, t_interp=None): atol = self._check_atol_type(atol, y0full.size) timer = pybamm.Timer() - if model.convert_to_format == "casadi" or ( - model.convert_to_format == "jax" - and self._options["jax_evaluator"] == "iree" - ): - solns = self._setup["solver"].solve( - t_eval, - t_interp, - y0full, - ydot0full, - inputs, - ) - else: # pragma: no cover - # Shouldn't ever reach this point - raise pybamm.SolverError("Unsupported IDAKLU solver configuration.") + solns = self._setup["solver"].solve( + t_eval, + t_interp, + y0full, + ydot0full, + inputs, + ) integration_time = timer.time() return [ @@ -807,7 +816,8 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict): sensitivity_names = self._setup["sensitivity_names"] number_of_timesteps = sol.t.size number_of_states = model.len_rhs_and_alg - if self.output_variables: + save_outputs_only = self.output_variables + if save_outputs_only: # Substitute empty vectors for state vector 'y' y_out = np.zeros((number_of_timesteps * number_of_states, 0)) y_event = sol.y_term @@ -838,6 +848,11 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict): else: raise pybamm.SolverError(f"FAILURE {self._solver_flag(sol.flag)}") + if sol.yp.size > 0: + yp = sol.yp.reshape((number_of_timesteps, number_of_states)).T + else: + yp = None + newsol = pybamm.Solution( sol.t, np.transpose(y_out), @@ -847,10 +862,11 @@ def _post_process_solution(self, sol, model, integration_time, inputs_dict): np.transpose(y_event)[:, np.newaxis], termination, all_sensitivities=yS_out, + all_yps=yp, ) + newsol.integration_time = integration_time - if not self.output_variables: - # print((newsol.y).shape) + if not save_outputs_only: return newsol # Populate variables and sensititivies dictionaries directly diff --git a/src/pybamm/solvers/processed_variable.py b/src/pybamm/solvers/processed_variable.py index 2464466348..5cf928ca7f 100644 --- a/src/pybamm/solvers/processed_variable.py +++ b/src/pybamm/solvers/processed_variable.py @@ -6,6 +6,7 @@ import pybamm from scipy.integrate import cumulative_trapezoid import xarray as xr +import bisect class ProcessedVariable: @@ -23,14 +24,11 @@ class ProcessedVariable: Note that this can be any kind of node in the expression tree, not just a :class:`pybamm.Variable`. When evaluated, returns an array of size (m,n) - base_variable_casadis : list of :class:`casadi.Function` + base_variables_casadi : list of :class:`casadi.Function` A list of casadi functions. When evaluated, returns the same thing as `base_Variable.evaluate` (but more efficiently). solution : :class:`pybamm.Solution` The solution object to be used to create the processed variables - warn : bool, optional - Whether to raise warnings when trying to evaluate time and length scales. - Default is True. """ def __init__( @@ -38,7 +36,6 @@ def __init__( base_variables, base_variables_casadi, solution, - warn=True, cumtrapz_ic=None, ): self.base_variables = base_variables @@ -46,13 +43,13 @@ def __init__( self.all_ts = solution.all_ts self.all_ys = solution.all_ys + self.all_yps = solution.all_yps self.all_inputs = solution.all_inputs self.all_inputs_casadi = solution.all_inputs_casadi self.mesh = base_variables[0].mesh self.domain = base_variables[0].domain self.domains = base_variables[0].domains - self.warn = warn self.cumtrapz_ic = cumtrapz_ic # Process spatial variables @@ -75,46 +72,431 @@ def __init__( self.base_eval_shape = self.base_variables[0].shape self.base_eval_size = self.base_variables[0].size - # xr_data_array is initialized - self._xr_data_array = None + self._xr_array_raw = None + self._entries_raw = None + self._entries_for_interp_raw = None + self._coords_raw = None - # handle 2D (in space) finite element variables differently - if ( - self.mesh - and "current collector" in self.domain - and isinstance(self.mesh, pybamm.ScikitSubMesh2D) + def initialise(self): + if self.entries_raw_initialized: + return + + entries = self.observe_raw() + + t = self.t_pts + entries_for_interp, coords = self._interp_setup(entries, t) + + self._entries_raw = entries + self._entries_for_interp_raw = entries_for_interp + self._coords_raw = coords + + def observe_and_interp(self, t, fill_value): + """ + Interpolate the variable at the given time points and y values. + t must be a sorted array of time points. + """ + + entries = self._observe_hermite_cpp(t) + processed_entries = self._observe_postfix(entries, t) + + tf = self.t_pts[-1] + if t[-1] > tf and fill_value != "extrapolate": + # fill the rest + idx = np.searchsorted(t, tf, side="right") + processed_entries[..., idx:] = fill_value + + return processed_entries + + def observe_raw(self): + """ + Evaluate the base variable at the given time points and y values. + """ + t = self.t_pts + + # For small number of points, use Python + if pybamm.has_idaklu(): + entries = self._observe_raw_cpp() + else: + # Fallback method for when IDAKLU is not available. To be removed + # when the C++ code is migrated to a new repo + entries = self._observe_raw_python() # pragma: no cover + + return self._observe_postfix(entries, t) + + def _setup_cpp_inputs(self, t, full_range): + pybamm.logger.debug("Setting up C++ interpolation inputs") + + ts = self.all_ts + ys = self.all_ys + yps = self.all_yps + inputs = self.all_inputs_casadi + # Find the indices of the time points to observe + if full_range: + idxs = range(len(ts)) + else: + idxs = _find_ts_indices(ts, t) + + if isinstance(idxs, list): + # Extract the time points and inputs + ts = [ts[idx] for idx in idxs] + ys = [ys[idx] for idx in idxs] + if self.hermite_interpolation: + yps = [yps[idx] for idx in idxs] + inputs = [self.all_inputs_casadi[idx] for idx in idxs] + + is_f_contiguous = _is_f_contiguous(ys) + + ts = pybamm.solvers.idaklu_solver.idaklu.VectorRealtypeNdArray(ts) + ys = pybamm.solvers.idaklu_solver.idaklu.VectorRealtypeNdArray(ys) + if self.hermite_interpolation: + yps = pybamm.solvers.idaklu_solver.idaklu.VectorRealtypeNdArray(yps) + else: + yps = None + inputs = pybamm.solvers.idaklu_solver.idaklu.VectorRealtypeNdArray(inputs) + + # Generate the serialized C++ functions only once + funcs_unique = {} + funcs = [None] * len(idxs) + for i in range(len(idxs)): + vars = self.base_variables_casadi[idxs[i]] + if vars not in funcs_unique: + funcs_unique[vars] = vars.serialize() + funcs[i] = funcs_unique[vars] + + return ts, ys, yps, funcs, inputs, is_f_contiguous + + def _observe_hermite_cpp(self, t): + pybamm.logger.debug("Observing and Hermite interpolating the variable in C++") + + ts, ys, yps, funcs, inputs, _ = self._setup_cpp_inputs(t, full_range=False) + shapes = self._shape(t) + return pybamm.solvers.idaklu_solver.idaklu.observe_hermite_interp( + t, ts, ys, yps, inputs, funcs, shapes + ) + + def _observe_raw_cpp(self): + pybamm.logger.debug("Observing the variable raw data in C++") + t = self.t_pts + ts, ys, _, funcs, inputs, is_f_contiguous = self._setup_cpp_inputs( + t, full_range=True + ) + shapes = self._shape(self.t_pts) + + return pybamm.solvers.idaklu_solver.idaklu.observe( + ts, ys, inputs, funcs, is_f_contiguous, shapes + ) + + def _observe_raw_python(self): + raise NotImplementedError # pragma: no cover + + def _observe_postfix(self, entries, t): + return entries + + def _interp_setup(self, entries, t): + raise NotImplementedError # pragma: no cover + + def _shape(self, t): + raise NotImplementedError # pragma: no cover + + def _process_spatial_variable_names(self, spatial_variable): + if len(spatial_variable) == 0: + return None + + # Extract names + raw_names = [] + for var in spatial_variable: + # Ignore tabs in domain names + if var == "tabs": + continue + if isinstance(var, str): + raw_names.append(var) + else: + raw_names.append(var.name) + + # Rename battery variables to match PyBaMM convention + if all([var.startswith("r") for var in raw_names]): + return "r" + elif all([var.startswith("x") for var in raw_names]): + return "x" + elif all([var.startswith("R") for var in raw_names]): + return "R" + elif len(raw_names) == 1: + return raw_names[0] + else: + raise NotImplementedError( + f"Spatial variable name not recognized for {spatial_variable}" + ) + + def __call__( + self, + t=None, + x=None, + r=None, + y=None, + z=None, + R=None, + fill_value=np.nan, + ): + # Check to see if we are interpolating exactly onto the raw solution time points + t_observe, observe_raw = self._check_observe_raw(t) + + # Check if the time points are sorted and unique + is_sorted = observe_raw or _is_sorted(t_observe) + + # Sort them if not + if not is_sorted: + idxs_sort = np.argsort(t_observe) + t_observe = t_observe[idxs_sort] + + hermite_time_interp = ( + pybamm.has_idaklu() and self.hermite_interpolation and not observe_raw + ) + + if hermite_time_interp: + entries = self.observe_and_interp(t_observe, fill_value) + + spatial_interp = any(a is not None for a in [x, r, y, z, R]) + + xr_interp = spatial_interp or not hermite_time_interp + + if xr_interp: + if hermite_time_interp: + # Already interpolated in time + t = None + entries_for_interp, coords = self._interp_setup(entries, t_observe) + else: + self.initialise() + entries_for_interp, coords = ( + self._entries_for_interp_raw, + self._coords_raw, + ) + + processed_entries = self._xr_interpolate( + entries_for_interp, + coords, + observe_raw, + t, + x, + r, + y, + z, + R, + fill_value, + ) + else: + processed_entries = entries + + if not is_sorted: + idxs_unsort = np.zeros_like(idxs_sort) + idxs_unsort[idxs_sort] = np.arange(len(t_observe)) + + processed_entries = processed_entries[..., idxs_unsort] + + # Remove a singleton time dimension if we interpolate in time with hermite + if hermite_time_interp and t_observe.size == 1: + processed_entries = np.squeeze(processed_entries, axis=-1) + + return processed_entries + + def _xr_interpolate( + self, + entries_for_interp, + coords, + observe_raw, + t=None, + x=None, + r=None, + y=None, + z=None, + R=None, + fill_value=None, + ): + """ + Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R), + using interpolation + """ + if observe_raw: + if not self.xr_array_raw_initialized: + self._xr_array_raw = xr.DataArray(entries_for_interp, coords=coords) + xr_data_array = self._xr_array_raw + else: + xr_data_array = xr.DataArray(entries_for_interp, coords=coords) + + kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R} + + # Remove any None arguments + kwargs = {key: value for key, value in kwargs.items() if value is not None} + + # Use xarray interpolation, return numpy array + out = xr_data_array.interp(**kwargs, kwargs={"fill_value": fill_value}).values + + return out + + def _check_observe_raw(self, t): + """ + Checks if the raw data should be observed exactly at the solution time points + + Args: + t (np.ndarray, list, None): time points to observe + + Returns: + t_observe (np.ndarray): time points to observe + observe_raw (bool): True if observing the raw data + """ + observe_raw = (t is None) or ( + np.asarray(t).size == len(self.t_pts) and np.all(t == self.t_pts) + ) + + if observe_raw: + t_observe = self.t_pts + elif not isinstance(t, np.ndarray): + if not isinstance(t, list): + t = [t] + t_observe = np.array(t) + else: + t_observe = t + + if t_observe[0] < self.t_pts[0]: + raise ValueError( + "The interpolation points must be greater than or equal to the initial solution time" + ) + + return t_observe, observe_raw + + @property + def entries(self): + """ + Returns the raw data entries of the processed variable. If the processed + variable has not been initialized (i.e. the entries have not been + calculated), then the processed variable is initialized first. + """ + self.initialise() + return self._entries_raw + + @property + def data(self): + """Same as entries, but different name""" + return self.entries + + @property + def entries_raw_initialized(self): + return self._entries_raw is not None + + @property + def xr_array_raw_initialized(self): + return self._xr_array_raw is not None + + @property + def sensitivities(self): + """ + Returns a dictionary of sensitivities for each input parameter. + The keys are the input parameters, and the value is a matrix of size + (n_x * n_t, n_p), where n_x is the number of states, n_t is the number of time + points, and n_p is the size of the input parameter + """ + # No sensitivities if there are no inputs + if len(self.all_inputs[0]) == 0: + return {} + # Otherwise initialise and return sensitivities + if self._sensitivities is None: + if self.all_solution_sensitivities: + self.initialise_sensitivity_explicit_forward() + else: + raise ValueError( + "Cannot compute sensitivities. The 'calculate_sensitivities' " + "argument of the solver.solve should be changed from 'None' to " + "allow sensitivities calculations. Check solver documentation for " + "details." + ) + return self._sensitivities + + def initialise_sensitivity_explicit_forward(self): + "Set up the sensitivity dictionary" + + all_S_var = [] + for ts, ys, inputs_stacked, inputs, base_variable, dy_dp in zip( + self.all_ts, + self.all_ys, + self.all_inputs_casadi, + self.all_inputs, + self.base_variables, + self.all_solution_sensitivities["all"], ): - return self.initialise_2D_scikit_fem() + # Set up symbolic variables + t_casadi = casadi.MX.sym("t") + y_casadi = casadi.MX.sym("y", ys.shape[0]) + p_casadi = { + name: casadi.MX.sym(name, value.shape[0]) + for name, value in inputs.items() + } - # check variable shape - if len(self.base_eval_shape) == 0 or self.base_eval_shape[0] == 1: - return self.initialise_0D() + p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()]) - n = self.mesh.npts - base_shape = self.base_eval_shape[0] - # Try some shapes that could make the variable a 1D variable - if base_shape in [n, n + 1]: - return self.initialise_1D() + # Convert variable to casadi format for differentiating + var_casadi = base_variable.to_casadi(t_casadi, y_casadi, inputs=p_casadi) + dvar_dy = casadi.jacobian(var_casadi, y_casadi) + dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked) - # Try some shapes that could make the variable a 2D variable - first_dim_nodes = self.mesh.nodes - first_dim_edges = self.mesh.edges - second_dim_pts = self.base_variables[0].secondary_mesh.nodes - if self.base_eval_size // len(second_dim_pts) in [ - len(first_dim_nodes), - len(first_dim_edges), - ]: - return self.initialise_2D() - - # Raise error for 3D variable - raise NotImplementedError( - f"Shape not recognized for {base_variables[0]}" - + "(note processing of 3D variables is not yet implemented)" + # Convert to functions and evaluate index-by-index + dvar_dy_func = casadi.Function( + "dvar_dy", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dy] + ) + dvar_dp_func = casadi.Function( + "dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp] + ) + for idx, t in enumerate(ts): + u = ys[:, idx] + next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked) + next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked) + if idx == 0: + dvar_dy_eval = next_dvar_dy_eval + dvar_dp_eval = next_dvar_dp_eval + else: + dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval) + dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval) + + # Compute sensitivity + S_var = dvar_dy_eval @ dy_dp + dvar_dp_eval + all_S_var.append(S_var) + + S_var = casadi.vertcat(*all_S_var) + sensitivities = {"all": S_var} + + # Add the individual sensitivity + start = 0 + for name, inp in self.all_inputs[0].items(): + end = start + inp.shape[0] + sensitivities[name] = S_var[:, start:end] + start = end + + # Save attribute + self._sensitivities = sensitivities + + @property + def hermite_interpolation(self): + return self.all_yps is not None + + +class ProcessedVariable0D(ProcessedVariable): + def __init__( + self, + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=None, + ): + self.dimensions = 0 + super().__init__( + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=cumtrapz_ic, ) - def initialise_0D(self): + def _observe_raw_python(self): + pybamm.logger.debug("Observing the variable raw data in Python") # initialise empty array of the correct size - entries = np.empty(len(self.t_pts)) + entries = np.empty(self._shape(self.t_pts)) idx = 0 # Evaluate the base_variable index-by-index for ts, ys, inputs, base_var_casadi in zip( @@ -126,22 +508,67 @@ def initialise_0D(self): entries[idx] = float(base_var_casadi(t, y, inputs)) idx += 1 + return entries - if self.cumtrapz_ic is not None: - entries = cumulative_trapezoid( - entries, self.t_pts, initial=float(self.cumtrapz_ic) - ) + def _observe_postfix(self, entries, _): + if self.cumtrapz_ic is None: + return entries + + return cumulative_trapezoid( + entries, self.t_pts, initial=float(self.cumtrapz_ic) + ) + def _interp_setup(self, entries, t): # save attributes for interpolation - self.entries_for_interp = entries - self.coords_for_interp = {"t": self.t_pts} + entries_for_interp = entries + coords_for_interp = {"t": t} - self.entries = entries - self.dimensions = 0 + return entries_for_interp, coords_for_interp - def initialise_1D(self, fixed_t=False): - len_space = self.base_eval_shape[0] - entries = np.empty((len_space, len(self.t_pts))) + def _shape(self, t): + return [len(t)] + + +class ProcessedVariable1D(ProcessedVariable): + """ + An object that can be evaluated at arbitrary (scalars or vectors) t and x, and + returns the (interpolated) value of the base variable at that t and x. + + Parameters + ---------- + base_variables : list of :class:`pybamm.Symbol` + A list of base variables with a method `evaluate(t,y)`, each entry of which + returns the value of that variable for that particular sub-solution. + A Solution can be comprised of sub-solutions which are the solutions of + different models. + Note that this can be any kind of node in the expression tree, not + just a :class:`pybamm.Variable`. + When evaluated, returns an array of size (m,n) + base_variables_casadi : list of :class:`casadi.Function` + A list of casadi functions. When evaluated, returns the same thing as + `base_Variable.evaluate` (but more efficiently). + solution : :class:`pybamm.Solution` + The solution object to be used to create the processed variables + """ + + def __init__( + self, + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=None, + ): + self.dimensions = 1 + super().__init__( + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=cumtrapz_ic, + ) + + def _observe_raw_python(self): + pybamm.logger.debug("Observing the variable raw data in Python") + entries = np.empty(self._shape(self.t_pts)) # Evaluate the base_variable index-by-index idx = 0 @@ -153,7 +580,9 @@ def initialise_1D(self, fixed_t=False): y = ys[:, inner_idx] entries[:, idx] = base_var_casadi(t, y, inputs).full()[:, 0] idx += 1 + return entries + def _interp_setup(self, entries, t): # Get node and edge values nodes = self.mesh.nodes edges = self.mesh.edges @@ -173,8 +602,6 @@ def initialise_1D(self, fixed_t=False): ) # assign attributes for reference (either x_sol or r_sol) - self.entries = entries - self.dimensions = 1 self.spatial_variable_names = { k: self._process_spatial_variable_names(v) for k, v in self.spatial_variables.items() @@ -189,26 +616,71 @@ def initialise_1D(self, fixed_t=False): self.first_dim_pts = edges # save attributes for interpolation - self.entries_for_interp = entries_for_interp - self.coords_for_interp = {self.first_dimension: pts_for_interp, "t": self.t_pts} + coords_for_interp = {self.first_dimension: pts_for_interp, "t": t} - def initialise_2D(self): - """ - Initialise a 2D object that depends on x and r, x and z, x and R, or R and r. - """ + return entries_for_interp, coords_for_interp + + def _shape(self, t): + t_size = len(t) + space_size = self.base_eval_shape[0] + return [space_size, t_size] + + +class ProcessedVariable2D(ProcessedVariable): + """ + An object that can be evaluated at arbitrary (scalars or vectors) t and x, and + returns the (interpolated) value of the base variable at that t and x. + + Parameters + ---------- + base_variables : list of :class:`pybamm.Symbol` + A list of base variables with a method `evaluate(t,y)`, each entry of which + returns the value of that variable for that particular sub-solution. + A Solution can be comprised of sub-solutions which are the solutions of + different models. + Note that this can be any kind of node in the expression tree, not + just a :class:`pybamm.Variable`. + When evaluated, returns an array of size (m,n) + base_variables_casadi : list of :class:`casadi.Function` + A list of casadi functions. When evaluated, returns the same thing as + `base_Variable.evaluate` (but more efficiently). + solution : :class:`pybamm.Solution` + The solution object to be used to create the processed variables + """ + + def __init__( + self, + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=None, + ): + self.dimensions = 2 + super().__init__( + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=cumtrapz_ic, + ) first_dim_nodes = self.mesh.nodes first_dim_edges = self.mesh.edges second_dim_nodes = self.base_variables[0].secondary_mesh.nodes - second_dim_edges = self.base_variables[0].secondary_mesh.edges if self.base_eval_size // len(second_dim_nodes) == len(first_dim_nodes): first_dim_pts = first_dim_nodes elif self.base_eval_size // len(second_dim_nodes) == len(first_dim_edges): first_dim_pts = first_dim_edges second_dim_pts = second_dim_nodes - first_dim_size = len(first_dim_pts) - second_dim_size = len(second_dim_pts) - entries = np.empty((first_dim_size, second_dim_size, len(self.t_pts))) + self.first_dim_size = len(first_dim_pts) + self.second_dim_size = len(second_dim_pts) + + def _observe_raw_python(self): + """ + Initialise a 2D object that depends on x and r, x and z, x and R, or R and r. + """ + pybamm.logger.debug("Observing the variable raw data in Python") + first_dim_size, second_dim_size, t_size = self._shape(self.t_pts) + entries = np.empty((first_dim_size, second_dim_size, t_size)) # Evaluate the base_variable index-by-index idx = 0 @@ -224,6 +696,22 @@ def initialise_2D(self): order="F", ) idx += 1 + return entries + + def _interp_setup(self, entries, t): + """ + Initialise a 2D object that depends on x and r, x and z, x and R, or R and r. + """ + first_dim_nodes = self.mesh.nodes + first_dim_edges = self.mesh.edges + second_dim_nodes = self.base_variables[0].secondary_mesh.nodes + second_dim_edges = self.base_variables[0].secondary_mesh.edges + if self.base_eval_size // len(second_dim_nodes) == len(first_dim_nodes): + first_dim_pts = first_dim_nodes + elif self.base_eval_size // len(second_dim_nodes) == len(first_dim_edges): + first_dim_pts = first_dim_edges + + second_dim_pts = second_dim_nodes # add points outside first dimension domain for extrapolation to # boundaries @@ -281,8 +769,6 @@ def initialise_2D(self): self.second_dimension = self.spatial_variable_names["secondary"] # assign attributes for reference - self.entries = entries - self.dimensions = 2 first_dim_pts_for_interp = first_dim_pts second_dim_pts_for_interp = second_dim_pts @@ -291,38 +777,68 @@ def initialise_2D(self): self.second_dim_pts = second_dim_edges # save attributes for interpolation - self.entries_for_interp = entries_for_interp - self.coords_for_interp = { + coords_for_interp = { self.first_dimension: first_dim_pts_for_interp, self.second_dimension: second_dim_pts_for_interp, - "t": self.t_pts, + "t": t, } - def initialise_2D_scikit_fem(self): + return entries_for_interp, coords_for_interp + + def _shape(self, t): + first_dim_size = self.first_dim_size + second_dim_size = self.second_dim_size + t_size = len(t) + return [first_dim_size, second_dim_size, t_size] + + +class ProcessedVariable2DSciKitFEM(ProcessedVariable2D): + """ + An object that can be evaluated at arbitrary (scalars or vectors) t and x, and + returns the (interpolated) value of the base variable at that t and x. + + Parameters + ---------- + base_variables : list of :class:`pybamm.Symbol` + A list of base variables with a method `evaluate(t,y)`, each entry of which + returns the value of that variable for that particular sub-solution. + A Solution can be comprised of sub-solutions which are the solutions of + different models. + Note that this can be any kind of node in the expression tree, not + just a :class:`pybamm.Variable`. + When evaluated, returns an array of size (m,n) + base_variables_casadi : list of :class:`casadi.Function` + A list of casadi functions. When evaluated, returns the same thing as + `base_Variable.evaluate` (but more efficiently). + solution : :class:`pybamm.Solution` + The solution object to be used to create the processed variables + """ + + def __init__( + self, + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=None, + ): + self.dimensions = 2 + super(ProcessedVariable2D, self).__init__( + base_variables, + base_variables_casadi, + solution, + cumtrapz_ic=cumtrapz_ic, + ) y_sol = self.mesh.edges["y"] - len_y = len(y_sol) z_sol = self.mesh.edges["z"] - len_z = len(z_sol) - entries = np.empty((len_y, len_z, len(self.t_pts))) - # Evaluate the base_variable index-by-index - idx = 0 - for ts, ys, inputs, base_var_casadi in zip( - self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi - ): - for inner_idx, t in enumerate(ts): - t = ts[inner_idx] - y = ys[:, inner_idx] - entries[:, :, idx] = np.reshape( - base_var_casadi(t, y, inputs).full(), - [len_y, len_z], - order="C", - ) - idx += 1 + self.first_dim_size = len(y_sol) + self.second_dim_size = len(z_sol) + + def _interp_setup(self, entries, t): + y_sol = self.mesh.edges["y"] + z_sol = self.mesh.edges["z"] # assign attributes for reference - self.entries = entries - self.dimensions = 2 self.y_sol = y_sol self.z_sol = z_sol self.first_dimension = "y" @@ -331,148 +847,116 @@ def initialise_2D_scikit_fem(self): self.second_dim_pts = z_sol # save attributes for interpolation - self.entries_for_interp = entries - self.coords_for_interp = {"y": y_sol, "z": z_sol, "t": self.t_pts} + coords_for_interp = {"y": y_sol, "z": z_sol, "t": t} - def _process_spatial_variable_names(self, spatial_variable): - if len(spatial_variable) == 0: - return None + return entries, coords_for_interp - # Extract names - raw_names = [] - for var in spatial_variable: - # Ignore tabs in domain names - if var == "tabs": - continue - if isinstance(var, str): - raw_names.append(var) - else: - raw_names.append(var.name) - # Rename battery variables to match PyBaMM convention - if all([var.startswith("r") for var in raw_names]): - return "r" - elif all([var.startswith("x") for var in raw_names]): - return "x" - elif all([var.startswith("R") for var in raw_names]): - return "R" - elif len(raw_names) == 1: - return raw_names[0] - else: - raise NotImplementedError( - f"Spatial variable name not recognized for {spatial_variable}" - ) +def process_variable(base_variables, *args, **kwargs): + mesh = base_variables[0].mesh + domain = base_variables[0].domain - def _initialize_xr_data_array(self): - """ - Initialize the xarray DataArray for interpolation. We don't do this by - default as it has some overhead (~75 us) and sometimes we only need the entries - of the processed variable, not the xarray object for interpolation. - """ - entries = self.entries_for_interp - coords = self.coords_for_interp - self._xr_data_array = xr.DataArray(entries, coords=coords) + # Evaluate base variable at initial time + base_eval_shape = base_variables[0].shape + base_eval_size = base_variables[0].size - def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True): - """ - Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R), - using interpolation - """ - if self._xr_data_array is None: - self._initialize_xr_data_array() - kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R} - # Remove any None arguments - kwargs = {key: value for key, value in kwargs.items() if value is not None} - # Use xarray interpolation, return numpy array - return self._xr_data_array.interp(**kwargs).values + # handle 2D (in space) finite element variables differently + if ( + mesh + and "current collector" in domain + and isinstance(mesh, pybamm.ScikitSubMesh2D) + ): + return ProcessedVariable2DSciKitFEM(base_variables, *args, **kwargs) + + # check variable shape + if len(base_eval_shape) == 0 or base_eval_shape[0] == 1: + return ProcessedVariable0D(base_variables, *args, **kwargs) + + n = mesh.npts + base_shape = base_eval_shape[0] + # Try some shapes that could make the variable a 1D variable + if base_shape in [n, n + 1]: + return ProcessedVariable1D(base_variables, *args, **kwargs) + + # Try some shapes that could make the variable a 2D variable + first_dim_nodes = mesh.nodes + first_dim_edges = mesh.edges + second_dim_pts = base_variables[0].secondary_mesh.nodes + if base_eval_size // len(second_dim_pts) in [ + len(first_dim_nodes), + len(first_dim_edges), + ]: + return ProcessedVariable2D(base_variables, *args, **kwargs) + + # Raise error for 3D variable + raise NotImplementedError( + f"Shape not recognized for {base_variables[0]}" + + "(note processing of 3D variables is not yet implemented)" + ) + + +def _is_f_contiguous(all_ys): + """ + Check if all the ys are f-contiguous in memory - @property - def data(self): - """Same as entries, but different name""" - return self.entries + Args: + all_ys (list of np.ndarray): list of all ys - @property - def sensitivities(self): - """ - Returns a dictionary of sensitivities for each input parameter. - The keys are the input parameters, and the value is a matrix of size - (n_x * n_t, n_p), where n_x is the number of states, n_t is the number of time - points, and n_p is the size of the input parameter - """ - # No sensitivities if there are no inputs - if len(self.all_inputs[0]) == 0: - return {} - # Otherwise initialise and return sensitivities - if self._sensitivities is None: - if self.all_solution_sensitivities: - self.initialise_sensitivity_explicit_forward() - else: - raise ValueError( - "Cannot compute sensitivities. The 'calculate_sensitivities' " - "argument of the solver.solve should be changed from 'None' to " - "allow sensitivities calculations. Check solver documentation for " - "details." - ) - return self._sensitivities + Returns: + bool: True if all ys are f-contiguous + """ - def initialise_sensitivity_explicit_forward(self): - "Set up the sensitivity dictionary" + return all(isinstance(y, np.ndarray) and y.data.f_contiguous for y in all_ys) - all_S_var = [] - for ts, ys, inputs_stacked, inputs, base_variable, dy_dp in zip( - self.all_ts, - self.all_ys, - self.all_inputs_casadi, - self.all_inputs, - self.base_variables, - self.all_solution_sensitivities["all"], - ): - # Set up symbolic variables - t_casadi = casadi.MX.sym("t") - y_casadi = casadi.MX.sym("y", ys.shape[0]) - p_casadi = { - name: casadi.MX.sym(name, value.shape[0]) - for name, value in inputs.items() - } - p_casadi_stacked = casadi.vertcat(*[p for p in p_casadi.values()]) +def _is_sorted(t): + """ + Check if an array is sorted - # Convert variable to casadi format for differentiating - var_casadi = base_variable.to_casadi(t_casadi, y_casadi, inputs=p_casadi) - dvar_dy = casadi.jacobian(var_casadi, y_casadi) - dvar_dp = casadi.jacobian(var_casadi, p_casadi_stacked) + Args: + t (np.ndarray): array to check - # Convert to functions and evaluate index-by-index - dvar_dy_func = casadi.Function( - "dvar_dy", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dy] - ) - dvar_dp_func = casadi.Function( - "dvar_dp", [t_casadi, y_casadi, p_casadi_stacked], [dvar_dp] - ) - for idx, t in enumerate(ts): - u = ys[:, idx] - next_dvar_dy_eval = dvar_dy_func(t, u, inputs_stacked) - next_dvar_dp_eval = dvar_dp_func(t, u, inputs_stacked) - if idx == 0: - dvar_dy_eval = next_dvar_dy_eval - dvar_dp_eval = next_dvar_dp_eval - else: - dvar_dy_eval = casadi.diagcat(dvar_dy_eval, next_dvar_dy_eval) - dvar_dp_eval = casadi.vertcat(dvar_dp_eval, next_dvar_dp_eval) + Returns: + bool: True if array is sorted + """ + return np.all(t[:-1] <= t[1:]) - # Compute sensitivity - S_var = dvar_dy_eval @ dy_dp + dvar_dp_eval - all_S_var.append(S_var) - S_var = casadi.vertcat(*all_S_var) - sensitivities = {"all": S_var} +def _find_ts_indices(ts, t): + """ + Parameters: + - ts: A list of numpy arrays (each sorted) whose values are successively increasing. + - t: A sorted list or array of values to find within ts. - # Add the individual sensitivity - start = 0 - for name, inp in self.all_inputs[0].items(): - end = start + inp.shape[0] - sensitivities[name] = S_var[:, start:end] - start = end + Returns: + - indices: A list of indices from `ts` such that at least one value of `t` falls within ts[idx]. + """ - # Save attribute - self._sensitivities = sensitivities + indices = [] + + # Get the minimum and maximum values of the target values `t` + t_min, t_max = t[0], t[-1] + + # Step 1: Use binary search to find the range of `ts` arrays where t_min and t_max could lie + low_idx = bisect.bisect_left([ts_arr[-1] for ts_arr in ts], t_min) + high_idx = bisect.bisect_right([ts_arr[0] for ts_arr in ts], t_max) + + # Step 2: Iterate over the identified range + for idx in range(low_idx, high_idx): + ts_min, ts_max = ts[idx][0], ts[idx][-1] + + # Binary search within `t` to check if any value falls within [ts_min, ts_max] + i = bisect.bisect_left(t, ts_min) + if i < len(t) and t[i] <= ts_max: + # At least one value of t is within ts[idx] + indices.append(idx) + + # extrapolating + if (t[-1] > ts[-1][-1]) and (len(indices) == 0 or indices[-1] != len(ts) - 1): + indices.append(len(ts) - 1) + + if len(indices) == len(ts): + # All indices are included + return range(len(ts)) + + return indices diff --git a/src/pybamm/solvers/processed_variable_computed.py b/src/pybamm/solvers/processed_variable_computed.py index a717c8b0cb..4f0cccc8c3 100644 --- a/src/pybamm/solvers/processed_variable_computed.py +++ b/src/pybamm/solvers/processed_variable_computed.py @@ -27,7 +27,7 @@ class ProcessedVariableComputed: Note that this can be any kind of node in the expression tree, not just a :class:`pybamm.Variable`. When evaluated, returns an array of size (m,n) - base_variable_casadis : list of :class:`casadi.Function` + base_variables_casadi : list of :class:`casadi.Function` A list of casadi functions. When evaluated, returns the same thing as `base_Variable.evaluate` (but more efficiently). base_variable_data : list of :numpy:array @@ -45,7 +45,6 @@ def __init__( base_variables_casadi, base_variables_data, solution, - warn=True, cumtrapz_ic=None, ): self.base_variables = base_variables @@ -60,7 +59,6 @@ def __init__( self.mesh = base_variables[0].mesh self.domain = base_variables[0].domain self.domains = base_variables[0].domains - self.warn = warn self.cumtrapz_ic = cumtrapz_ic # Sensitivity starts off uninitialized, only set when called @@ -82,33 +80,35 @@ def __init__( and isinstance(self.mesh, pybamm.ScikitSubMesh2D) ): self.initialise_2D_scikit_fem() + return # check variable shape - else: - if len(self.base_eval_shape) == 0 or self.base_eval_shape[0] == 1: - self.initialise_0D() - else: - n = self.mesh.npts - base_shape = self.base_eval_shape[0] - # Try some shapes that could make the variable a 1D variable - if base_shape in [n, n + 1]: - self.initialise_1D() - else: - # Try some shapes that could make the variable a 2D variable - first_dim_nodes = self.mesh.nodes - first_dim_edges = self.mesh.edges - second_dim_pts = self.base_variables[0].secondary_mesh.nodes - if self.base_eval_size // len(second_dim_pts) in [ - len(first_dim_nodes), - len(first_dim_edges), - ]: - self.initialise_2D() - else: - # Raise error for 3D variable - raise NotImplementedError( - f"Shape not recognized for {base_variables[0]} " - + "(note processing of 3D variables is not yet implemented)" - ) + if len(self.base_eval_shape) == 0 or self.base_eval_shape[0] == 1: + self.initialise_0D() + return + + n = self.mesh.npts + base_shape = self.base_eval_shape[0] + # Try some shapes that could make the variable a 1D variable + if base_shape in [n, n + 1]: + self.initialise_1D() + return + + # Try some shapes that could make the variable a 2D variable + first_dim_nodes = self.mesh.nodes + first_dim_edges = self.mesh.edges + second_dim_pts = self.base_variables[0].secondary_mesh.nodes + if self.base_eval_size // len(second_dim_pts) not in [ + len(first_dim_nodes), + len(first_dim_edges), + ]: + # Raise error for 3D variable + raise NotImplementedError( + f"Shape not recognized for {base_variables[0]} " + + "(note processing of 3D variables is not yet implemented)" + ) + + self.initialise_2D() def add_sensitivity(self, param, data): # unroll from sparse representation into n-d matrix @@ -203,7 +203,7 @@ def initialise_0D(self): self.entries = entries self.dimensions = 0 - def initialise_1D(self, fixed_t=False): + def initialise_1D(self): entries = self.unroll_1D() # Get node and edge values @@ -422,7 +422,7 @@ def initialise_2D_scikit_fem(self): coords={"y": y_sol, "z": z_sol, "t": self.t_pts}, ) - def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True): + def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None): """ Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R), using interpolation diff --git a/src/pybamm/solvers/solution.py b/src/pybamm/solvers/solution.py index 74d9ce7baf..c884e79e34 100644 --- a/src/pybamm/solvers/solution.py +++ b/src/pybamm/solvers/solution.py @@ -75,6 +75,7 @@ def __init__( y_event=None, termination="final time", all_sensitivities=False, + all_yps=None, check_solution=True, ): if not isinstance(all_ts, list): @@ -88,6 +89,10 @@ def __init__( self._all_ys_and_sens = all_ys self._all_models = all_models + if (all_yps is not None) and not isinstance(all_yps, list): + all_yps = [all_yps] + self._all_yps = all_yps + # Set up inputs if not isinstance(all_inputs, list): all_inputs_copy = dict(all_inputs) @@ -128,7 +133,7 @@ def __init__( # initialize empty variables and data self._variables = pybamm.FuzzyDict() - self.data = pybamm.FuzzyDict() + self._data = pybamm.FuzzyDict() # Add self as sub-solution for compatibility with ProcessedVariable self._sub_solutions = [self] @@ -298,6 +303,13 @@ def y(self): return self._y + @property + def data(self): + for k, v in self._variables.items(): + if k not in self._data: + self._data[k] = v.data + return self._data + @property def sensitivities(self): """Values of the sensitivities. Returns a dict of param_name: np_array""" @@ -400,6 +412,14 @@ def all_models(self): def all_inputs_casadi(self): return [casadi.vertcat(*inp.values()) for inp in self.all_inputs] + @property + def all_yps(self): + return self._all_yps + + @property + def hermite_interpolation(self): + return self.all_yps is not None + @property def t_event(self): """Time at which the event happens""" @@ -434,6 +454,12 @@ def first_state(self): n_states = self.all_models[0].len_rhs_and_alg for key in self._all_sensitivities: sensitivities[key] = self._all_sensitivities[key][0][-n_states:, :] + + if self.all_yps is None: + all_yps = None + else: + all_yps = self.all_yps[0][:, :1] + new_sol = Solution( self.all_ts[0][:1], self.all_ys[0][:, :1], @@ -443,6 +469,7 @@ def first_state(self): None, "final time", all_sensitivities=sensitivities, + all_yps=all_yps, ) new_sol._all_inputs_casadi = self.all_inputs_casadi[:1] new_sol._sub_solutions = self.sub_solutions[:1] @@ -467,6 +494,12 @@ def last_state(self): n_states = self.all_models[-1].len_rhs_and_alg for key in self._all_sensitivities: sensitivities[key] = self._all_sensitivities[key][-1][-n_states:, :] + + if self.all_yps is None: + all_yps = None + else: + all_yps = self.all_yps[-1][:, -1:] + new_sol = Solution( self.all_ts[-1][-1:], self.all_ys[-1][:, -1:], @@ -476,6 +509,7 @@ def last_state(self): self.y_event, self.termination, all_sensitivities=sensitivities, + all_yps=all_yps, ) new_sol._all_inputs_casadi = self.all_inputs_casadi[-1:] new_sol._sub_solutions = self.sub_solutions[-1:] @@ -528,57 +562,61 @@ def update(self, variables): if isinstance(self._all_sensitivities, bool) and self._all_sensitivities: self.extract_explicit_sensitivities() - # Convert single entry to list + # Single variable if isinstance(variables, str): variables = [variables] + # Process - for key in variables: - cumtrapz_ic = None - pybamm.logger.debug(f"Post-processing {key}") - vars_pybamm = [model.variables_and_events[key] for model in self.all_models] - - # Iterate through all models, some may be in the list several times and - # therefore only get set up once - vars_casadi = [] - for i, (model, ys, inputs, var_pybamm) in enumerate( - zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm) + for variable in variables: + self._update_variable(variable) + + def _update_variable(self, variable): + cumtrapz_ic = None + pybamm.logger.debug(f"Post-processing {variable}") + vars_pybamm = [ + model.variables_and_events[variable] for model in self.all_models + ] + + # Iterate through all models, some may be in the list several times and + # therefore only get set up once + vars_casadi = [] + for i, (model, ys, inputs, var_pybamm) in enumerate( + zip(self.all_models, self.all_ys, self.all_inputs, vars_pybamm) + ): + if ys.size == 0 and var_pybamm.has_symbol_of_classes( + pybamm.expression_tree.state_vector.StateVector ): - if ys.size == 0 and var_pybamm.has_symbol_of_classes( - pybamm.expression_tree.state_vector.StateVector - ): - raise KeyError( - f"Cannot process variable '{key}' as it was not part of the " - "solve. Please re-run the solve with `output_variables` set to " - "include this variable." - ) - elif isinstance(var_pybamm, pybamm.ExplicitTimeIntegral): - cumtrapz_ic = var_pybamm.initial_condition - cumtrapz_ic = cumtrapz_ic.evaluate() - var_pybamm = var_pybamm.child - var_casadi = self.process_casadi_var( - var_pybamm, - inputs, - ys.shape, - ) - model._variables_casadi[key] = var_casadi - vars_pybamm[i] = var_pybamm - elif key in model._variables_casadi: - var_casadi = model._variables_casadi[key] - else: - var_casadi = self.process_casadi_var( - var_pybamm, - inputs, - ys.shape, - ) - model._variables_casadi[key] = var_casadi - vars_casadi.append(var_casadi) - var = pybamm.ProcessedVariable( - vars_pybamm, vars_casadi, self, cumtrapz_ic=cumtrapz_ic - ) + raise KeyError( + f"Cannot process variable '{variable}' as it was not part of the " + "solve. Please re-run the solve with `output_variables` set to " + "include this variable." + ) + elif isinstance(var_pybamm, pybamm.ExplicitTimeIntegral): + cumtrapz_ic = var_pybamm.initial_condition + cumtrapz_ic = cumtrapz_ic.evaluate() + var_pybamm = var_pybamm.child + var_casadi = self.process_casadi_var( + var_pybamm, + inputs, + ys.shape, + ) + model._variables_casadi[variable] = var_casadi + vars_pybamm[i] = var_pybamm + elif variable in model._variables_casadi: + var_casadi = model._variables_casadi[variable] + else: + var_casadi = self.process_casadi_var( + var_pybamm, + inputs, + ys.shape, + ) + model._variables_casadi[variable] = var_casadi + vars_casadi.append(var_casadi) + var = pybamm.process_variable( + vars_pybamm, vars_casadi, self, cumtrapz_ic=cumtrapz_ic + ) - # Save variable and data - self._variables[key] = var - self.data[key] = var.data + self._variables[variable] = var def process_casadi_var(self, var_pybamm, inputs, ys_shape): t_MX = casadi.MX.sym("t") @@ -588,8 +626,40 @@ def process_casadi_var(self, var_pybamm, inputs, ys_shape): } inputs_MX = casadi.vertcat(*[p for p in inputs_MX_dict.values()]) var_sym = var_pybamm.to_casadi(t_MX, y_MX, inputs=inputs_MX_dict) - var_casadi = casadi.Function("variable", [t_MX, y_MX, inputs_MX], [var_sym]) - return var_casadi + + opts = { + "cse": True, + "inputs_check": False, + "is_diff_in": [False, False, False], + "is_diff_out": [False], + "regularity_check": False, + "error_on_fail": False, + "enable_jacobian": False, + } + + # Casadi has a bug where it does not correctly handle arrays with + # zeros padded at the beginning or end. To avoid this, we add and + # subtract the same number to the variable to reinforce the + # variable bounds. This does not affect the answer + epsilon = 1.0 + var_sym = (var_sym - epsilon) + epsilon + + var_casadi = casadi.Function( + "variable", + [t_MX, y_MX, inputs_MX], + [var_sym], + opts, + ) + + # Some variables, like interpolants, cannot be expanded + try: + var_casadi_out = var_casadi.expand() + except RuntimeError as error: + if "'eval_sx' not defined for" not in str(error): + raise error # pragma: no cover + var_casadi_out = var_casadi + + return var_casadi_out def __getitem__(self, key): """Read a variable from the solution. Variables are created 'just in time', i.e. @@ -818,13 +888,23 @@ def __add__(self, other): return new_sol # Update list of sub-solutions + hermite_interpolation = ( + other.hermite_interpolation and self.hermite_interpolation + ) if other.all_ts[0][0] == self.all_ts[-1][-1]: # Skip first time step if it is repeated all_ts = self.all_ts + [other.all_ts[0][1:]] + other.all_ts[1:] all_ys = self.all_ys + [other.all_ys[0][:, 1:]] + other.all_ys[1:] + if hermite_interpolation: + all_yps = self.all_yps + [other.all_yps[0][:, 1:]] + other.all_yps[1:] else: all_ts = self.all_ts + other.all_ts all_ys = self.all_ys + other.all_ys + if hermite_interpolation: + all_yps = self.all_yps + other.all_yps + + if not hermite_interpolation: + all_yps = None # sensitivities can be: # - bool if not using sensitivities or using explicit sensitivities which still @@ -859,6 +939,7 @@ def __add__(self, other): other.y_event, other.termination, all_sensitivities=all_sensitivities, + all_yps=all_yps, ) new_sol.closest_event_idx = other.closest_event_idx @@ -891,6 +972,7 @@ def copy(self): self.y_event, self.termination, self._all_sensitivities, + self.all_yps, ) new_sol._all_inputs_casadi = self.all_inputs_casadi new_sol._sub_solutions = self.sub_solutions @@ -1001,6 +1083,7 @@ def make_cycle_solution( sum_sols.y_event, sum_sols.termination, sum_sols._all_sensitivities, + sum_sols.all_yps, ) cycle_solution._all_inputs_casadi = sum_sols.all_inputs_casadi cycle_solution._sub_solutions = sum_sols.sub_solutions diff --git a/tests/integration/test_models/standard_output_comparison.py b/tests/integration/test_models/standard_output_comparison.py index 4d4d16e5ca..6c56894314 100644 --- a/tests/integration/test_models/standard_output_comparison.py +++ b/tests/integration/test_models/standard_output_comparison.py @@ -66,6 +66,7 @@ def compare(self, var, atol=0, rtol=0.02): # Get variable for each model model_variables = [solution[var] for solution in self.solutions] var0 = model_variables[0] + var0.initialise() spatial_pts = {} if var0.dimensions >= 1: diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py index 394d64b257..b455e72393 100644 --- a/tests/unit/test_experiments/test_simulation_with_experiment.py +++ b/tests/unit/test_experiments/test_simulation_with_experiment.py @@ -83,8 +83,12 @@ def test_run_experiment(self): assert len(sol.cycles) == 1 # Test outputs - np.testing.assert_array_equal(sol.cycles[0].steps[0]["C-rate"].data, 1 / 20) - np.testing.assert_array_equal(sol.cycles[0].steps[1]["Current [A]"].data, -1) + np.testing.assert_array_almost_equal( + sol.cycles[0].steps[0]["C-rate"].data, 1 / 20 + ) + np.testing.assert_array_almost_equal( + sol.cycles[0].steps[1]["Current [A]"].data, -1 + ) np.testing.assert_array_almost_equal( sol.cycles[0].steps[2]["Voltage [V]"].data, 4.1, decimal=5 ) diff --git a/tests/unit/test_solvers/test_idaklu_solver.py b/tests/unit/test_solvers/test_idaklu_solver.py index b049729ae3..39918d73a4 100644 --- a/tests/unit/test_solvers/test_idaklu_solver.py +++ b/tests/unit/test_solvers/test_idaklu_solver.py @@ -43,9 +43,8 @@ def test_ida_roberts_klu(self): ) # Test - t_eval = np.linspace(0, 3, 100) - t_interp = t_eval - solution = solver.solve(model, t_eval, t_interp=t_interp) + t_eval = [0, 3] + solution = solver.solve(model, t_eval) # test that final time is time of event # y = 0.1 t + y0 so y=0.2 when t=2 @@ -311,8 +310,7 @@ def test_sensitivities_initial_condition(self): options={"jax_evaluator": "iree"} if form == "iree" else {}, ) - t_interp = np.linspace(0, 3, 100) - t_eval = [t_interp[0], t_interp[-1]] + t_eval = [0, 3] a_value = 0.1 @@ -321,7 +319,6 @@ def test_sensitivities_initial_condition(self): t_eval, inputs={"a": a_value}, calculate_sensitivities=True, - t_interp=t_interp, ) np.testing.assert_array_almost_equal( @@ -638,7 +635,7 @@ def test_failures(self): solver = pybamm.IDAKLUSolver() t_eval = [0, 3] - with pytest.raises(pybamm.SolverError, match="FAILURE IDA"): + with pytest.raises(ValueError, match="std::exception"): solver.solve(model, t_eval) def test_dae_solver_algebraic_model(self): @@ -944,7 +941,7 @@ def construct_model(): # Mock a 1D current collector and initialise (none in the model) sol["x_s [m]"].domain = ["current collector"] - sol["x_s [m]"].initialise_1D() + sol["x_s [m]"].entries def test_with_output_variables_and_sensitivities(self): # Construct a model and solve for all variables, then test @@ -1040,7 +1037,7 @@ def test_with_output_variables_and_sensitivities(self): # Mock a 1D current collector and initialise (none in the model) sol["x_s [m]"].domain = ["current collector"] - sol["x_s [m]"].initialise_1D() + sol["x_s [m]"].entries def test_bad_jax_evaluator(self): model = pybamm.lithium_ion.DFN() @@ -1108,19 +1105,40 @@ def test_simulation_period(self): def test_interpolate_time_step_start_offset(self): model = pybamm.lithium_ion.SPM() - experiment = pybamm.Experiment( - [ - "Discharge at C/10 for 10 seconds", - "Charge at C/10 for 10 seconds", - ], - period="1 seconds", - ) + + def experiment_setup(period=None): + return pybamm.Experiment( + [ + "Discharge at C/10 for 10 seconds", + "Charge at C/10 for 10 seconds", + ], + period=period, + ) + + experiment_1s = experiment_setup(period="1 seconds") solver = pybamm.IDAKLUSolver() - sim = pybamm.Simulation(model, experiment=experiment, solver=solver) - sol = sim.solve() + sim_1s = pybamm.Simulation(model, experiment=experiment_1s, solver=solver) + sol_1s = sim_1s.solve() np.testing.assert_equal( - np.nextafter(sol.sub_solutions[0].t[-1], np.inf), - sol.sub_solutions[1].t[0], + np.nextafter(sol_1s.sub_solutions[0].t[-1], np.inf), + sol_1s.sub_solutions[1].t[0], + ) + + assert not sol_1s.hermite_interpolation + + experiment = experiment_setup(period=None) + sim = pybamm.Simulation(model, experiment=experiment, solver=solver) + sol = sim.solve(model) + + assert sol.hermite_interpolation + + rtol = solver.rtol + atol = solver.atol + np.testing.assert_allclose( + sol_1s["Voltage [V]"].data, + sol["Voltage [V]"](sol_1s.t), + rtol=rtol, + atol=atol, ) def test_python_idaklu_deprecation_errors(self): diff --git a/tests/unit/test_solvers/test_processed_variable.py b/tests/unit/test_solvers/test_processed_variable.py index 6cd456347d..04de88963d 100644 --- a/tests/unit/test_solvers/test_processed_variable.py +++ b/tests/unit/test_solvers/test_processed_variable.py @@ -8,6 +8,13 @@ import numpy as np import pytest +from scipy.interpolate import CubicHermiteSpline + + +if pybamm.has_idaklu(): + _hermite_args = [True, False] +else: + _hermite_args = [False] def to_casadi(var_pybamm, y, inputs=None): @@ -27,55 +34,91 @@ def to_casadi(var_pybamm, y, inputs=None): return var_casadi -def process_and_check_2D_variable( - var, first_spatial_var, second_spatial_var, disc=None, geometry_options=None -): - # first_spatial_var should be on the "smaller" domain, i.e "r" for an "r-x" variable - if geometry_options is None: - geometry_options = {} - if disc is None: - disc = tests.get_discretisation_for_testing() - disc.set_variable_slices([var]) - - first_sol = disc.process_symbol(first_spatial_var).entries[:, 0] - second_sol = disc.process_symbol(second_spatial_var).entries[:, 0] - - # Keep only the first iteration of entries - first_sol = first_sol[: len(first_sol) // len(second_sol)] - var_sol = disc.process_symbol(var) - t_sol = np.linspace(0, 1) - y_sol = np.ones(len(second_sol) * len(first_sol))[:, np.newaxis] * np.linspace(0, 5) - - var_casadi = to_casadi(var_sol, y_sol) - model = tests.get_base_model_with_battery_geometry(**geometry_options) - processed_var = pybamm.ProcessedVariable( - [var_sol], - [var_casadi], - pybamm.Solution(t_sol, y_sol, model, {}), - warn=False, - ) - np.testing.assert_array_equal( - processed_var.entries, - np.reshape(y_sol, [len(first_sol), len(second_sol), len(t_sol)]), - ) - return y_sol, first_sol, second_sol, t_sol +class TestProcessedVariable: + @staticmethod + def _get_yps(y, hermite_interp, values=1): + if hermite_interp: + yp_sol = values * np.ones_like(y) + else: + yp_sol = None + return yp_sol + + @staticmethod + def _sol_default(t_sol, y_sol, yp_sol=None, model=None, inputs=None): + if inputs is None: + inputs = {} + if model is None: + model = tests.get_base_model_with_battery_geometry() + return pybamm.Solution( + t_sol, + y_sol, + model, + inputs, + all_yps=yp_sol, + ) + + def _process_and_check_2D_variable( + self, + var, + first_spatial_var, + second_spatial_var, + disc=None, + geometry_options=None, + hermite_interp=False, + ): + # first_spatial_var should be on the "smaller" domain, i.e "r" for an "r-x" variable + if geometry_options is None: + geometry_options = {} + if disc is None: + disc = tests.get_discretisation_for_testing() + disc.set_variable_slices([var]) + first_sol = disc.process_symbol(first_spatial_var).entries[:, 0] + second_sol = disc.process_symbol(second_spatial_var).entries[:, 0] -class TestProcessedVariable: - def test_processed_variable_0D(self): + # Keep only the first iteration of entries + first_sol = first_sol[: len(first_sol) // len(second_sol)] + var_sol = disc.process_symbol(var) + t_sol = np.linspace(0, 1) + y_sol = 5 * t_sol * np.ones(len(second_sol) * len(first_sol))[:, np.newaxis] + yp_sol = self._get_yps(y_sol, hermite_interp, values=5) + + var_casadi = to_casadi(var_sol, y_sol) + model = tests.get_base_model_with_battery_geometry(**geometry_options) + processed_var = pybamm.process_variable( + [var_sol], + [var_casadi], + self._sol_default(t_sol, y_sol, yp_sol, model), + ) + np.testing.assert_array_equal( + processed_var.entries, + np.reshape(y_sol, [len(first_sol), len(second_sol), len(t_sol)]), + ) + + # check that C++ and Python give the same result + if pybamm.has_idaklu(): + np.testing.assert_array_equal( + processed_var._observe_raw_cpp(), processed_var._observe_raw_python() + ) + + return y_sol, first_sol, second_sol, t_sol, yp_sol + + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_0D(self, hermite_interp): # without space t = pybamm.t y = pybamm.StateVector(slice(0, 1)) var = t * y var.mesh = None + model = pybamm.BaseModel() t_sol = np.linspace(0, 1) y_sol = np.array([np.linspace(0, 5)]) + yp_sol = self._get_yps(y_sol, hermite_interp) var_casadi = to_casadi(var, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var], [var_casadi], - pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol, model), ) np.testing.assert_array_equal(processed_var.entries, t_sol * y_sol[0]) @@ -84,18 +127,41 @@ def test_processed_variable_0D(self): var.mesh = None t_sol = np.array([0]) y_sol = np.array([1])[:, np.newaxis] + yp_sol = np.array([1])[:, np.newaxis] + sol = self._sol_default(t_sol, y_sol, yp_sol, model) var_casadi = to_casadi(var, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var], [var_casadi], - pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, + sol, ) np.testing.assert_array_equal(processed_var.entries, y_sol[0]) - # check empty sensitivity works + # check that repeated calls return the same data + data1 = processed_var.data + + assert processed_var.entries_raw_initialized + + data2 = processed_var.data + + np.testing.assert_array_equal(data1, data2) + + data_t1 = processed_var(sol.t) + + assert processed_var.xr_array_raw_initialized - def test_processed_variable_0D_no_sensitivity(self): + data_t2 = processed_var(sol.t) + + np.testing.assert_array_equal(data_t1, data_t2) + + # check that C++ and Python give the same result + if pybamm.has_idaklu(): + np.testing.assert_array_equal( + processed_var._observe_raw_cpp(), processed_var._observe_raw_python() + ) + + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_0D_no_sensitivity(self, hermite_interp): # without space t = pybamm.t y = pybamm.StateVector(slice(0, 1)) @@ -103,12 +169,12 @@ def test_processed_variable_0D_no_sensitivity(self): var.mesh = None t_sol = np.linspace(0, 1) y_sol = np.array([np.linspace(0, 5)]) + yp_sol = self._get_yps(y_sol, hermite_interp) var_casadi = to_casadi(var, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var], [var_casadi], - pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol, pybamm.BaseModel()), ) # test no inputs (i.e. no sensitivity) @@ -124,18 +190,18 @@ def test_processed_variable_0D_no_sensitivity(self): y_sol = np.array([np.linspace(0, 5)]) inputs = {"a": np.array([1.0])} var_casadi = to_casadi(var, y_sol, inputs=inputs) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var], [var_casadi], pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), inputs), - warn=False, ) # test no sensitivity raises error with pytest.raises(ValueError, match="Cannot compute sensitivities"): print(processed_var.sensitivities) - def test_processed_variable_1D(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_1D(self, hermite_interp): t = pybamm.t var = pybamm.Variable("var", domain=["negative electrode", "separator"]) x = pybamm.SpatialVariable("x", domain=["negative electrode", "separator"]) @@ -149,26 +215,21 @@ def test_processed_variable_1D(self): eqn_sol = disc.process_symbol(eqn) t_sol = np.linspace(0, 1) y_sol = np.ones_like(x_sol)[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(y_sol, hermite_interp, values=5) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_equal(processed_var.entries, y_sol) np.testing.assert_array_almost_equal(processed_var(t_sol, x_sol), y_sol) eqn_casadi = to_casadi(eqn_sol, y_sol) - processed_eqn = pybamm.ProcessedVariable( + processed_eqn = pybamm.process_variable( [eqn_sol], [eqn_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_almost_equal( processed_eqn(t_sol, x_sol), t_sol * y_sol + x_sol[:, np.newaxis] @@ -184,13 +245,10 @@ def test_processed_variable_1D(self): x_s_edge = pybamm.Matrix(disc.mesh["separator"].edges, domain="separator") x_s_edge.mesh = disc.mesh["separator"] x_s_casadi = to_casadi(x_s_edge, y_sol) - processed_x_s_edge = pybamm.ProcessedVariable( + processed_x_s_edge = pybamm.process_variable( [x_s_edge], [x_s_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_equal( x_s_edge.entries[:, 0], processed_x_s_edge.entries[:, 0] @@ -201,20 +259,25 @@ def test_processed_variable_1D(self): eqn_sol = disc.process_symbol(eqn) t_sol = np.array([0]) y_sol = np.ones_like(x_sol)[:, np.newaxis] + yp_sol = self._get_yps(y_sol, hermite_interp, values=0) eqn_casadi = to_casadi(eqn_sol, y_sol) - processed_eqn2 = pybamm.ProcessedVariable( + processed_eqn2 = pybamm.process_variable( [eqn_sol], [eqn_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_equal( processed_eqn2.entries, y_sol + x_sol[:, np.newaxis] ) - def test_processed_variable_1D_unknown_domain(self): + # check that C++ and Python give the same result + if pybamm.has_idaklu(): + np.testing.assert_array_equal( + processed_eqn2._observe_raw_cpp(), processed_eqn2._observe_raw_python() + ) + + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_1D_unknown_domain(self, hermite_interp): x = pybamm.SpatialVariable("x", domain="SEI layer", coord_sys="cartesian") geometry = pybamm.Geometry( {"SEI layer": {x: {"min": pybamm.Scalar(0), "max": pybamm.Scalar(1)}}} @@ -227,6 +290,7 @@ def test_processed_variable_1D_unknown_domain(self): nt = 100 y_sol = np.zeros((var_pts[x], nt)) + yp_sol = self._get_yps(y_sol, hermite_interp) model = tests.get_base_model_with_battery_geometry() model._geometry = geometry solution = pybamm.Solution( @@ -237,14 +301,16 @@ def test_processed_variable_1D_unknown_domain(self): np.linspace(0, 1, 1), np.zeros(var_pts[x]), "test", + all_yps=yp_sol, ) c = pybamm.StateVector(slice(0, var_pts[x]), domain=["SEI layer"]) c.mesh = mesh["SEI layer"] c_casadi = to_casadi(c, y_sol) - pybamm.ProcessedVariable([c], [c_casadi], solution, warn=False) + pybamm.process_variable([c], [c_casadi], solution) - def test_processed_variable_2D_x_r(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_x_r(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative particle"], @@ -258,9 +324,12 @@ def test_processed_variable_2D_x_r(self): ) disc = tests.get_p2d_discretisation_for_testing() - process_and_check_2D_variable(var, r, x, disc=disc) + self._process_and_check_2D_variable( + var, r, x, disc=disc, hermite_interp=hermite_interp + ) - def test_processed_variable_2D_R_x(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_R_x(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative particle size"], @@ -274,15 +343,17 @@ def test_processed_variable_2D_R_x(self): x = pybamm.SpatialVariable("x", domain=["negative electrode"]) disc = tests.get_size_distribution_disc_for_testing() - process_and_check_2D_variable( + self._process_and_check_2D_variable( var, R, x, disc=disc, geometry_options={"options": {"particle size": "distribution"}}, + hermite_interp=hermite_interp, ) - def test_processed_variable_2D_R_z(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_R_z(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative particle size"], @@ -296,15 +367,17 @@ def test_processed_variable_2D_R_z(self): z = pybamm.SpatialVariable("z", domain=["current collector"]) disc = tests.get_size_distribution_disc_for_testing() - process_and_check_2D_variable( + self._process_and_check_2D_variable( var, R, z, disc=disc, geometry_options={"options": {"particle size": "distribution"}}, + hermite_interp=hermite_interp, ) - def test_processed_variable_2D_r_R(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_r_R(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative particle"], @@ -318,15 +391,17 @@ def test_processed_variable_2D_r_R(self): R = pybamm.SpatialVariable("R", domain=["negative particle size"]) disc = tests.get_size_distribution_disc_for_testing() - process_and_check_2D_variable( + self._process_and_check_2D_variable( var, r, R, disc=disc, geometry_options={"options": {"particle size": "distribution"}}, + hermite_interp=hermite_interp, ) - def test_processed_variable_2D_x_z(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_x_z(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative electrode", "separator"], @@ -340,7 +415,9 @@ def test_processed_variable_2D_x_z(self): z = pybamm.SpatialVariable("z", domain=["current collector"]) disc = tests.get_1p1d_discretisation_for_testing() - y_sol, x_sol, z_sol, t_sol = process_and_check_2D_variable(var, x, z, disc=disc) + y_sol, x_sol, z_sol, t_sol, yp_sol = self._process_and_check_2D_variable( + var, x, z, disc=disc, hermite_interp=hermite_interp + ) del x_sol # On edges @@ -352,19 +429,17 @@ def test_processed_variable_2D_x_z(self): x_s_edge.mesh = disc.mesh["separator"] x_s_edge.secondary_mesh = disc.mesh["current collector"] x_s_casadi = to_casadi(x_s_edge, y_sol) - processed_x_s_edge = pybamm.ProcessedVariable( + processed_x_s_edge = pybamm.process_variable( [x_s_edge], [x_s_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_equal( x_s_edge.entries.flatten(), processed_x_s_edge.entries[:, :, 0].T.flatten() ) - def test_processed_variable_2D_space_only(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_space_only(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative particle"], @@ -386,22 +461,21 @@ def test_processed_variable_2D_space_only(self): var_sol = disc.process_symbol(var) t_sol = np.array([0]) y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] + yp_sol = self._get_yps(y_sol, hermite_interp) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_equal( processed_var.entries, np.reshape(y_sol, [len(r_sol), len(x_sol), len(t_sol)]), ) - def test_processed_variable_2D_scikit(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_scikit(self, hermite_interp): var = pybamm.Variable("var", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() @@ -412,21 +486,20 @@ def test_processed_variable_2D_scikit(self): var_sol.mesh = disc.mesh["current collector"] t_sol = np.linspace(0, 1) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(u_sol, hermite_interp) var_casadi = to_casadi(var_sol, u_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, u_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, u_sol, yp_sol), ) np.testing.assert_array_equal( processed_var.entries, np.reshape(u_sol, [len(y), len(z), len(t_sol)]) ) - def test_processed_variable_2D_fixed_t_scikit(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_variable_2D_fixed_t_scikit(self, hermite_interp): var = pybamm.Variable("var", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() @@ -437,22 +510,23 @@ def test_processed_variable_2D_fixed_t_scikit(self): var_sol.mesh = disc.mesh["current collector"] t_sol = np.array([0]) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] + yp_sol = self._get_yps(u_sol, hermite_interp) var_casadi = to_casadi(var_sol, u_sol) model = tests.get_base_model_with_battery_geometry( options={"dimensionality": 2} ) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution(t_sol, u_sol, model, {}), - warn=False, + pybamm.Solution(t_sol, u_sol, model, {}, all_yps=yp_sol), ) np.testing.assert_array_equal( processed_var.entries, np.reshape(u_sol, [len(y), len(z), len(t_sol)]) ) - def test_processed_var_0D_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_0D_interpolation(self, hermite_interp): # without spatial dependence t = pybamm.t y = pybamm.StateVector(slice(0, 1)) @@ -462,40 +536,39 @@ def test_processed_var_0D_interpolation(self): eqn.mesh = None t_sol = np.linspace(0, 1, 1000) - y_sol = np.array([np.linspace(0, 5, 1000)]) + y_sol = np.array([5 * t_sol]) + yp_sol = self._get_yps(y_sol, hermite_interp, values=5) var_casadi = to_casadi(var, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # vector np.testing.assert_array_equal(processed_var(t_sol), y_sol[0]) # scalar - np.testing.assert_array_equal(processed_var(0.5), 2.5) - np.testing.assert_array_equal(processed_var(0.7), 3.5) + np.testing.assert_array_almost_equal(processed_var(0.5), 2.5) + np.testing.assert_array_almost_equal(processed_var(0.7), 3.5) eqn_casadi = to_casadi(eqn, y_sol) - processed_eqn = pybamm.ProcessedVariable( + processed_eqn = pybamm.process_variable( [eqn], [eqn_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_equal(processed_eqn(t_sol), t_sol * y_sol[0]) - np.testing.assert_array_almost_equal(processed_eqn(0.5), 0.5 * 2.5) + assert processed_eqn(0.5).shape == () + + np.testing.assert_array_almost_equal(processed_eqn(0.5), 0.5 * 2.5) + np.testing.assert_array_equal(processed_eqn(2, fill_value=100), 100) # Suppress warning for this test pybamm.set_logging_level("ERROR") np.testing.assert_array_equal(processed_eqn(2), np.nan) pybamm.set_logging_level("WARNING") - def test_processed_var_0D_fixed_t_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_0D_fixed_t_interpolation(self, hermite_interp): y = pybamm.StateVector(slice(0, 1)) var = y eqn = 2 * y @@ -504,17 +577,18 @@ def test_processed_var_0D_fixed_t_interpolation(self): t_sol = np.array([10]) y_sol = np.array([[100]]) + yp_sol = self._get_yps(y_sol, hermite_interp) eqn_casadi = to_casadi(eqn, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [eqn], [eqn_casadi], - pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol, pybamm.BaseModel()), ) - np.testing.assert_array_equal(processed_var(), 200) + assert processed_var() == 200 - def test_processed_var_1D_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_1D_interpolation(self, hermite_interp): t = pybamm.t var = pybamm.Variable("var", domain=["negative electrode", "separator"]) x = pybamm.SpatialVariable("x", domain=["negative electrode", "separator"]) @@ -526,36 +600,33 @@ def test_processed_var_1D_interpolation(self): var_sol = disc.process_symbol(var) eqn_sol = disc.process_symbol(eqn) t_sol = np.linspace(0, 1) - y_sol = x_sol[:, np.newaxis] * np.linspace(0, 5) + y_sol = x_sol[:, np.newaxis] * (5 * t_sol) + yp_sol = self._get_yps(y_sol, hermite_interp, values=5) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) + # 2 vectors np.testing.assert_array_almost_equal(processed_var(t_sol, x_sol), y_sol) # 1 vector, 1 scalar np.testing.assert_array_almost_equal(processed_var(0.5, x_sol), 2.5 * x_sol) - np.testing.assert_array_equal( - processed_var(t_sol, x_sol[-1]), x_sol[-1] * np.linspace(0, 5) + np.testing.assert_array_almost_equal( + processed_var(t_sol, x_sol[-1]), + x_sol[-1] * np.linspace(0, 5), ) # 2 scalars np.testing.assert_array_almost_equal( processed_var(0.5, x_sol[-1]), 2.5 * x_sol[-1] ) eqn_casadi = to_casadi(eqn_sol, y_sol) - processed_eqn = pybamm.ProcessedVariable( + processed_eqn = pybamm.process_variable( [eqn_sol], [eqn_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # 2 vectors np.testing.assert_array_almost_equal( @@ -571,13 +642,10 @@ def test_processed_var_1D_interpolation(self): x_disc = disc.process_symbol(x) x_casadi = to_casadi(x_disc, y_sol) - processed_x = pybamm.ProcessedVariable( + processed_x = pybamm.process_variable( [x_disc], [x_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_almost_equal(processed_x(t=0, x=x_sol), x_sol) @@ -587,13 +655,10 @@ def test_processed_var_1D_interpolation(self): ) r_n.mesh = disc.mesh["negative particle"] r_n_casadi = to_casadi(r_n, y_sol) - processed_r_n = pybamm.ProcessedVariable( + processed_r_n = pybamm.process_variable( [r_n], [r_n_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) np.testing.assert_array_equal(r_n.entries[:, 0], processed_r_n.entries[:, 0]) r_test = np.linspace(0, 0.5) @@ -608,17 +673,17 @@ def test_processed_var_1D_interpolation(self): model = tests.get_base_model_with_battery_geometry( options={"particle size": "distribution"} ) - processed_R_n = pybamm.ProcessedVariable( + processed_R_n = pybamm.process_variable( [R_n], [R_n_casadi], pybamm.Solution(t_sol, y_sol, model, {}), - warn=False, ) np.testing.assert_array_equal(R_n.entries[:, 0], processed_R_n.entries[:, 0]) R_test = np.linspace(0, 1) np.testing.assert_array_almost_equal(processed_R_n(0, R=R_test), R_test) - def test_processed_var_1D_fixed_t_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_1D_fixed_t_interpolation(self, hermite_interp): var = pybamm.Variable("var", domain=["negative electrode", "separator"]) x = pybamm.SpatialVariable("x", domain=["negative electrode", "separator"]) eqn = var + x @@ -629,15 +694,13 @@ def test_processed_var_1D_fixed_t_interpolation(self): eqn_sol = disc.process_symbol(eqn) t_sol = np.array([1]) y_sol = x_sol[:, np.newaxis] + yp_sol = self._get_yps(y_sol, hermite_interp) eqn_casadi = to_casadi(eqn_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [eqn_sol], [eqn_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # vector @@ -647,7 +710,8 @@ def test_processed_var_1D_fixed_t_interpolation(self): # scalar np.testing.assert_array_almost_equal(processed_var(x=0.5), 1) - def test_processed_var_wrong_spatial_variable_names(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_wrong_spatial_variable_names(self, hermite_interp): var = pybamm.Variable( "var", domain=["domain A", "domain B"], @@ -677,6 +741,7 @@ def test_processed_var_wrong_spatial_variable_names(self): var_sol = disc.process_symbol(var) t_sol = np.linspace(0, 1) y_sol = np.ones(len(a_sol) * len(b_sol))[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(y_sol, hermite_interp) var_casadi = to_casadi(var_sol, y_sol) model = pybamm.BaseModel() @@ -687,14 +752,14 @@ def test_processed_var_wrong_spatial_variable_names(self): } ) with pytest.raises(NotImplementedError, match="Spatial variable name"): - pybamm.ProcessedVariable( + pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution(t_sol, y_sol, model, {}), - warn=False, - ) + pybamm.Solution(t_sol, y_sol, model, {}, all_yps=yp_sol), + ).initialise() - def test_processed_var_2D_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_2D_interpolation(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative particle"], @@ -716,15 +781,13 @@ def test_processed_var_2D_interpolation(self): var_sol = disc.process_symbol(var) t_sol = np.linspace(0, 1) y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(y_sol, hermite_interp) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # 3 vectors np.testing.assert_array_equal( @@ -768,20 +831,18 @@ def test_processed_var_2D_interpolation(self): y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # 3 vectors np.testing.assert_array_equal( processed_var(t_sol, x_sol, r_sol).shape, (10, 35, 50) ) - def test_processed_var_2D_fixed_t_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_2D_fixed_t_interpolation(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative particle"], @@ -803,15 +864,13 @@ def test_processed_var_2D_fixed_t_interpolation(self): var_sol = disc.process_symbol(var) t_sol = np.array([0]) y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] + yp_sol = self._get_yps(y_sol, hermite_interp) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # 2 vectors np.testing.assert_array_equal( @@ -823,7 +882,8 @@ def test_processed_var_2D_fixed_t_interpolation(self): # 2 scalars np.testing.assert_array_equal(processed_var(t=0, x=0.2, r=0.2).shape, ()) - def test_processed_var_2D_secondary_broadcast(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_2D_secondary_broadcast(self, hermite_interp): var = pybamm.Variable("var", domain=["negative particle"]) broad_var = pybamm.SecondaryBroadcast(var, "negative electrode") x = pybamm.SpatialVariable("x", domain=["negative electrode"]) @@ -836,15 +896,13 @@ def test_processed_var_2D_secondary_broadcast(self): var_sol = disc.process_symbol(broad_var) t_sol = np.linspace(0, 1) y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(y_sol, hermite_interp) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # 3 vectors np.testing.assert_array_equal( @@ -877,22 +935,21 @@ def test_processed_var_2D_secondary_broadcast(self): var_sol = disc.process_symbol(broad_var) t_sol = np.linspace(0, 1) y_sol = np.ones(len(x_sol) * len(r_sol))[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(y_sol, hermite_interp, values=5) var_casadi = to_casadi(var_sol, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, y_sol, yp_sol), ) # 3 vectors np.testing.assert_array_equal( processed_var(t_sol, x_sol, r_sol).shape, (10, 35, 50) ) - def test_processed_var_2_d_scikit_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_2_d_scikit_interpolation(self, hermite_interp): var = pybamm.Variable("var", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() @@ -903,15 +960,13 @@ def test_processed_var_2_d_scikit_interpolation(self): var_sol.mesh = disc.mesh["current collector"] t_sol = np.linspace(0, 1) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(u_sol, hermite_interp) var_casadi = to_casadi(var_sol, u_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, u_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, u_sol, yp_sol), ) # 3 vectors np.testing.assert_array_equal( @@ -938,7 +993,8 @@ def test_processed_var_2_d_scikit_interpolation(self): # 3 scalars np.testing.assert_array_equal(processed_var(0.2, y=0.2, z=0.2).shape, ()) - def test_processed_var_2D_fixed_t_scikit_interpolation(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_2D_fixed_t_scikit_interpolation(self, hermite_interp): var = pybamm.Variable("var", domain=["current collector"]) disc = tests.get_2p1d_discretisation_for_testing() @@ -949,15 +1005,13 @@ def test_processed_var_2D_fixed_t_scikit_interpolation(self): var_sol.mesh = disc.mesh["current collector"] t_sol = np.array([0]) u_sol = np.ones(var_sol.shape[0])[:, np.newaxis] + yp_sol = self._get_yps(u_sol, hermite_interp) var_casadi = to_casadi(var_sol, u_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, u_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, u_sol, yp_sol), ) # 2 vectors np.testing.assert_array_equal( @@ -969,7 +1023,8 @@ def test_processed_var_2D_fixed_t_scikit_interpolation(self): # 2 scalars np.testing.assert_array_equal(processed_var(t=0, y=0.2, z=0.2).shape, ()) - def test_processed_var_2D_unknown_domain(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_processed_var_2D_unknown_domain(self, hermite_interp): var = pybamm.Variable( "var", domain=["domain B"], @@ -1007,6 +1062,7 @@ def test_processed_var_2D_unknown_domain(self): var_sol = disc.process_symbol(var) t_sol = np.linspace(0, 1) y_sol = np.ones(len(x_sol) * len(z_sol))[:, np.newaxis] * np.linspace(0, 5) + yp_sol = self._get_yps(y_sol, hermite_interp, values=5) var_casadi = to_casadi(var_sol, y_sol) model = pybamm.BaseModel() @@ -1016,11 +1072,10 @@ def test_processed_var_2D_unknown_domain(self): "domain B": {z: {"min": 0, "max": 1}}, } ) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution(t_sol, y_sol, model, {}), - warn=False, + pybamm.Solution(t_sol, y_sol, model, {}, all_yps=yp_sol), ) # 3 vectors np.testing.assert_array_equal( @@ -1047,7 +1102,8 @@ def test_processed_var_2D_unknown_domain(self): # 3 scalars np.testing.assert_array_equal(processed_var(t=0.2, x=0.2, z=0.2).shape, ()) - def test_3D_raises_error(self): + @pytest.mark.parametrize("hermite_interp", _hermite_args) + def test_3D_raises_error(self, hermite_interp): var = pybamm.Variable( "var", domain=["negative electrode"], @@ -1059,16 +1115,14 @@ def test_3D_raises_error(self): var_sol = disc.process_symbol(var) t_sol = np.array([0, 1, 2]) u_sol = np.ones(var_sol.shape[0] * 3)[:, np.newaxis] + yp_sol = self._get_yps(u_sol, hermite_interp, values=0) var_casadi = to_casadi(var_sol, u_sol) with pytest.raises(NotImplementedError, match="Shape not recognized"): - pybamm.ProcessedVariable( + pybamm.process_variable( [var_sol], [var_casadi], - pybamm.Solution( - t_sol, u_sol, tests.get_base_model_with_battery_geometry(), {} - ), - warn=False, + self._sol_default(t_sol, u_sol, yp_sol), ) def test_process_spatial_variable_names(self): @@ -1080,11 +1134,10 @@ def test_process_spatial_variable_names(self): t_sol = np.linspace(0, 1) y_sol = np.array([np.linspace(0, 5)]) var_casadi = to_casadi(var, y_sol) - processed_var = pybamm.ProcessedVariable( + processed_var = pybamm.process_variable( [var], [var_casadi], pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, ) # Test empty list returns None @@ -1108,3 +1161,110 @@ def test_process_spatial_variable_names(self): # Test error raised if spatial variable name not recognised with pytest.raises(NotImplementedError, match="Spatial variable name"): processed_var._process_spatial_variable_names(["var1", "var2"]) + + def test_hermite_interpolator(self): + if not pybamm.has_idaklu(): + pytest.skip("Cannot test Hermite interpolation without IDAKLU") + + # initialise dummy solution to access method + def solution_setup(t_sol, sign): + y_sol = np.array([sign * np.sin(t_sol)]) + yp_sol = np.array([sign * np.cos(t_sol)]) + sol = self._sol_default(t_sol, y_sol, yp_sol) + return sol + + # without spatial dependence + t = pybamm.t + y = pybamm.StateVector(slice(0, 1)) + var = y + eqn = t * y + var.mesh = None + eqn.mesh = None + + sign1 = +1 + t_sol1 = np.linspace(0, 1, 100) + sol1 = solution_setup(t_sol1, sign1) + + # Discontinuity in the solution + sign2 = -1 + t_sol2 = np.linspace(np.nextafter(t_sol1[-1], np.inf), t_sol1[-1] + 3, 99) + sol2 = solution_setup(t_sol2, sign2) + + sol = sol1 + sol2 + var_casadi = to_casadi(var, sol.all_ys[0]) + processed_var = pybamm.process_variable( + [var] * len(sol.all_ts), + [var_casadi] * len(sol.all_ts), + sol, + ) + + # Ground truth spline interpolants from scipy + spls = [ + CubicHermiteSpline(t, y, yp, axis=1) + for t, y, yp in zip(sol.all_ts, sol.all_ys, sol.all_yps) + ] + + def spl(t): + t = np.array(t) + out = np.zeros(len(t)) + for i, spl in enumerate(spls): + t0 = sol.all_ts[i][0] + tf = sol.all_ts[i][-1] + + mask = t >= t0 + # Extrapolation is allowed for the final solution + if i < len(spls) - 1: + mask &= t <= tf + + out[mask] = spl(t[mask]).flatten() + return out + + t0 = sol.t[0] + tf = sol.t[-1] + + # Test extrapolation before the first solution time + t_left_extrap = t0 - 1 + with pytest.raises( + ValueError, match="interpolation points must be greater than" + ): + processed_var(t_left_extrap) + + # Test extrapolation after the last solution time + t_right_extrap = [tf + 1] + np.testing.assert_almost_equal( + spl(t_right_extrap), + processed_var(t_right_extrap, fill_value="extrapolate"), + decimal=8, + ) + + t_dense = np.linspace(t0, tf + 1, 1000) + np.testing.assert_almost_equal( + spl(t_dense), + processed_var(t_dense, fill_value="extrapolate"), + decimal=8, + ) + + t_extended = np.union1d(sol.t, sol.t[-1] + 1) + np.testing.assert_almost_equal( + spl(t_extended), + processed_var(t_extended, fill_value="extrapolate"), + decimal=8, + ) + + ## Unsorted arrays + t_unsorted = np.array([0.5, 0.4, 0.6, 0, 1]) + idxs_sort = np.argsort(t_unsorted) + t_sorted = np.sort(t_unsorted) + + y_sorted = processed_var(t_sorted) + + idxs_unsort = np.zeros_like(idxs_sort) + idxs_unsort[idxs_sort] = np.arange(len(t_unsorted)) + + # Check that the unsorted and sorted arrays are the same + assert np.all(t_sorted == t_unsorted[idxs_sort]) + + y_unsorted = processed_var(t_unsorted) + + # Check that the unsorted and sorted arrays are the same + assert np.all(y_unsorted == y_sorted[idxs_unsort]) diff --git a/tests/unit/test_solvers/test_processed_variable_computed.py b/tests/unit/test_solvers/test_processed_variable_computed.py index 59a062b199..0fa46b4414 100644 --- a/tests/unit/test_solvers/test_processed_variable_computed.py +++ b/tests/unit/test_solvers/test_processed_variable_computed.py @@ -57,7 +57,6 @@ def process_and_check_2D_variable( [var_casadi], [y_sol], pybamm.Solution(t_sol, y_sol, model, {}), - warn=False, ) # NB: ProcessedVariableComputed does not interpret y in the same way as # ProcessedVariable; a better test of equivalence is to check that the @@ -82,7 +81,6 @@ def test_processed_variable_0D(self): [var_casadi], [y_sol], pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, ) # Assert that the processed variable is the same as the solution np.testing.assert_array_equal(processed_var.entries, y_sol[0]) @@ -94,7 +92,7 @@ def test_processed_variable_0D(self): # Check cumtrapz workflow produces no errors processed_var.cumtrapz_ic = 1 - processed_var.initialise_0D() + processed_var.entries # check empty sensitivity works def test_processed_variable_0D_no_sensitivity(self): @@ -111,7 +109,6 @@ def test_processed_variable_0D_no_sensitivity(self): [var_casadi], [y_sol], pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, ) # test no inputs (i.e. no sensitivity) @@ -132,7 +129,6 @@ def test_processed_variable_0D_no_sensitivity(self): [var_casadi], [y_sol], pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), inputs), - warn=False, ) # test no sensitivity raises error @@ -157,7 +153,6 @@ def test_processed_variable_1D(self): [var_casadi], [y_sol], sol, - warn=False, ) # Ordering from idaklu with output_variables set is different to @@ -175,7 +170,7 @@ def test_processed_variable_1D(self): processed_var.mesh.edges, processed_var.mesh.nodes, ) - processed_var.initialise_1D() + processed_var.entries processed_var.mesh.nodes, processed_var.mesh.edges = ( processed_var.mesh.edges, processed_var.mesh.nodes, @@ -192,7 +187,7 @@ def test_processed_variable_1D(self): ] for domain in domain_list: processed_var.domain[0] = domain - processed_var.initialise_1D() + processed_var.entries def test_processed_variable_1D_unknown_domain(self): x = pybamm.SpatialVariable("x", domain="SEI layer", coord_sys="cartesian") @@ -220,7 +215,7 @@ def test_processed_variable_1D_unknown_domain(self): c = pybamm.StateVector(slice(0, var_pts[x]), domain=["SEI layer"]) c.mesh = mesh["SEI layer"] c_casadi = to_casadi(c, y_sol) - pybamm.ProcessedVariableComputed([c], [c_casadi], [y_sol], solution, warn=False) + pybamm.ProcessedVariableComputed([c], [c_casadi], [y_sol], solution) def test_processed_variable_2D_x_r(self): var = pybamm.Variable( @@ -330,13 +325,12 @@ def test_processed_variable_2D_x_z(self): x_s_edge.mesh = disc.mesh["separator"] x_s_edge.secondary_mesh = disc.mesh["current collector"] x_s_casadi = to_casadi(x_s_edge, y_sol) - processed_x_s_edge = pybamm.ProcessedVariable( + processed_x_s_edge = pybamm.process_variable( [x_s_edge], [x_s_casadi], pybamm.Solution( t_sol, y_sol, tests.get_base_model_with_battery_geometry(), {} ), - warn=False, ) np.testing.assert_array_equal( x_s_edge.entries.flatten(), processed_x_s_edge.entries[:, :, 0].T.flatten() @@ -371,7 +365,6 @@ def test_processed_variable_2D_space_only(self): [var_casadi], [y_sol], pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}), - warn=False, ) np.testing.assert_array_equal( processed_var.entries, @@ -408,7 +401,6 @@ def test_processed_variable_2D_fixed_t_scikit(self): [var_casadi], [u_sol], pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}), - warn=False, ) np.testing.assert_array_equal( processed_var.entries, np.reshape(u_sol, [len(y), len(z), len(t_sol)]) @@ -434,5 +426,4 @@ def test_3D_raises_error(self): [var_casadi], [u_sol], pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}), - warn=False, ) diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py index 3aff012d5b..4ac9312531 100644 --- a/tests/unit/test_solvers/test_solution.py +++ b/tests/unit/test_solvers/test_solution.py @@ -77,14 +77,16 @@ def test_add_solutions(self): # Set up first solution t1 = np.linspace(0, 1) y1 = np.tile(t1, (20, 1)) - sol1 = pybamm.Solution(t1, y1, pybamm.BaseModel(), {"a": 1}) + yp1 = np.tile(t1, (30, 1)) + sol1 = pybamm.Solution(t1, y1, pybamm.BaseModel(), {"a": 1}, all_yps=yp1) sol1.solve_time = 1.5 sol1.integration_time = 0.3 # Set up second solution t2 = np.linspace(1, 2) y2 = np.tile(t2, (20, 1)) - sol2 = pybamm.Solution(t2, y2, pybamm.BaseModel(), {"a": 2}) + yp2 = np.tile(t1, (30, 1)) + sol2 = pybamm.Solution(t2, y2, pybamm.BaseModel(), {"a": 2}, all_yps=yp2) sol2.solve_time = 1 sol2.integration_time = 0.5