diff --git a/src/pybamm/__init__.py b/src/pybamm/__init__.py index 68529156e3..3de52e5724 100644 --- a/src/pybamm/__init__.py +++ b/src/pybamm/__init__.py @@ -39,6 +39,7 @@ from .expression_tree.parameter import Parameter, FunctionParameter from .expression_tree.scalar import Scalar from .expression_tree.variable import * +from .expression_tree.coupled_variable import * from .expression_tree.independent_variable import * from .expression_tree.independent_variable import t from .expression_tree.vector import Vector diff --git a/src/pybamm/discretisations/discretisation.py b/src/pybamm/discretisations/discretisation.py index af4bd2edd6..2ca8c87649 100644 --- a/src/pybamm/discretisations/discretisation.py +++ b/src/pybamm/discretisations/discretisation.py @@ -938,6 +938,11 @@ def _process_symbol(self, symbol): if symbol._expected_size is None: symbol._expected_size = expected_size return symbol.create_copy() + + elif isinstance(symbol, pybamm.CoupledVariable): + new_symbol = self.process_symbol(symbol.children[0]) + return new_symbol + else: # Backup option: return the object return symbol diff --git a/src/pybamm/expression_tree/coupled_variable.py b/src/pybamm/expression_tree/coupled_variable.py new file mode 100644 index 0000000000..04d03d2792 --- /dev/null +++ b/src/pybamm/expression_tree/coupled_variable.py @@ -0,0 +1,55 @@ +import pybamm + +from pybamm.type_definitions import DomainType + + +class CoupledVariable(pybamm.Symbol): + """ + A node in the expression tree representing a variable whose equation is set by a different model or submodel. + + + Parameters + ---------- + name : str + name of the node + domain : iterable of str + list of domains that this coupled variable is valid over + """ + + def __init__( + self, + name: str, + domain: DomainType = None, + ) -> None: + super().__init__(name, domain=domain) + + def _evaluate_for_shape(self): + """ + Returns the scalar 'NaN' to represent the shape of a parameter. + See :meth:`pybamm.Symbol.evaluate_for_shape()` + """ + return pybamm.evaluate_for_shape_using_domain(self.domains) + + def create_copy(self): + """Creates a new copy of the coupled variable.""" + new_coupled_variable = CoupledVariable(self.name, self.domain) + return new_coupled_variable + + @property + def children(self): + return self._children + + @children.setter + def children(self, expr): + self._children = expr + + def set_coupled_variable(self, symbol, expr): + """Sets the children of the coupled variable to the expression passed in expr. If the symbol is not the coupled variable, then it searches the children of the symbol for the coupled variable. The coupled variable will be replaced by its first child (symbol.children[0], which should be expr) in the discretisation step.""" + if self == symbol: + symbol.children = [ + expr, + ] + else: + for child in symbol.children: + self.set_coupled_variable(child, expr) + symbol.set_id() diff --git a/src/pybamm/models/base_model.py b/src/pybamm/models/base_model.py index f6f47acc55..f7e8f70f32 100644 --- a/src/pybamm/models/base_model.py +++ b/src/pybamm/models/base_model.py @@ -56,6 +56,7 @@ def __init__(self, name="Unnamed model"): self._boundary_conditions = {} self._variables_by_submodel = {} self._variables = pybamm.FuzzyDict({}) + self._coupled_variables = {} self._summary_variables = [] self._events = [] self._concatenated_rhs = None @@ -182,6 +183,31 @@ def boundary_conditions(self): def boundary_conditions(self, boundary_conditions): self._boundary_conditions = BoundaryConditionsDict(boundary_conditions) + @property + def coupled_variables(self): + """Returns a dictionary mapping strings to expressions representing variables needed by the model but whose equations were set by other models.""" + return self._coupled_variables + + @coupled_variables.setter + def coupled_variables(self, coupled_variables): + for name, var in coupled_variables.items(): + if ( + isinstance(var, pybamm.CoupledVariable) + and var.name != name + # Exception if the variable is also there under its own name + and not ( + var.name in coupled_variables and coupled_variables[var.name] == var + ) + ): + raise ValueError( + f"Coupled variable with name '{var.name}' is in coupled variables dictionary with " + f"name '{name}'. Names must match." + ) + self._coupled_variables = coupled_variables + + def list_coupled_variables(self): + return list(self._coupled_variables.keys()) + @property def variables(self): """Returns a dictionary mapping strings to expressions representing the model's useful variables.""" diff --git a/tests/unit/test_expression_tree/test_coupled_variable.py b/tests/unit/test_expression_tree/test_coupled_variable.py new file mode 100644 index 0000000000..3e60c412e5 --- /dev/null +++ b/tests/unit/test_expression_tree/test_coupled_variable.py @@ -0,0 +1,94 @@ +# +# Tests for the CoupledVariable class +# + + +import numpy as np + +import pybamm + +import pytest + + +def combine_models(list_of_models): + model = pybamm.BaseModel() + + for submodel in list_of_models: + model.coupled_variables.update(submodel.coupled_variables) + model.variables.update(submodel.variables) + model.rhs.update(submodel.rhs) + model.algebraic.update(submodel.algebraic) + model.initial_conditions.update(submodel.initial_conditions) + model.boundary_conditions.update(submodel.boundary_conditions) + + for name, coupled_variable in model.coupled_variables.items(): + if name in model.variables: + for sym in model.rhs.values(): + coupled_variable.set_coupled_variable(sym, model.variables[name]) + for sym in model.algebraic.values(): + coupled_variable.set_coupled_variable(sym, model.variables[name]) + return model + + +class TestCoupledVariable: + def test_coupled_variable(self): + model_1 = pybamm.BaseModel() + model_1_var_1 = pybamm.CoupledVariable("a") + model_1_var_2 = pybamm.Variable("b") + model_1.rhs[model_1_var_2] = -0.2 * model_1_var_1 + model_1.variables["b"] = model_1_var_2 + model_1.coupled_variables["a"] = model_1_var_1 + model_1.initial_conditions[model_1_var_2] = 1.0 + + model_2 = pybamm.BaseModel() + model_2_var_1 = pybamm.Variable("a") + model_2_var_2 = pybamm.CoupledVariable("b") + model_2.rhs[model_2_var_1] = -0.2 * model_2_var_2 + model_2.variables["a"] = model_2_var_1 + model_2.coupled_variables["b"] = model_2_var_2 + model_2.initial_conditions[model_2_var_1] = 1.0 + + model = combine_models([model_1, model_2]) + + params = pybamm.ParameterValues({}) + geometry = {} + + # Process parameters + params.process_model(model) + params.process_geometry(geometry) + + # mesh and discretise + submesh_types = {} + var_pts = {} + mesh = pybamm.Mesh(geometry, submesh_types, var_pts) + + spatial_methods = {} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + # solve + solver = pybamm.CasadiSolver() + t = np.linspace(0, 10, 1000) + solution = solver.solve(model, t) + + np.testing.assert_almost_equal( + solution["a"].entries, solution["b"].entries, decimal=10 + ) + + assert set(model.list_coupled_variables()) == set(["a", "b"]) + + def test_create_copy(self): + a = pybamm.CoupledVariable("a") + b = a.create_copy() + assert a == b + + def test_setter(self): + model = pybamm.BaseModel() + a = pybamm.CoupledVariable("a") + coupled_variables = {"a": a} + model.coupled_variables = coupled_variables + assert model.coupled_variables == coupled_variables + + with pytest.raises(ValueError, match="Coupled variable with name"): + coupled_variables = {"b": a} + model.coupled_variables = coupled_variables