From 70b4a4097f9ab8d6f0b02c8114336ea689e1d551 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Mon, 3 Feb 2025 15:02:03 +0100 Subject: [PATCH] arguments_for wip --- examples/drivencavity.py | 4 ++-- nutils/SI.py | 5 +++++ nutils/function.py | 15 +++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/examples/drivencavity.py b/examples/drivencavity.py index d1d28f697..cf70d42d2 100644 --- a/examples/drivencavity.py +++ b/examples/drivencavity.py @@ -119,7 +119,7 @@ def main(nelems: int = 32, # strong enforcement of non-penetrating boundary conditions sqr = domain.boundary.integral('(u_k n_k)^2 dS' @ ns, degree=degree*2) cons = System(sqr, trial='u').solve_constraints(droptol=1e-15) - cons['p'] = numpy.zeros(res.argshapes['p'], dtype=bool) + cons['p'] = numpy.zeros(function.arguments_for(res)['p'].shape, dtype=bool) cons['p'].flat[0] = True # point constraint if strongbc: @@ -158,7 +158,7 @@ def postprocess(domain, ns, **arguments): # reconstruct velocity streamlines sqr = domain.integral('Σ_i (u_i - ε_ij ∇_j(ψ))^2 dV' @ ns, degree=4) - consψ = numpy.zeros(sqr.argshapes['ψ'], dtype=bool) + consψ = numpy.zeros(function.arguments_for(sqr)['ψ'].shape, dtype=bool) consψ.flat[0] = True # point constraint arguments = System(sqr, trial='ψ').solve(arguments=arguments, constrain={'ψ': consψ}) diff --git a/nutils/SI.py b/nutils/SI.py index 245a67499..cb19e6d6e 100644 --- a/nutils/SI.py +++ b/nutils/SI.py @@ -411,6 +411,11 @@ def __field(op, *args, **kwargs): dims, args = zip(*Quantity.__unpack(*args)) # we abuse the fact that unpack str returns dimensionless return functools.reduce(operator.mul, dims).wrap(op(*args, **kwargs)) + @register('arguments') + def __attribute(op, *args, **kwargs): + __dims, args = zip(*Quantity.__unpack(*args)) + return op(*args, **kwargs) + del register ## DEFINE OPERATORS diff --git a/nutils/function.py b/nutils/function.py index b6bdbbe1e..ba5b073fa 100644 --- a/nutils/function.py +++ b/nutils/function.py @@ -2414,6 +2414,21 @@ def lower(self, args: LowerArgs) -> evaluable.Array: return evaluable.prependaxes(self._array, args.points_shape) +@nutils_dispatch +def arguments_for(*arrays): + arguments = {} + for array in arrays: + for name, (shape, dtype) in array.arguments.items(): + argument = arguments.get(name) + if argument is None: + arguments[name] = Argument(name, shape, dtype) + elif argument.shape != shape: + raise ValueError(f'inconsistent shapes for argument {name!r}') + elif argument.dtype != dtype: + raise ValueError(f'inconsistent dtypes for argument {name!r}') + return arguments + + # BASES