From c259bb0e6046dc5fd591e39d487ba6cf340b2915 Mon Sep 17 00:00:00 2001 From: Alec Bills Date: Thu, 31 Oct 2024 12:36:51 -0700 Subject: [PATCH] add tests for coverage; valentin comments --- .../expression_tree/coupled_variable.py | 11 +++++----- src/pybamm/models/base_model.py | 2 +- .../test_coupled_variable.py | 20 +++++++++++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/pybamm/expression_tree/coupled_variable.py b/src/pybamm/expression_tree/coupled_variable.py index b85589cbe0..04d03d2792 100644 --- a/src/pybamm/expression_tree/coupled_variable.py +++ b/src/pybamm/expression_tree/coupled_variable.py @@ -13,7 +13,7 @@ class CoupledVariable(pybamm.Symbol): name : str name of the node domain : iterable of str - list of domains that this variable is valid over + list of domains that this coupled variable is valid over """ def __init__( @@ -31,11 +31,9 @@ def _evaluate_for_shape(self): return pybamm.evaluate_for_shape_using_domain(self.domains) def create_copy(self): - """See :meth:`pybamm.Symbol.new_copy()`.""" - new_input_parameter = CoupledVariable( - self.name, self.domain, expected_size=self._expected_size - ) - return new_input_parameter + """Creates a new copy of the coupled variable.""" + new_coupled_variable = CoupledVariable(self.name, self.domain) + return new_coupled_variable @property def children(self): @@ -46,6 +44,7 @@ def children(self, expr): self._children = expr def set_coupled_variable(self, symbol, expr): + """Sets the children of the coupled variable to the expression passed in expr. If the symbol is not the coupled variable, then it searches the children of the symbol for the coupled variable. The coupled variable will be replaced by its first child (symbol.children[0], which should be expr) in the discretisation step.""" if self == symbol: symbol.children = [ expr, diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index 85111af7f7..f7e8f70f32 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -206,7 +206,7 @@ def coupled_variables(self, coupled_variables): self._coupled_variables = coupled_variables def list_coupled_variables(self): - list(self._coupled_variables.keys()) + return list(self._coupled_variables.keys()) @property def variables(self): diff --git a/tests/unit/test_expression_tree/test_coupled_variable.py b/tests/unit/test_expression_tree/test_coupled_variable.py index 53056e9b25..3e60c412e5 100644 --- a/tests/unit/test_expression_tree/test_coupled_variable.py +++ b/tests/unit/test_expression_tree/test_coupled_variable.py @@ -7,6 +7,8 @@ import pybamm +import pytest + def combine_models(list_of_models): model = pybamm.BaseModel() @@ -72,3 +74,21 @@ def test_coupled_variable(self): np.testing.assert_almost_equal( solution["a"].entries, solution["b"].entries, decimal=10 ) + + assert set(model.list_coupled_variables()) == set(["a", "b"]) + + def test_create_copy(self): + a = pybamm.CoupledVariable("a") + b = a.create_copy() + assert a == b + + def test_setter(self): + model = pybamm.BaseModel() + a = pybamm.CoupledVariable("a") + coupled_variables = {"a": a} + model.coupled_variables = coupled_variables + assert model.coupled_variables == coupled_variables + + with pytest.raises(ValueError, match="Coupled variable with name"): + coupled_variables = {"b": a} + model.coupled_variables = coupled_variables