Skip to content

Commit

Permalink
Introduce function.arguments_for
Browse files Browse the repository at this point in the history
This patch adds the function.arguments_for function, which can be used to
retrieve array arguments. Unlike that .arguments attribute (which will be
deprecated in the next commit) the returned dictionary maps to fully formed
Argument objects, that can directly be used in operations such as derivative.
Furthermore, the function is nutils-dispatched, so that if can be used even if
the array is wrapped in an SI quantity.
  • Loading branch information
gertjanvanzwieten committed Feb 4, 2025
1 parent 649029e commit 6bd65da
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/cahnhilliard.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def main(size: Length = parse('10cm'),
system = System(nrg / tol, trial='φ,η')

numpy.random.seed(seed)
args = dict(φ=numpy.random.normal(0, .5, system.argshapes['φ'])) # initial condition
args = dict(φ=numpy.random.normal(0, .5, function.arguments_for(nrg)['φ'].shape)) # initial condition

with log.iter.fraction('timestep', range(round(endtime / timestep))) as steps:
for istep in steps:
Expand Down
4 changes: 2 additions & 2 deletions examples/drivencavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def main(nelems: int = 32,
# strong enforcement of non-penetrating boundary conditions
sqr = domain.boundary.integral('(u_k n_k)^2 dS' @ ns, degree=degree*2)
cons = System(sqr, trial='u').solve_constraints(droptol=1e-15)
cons['p'] = numpy.zeros(res.argshapes['p'], dtype=bool)
cons['p'] = numpy.zeros(function.arguments_for(res)['p'].shape, dtype=bool)
cons['p'].flat[0] = True # point constraint

if strongbc:
Expand Down Expand Up @@ -158,7 +158,7 @@ def postprocess(domain, ns, **arguments):

# reconstruct velocity streamlines
sqr = domain.integral('Σ_i (u_i - ε_ij ∇_j(ψ))^2 dV' @ ns, degree=4)
consψ = numpy.zeros(sqr.argshapes['ψ'], dtype=bool)
consψ = numpy.zeros(function.arguments_for(sqr)['ψ'].shape, dtype=bool)
consψ.flat[0] = True # point constraint
arguments = System(sqr, trial='ψ').solve(arguments=arguments, constrain={'ψ': consψ})

Expand Down
5 changes: 5 additions & 0 deletions nutils/SI.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ def __field(op, *args, **kwargs):
dims, args = zip(*Quantity.__unpack(*args)) # we abuse the fact that unpack str returns dimensionless
return functools.reduce(operator.mul, dims).wrap(op(*args, **kwargs))

@register('arguments_for')
def __attribute(op, *args, **kwargs):
__dims, args = zip(*Quantity.__unpack(*args))
return op(*args, **kwargs)

del register

## DEFINE OPERATORS
Expand Down
24 changes: 24 additions & 0 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,30 @@ def lower(self, args: LowerArgs) -> evaluable.Array:
return evaluable.prependaxes(self._array, args.points_shape)


@nutils_dispatch
def arguments_for(*arrays) -> Dict[str, Argument]:
'''Get all arguments that array(s) depend on.
Given any number of arrays, return a dictionary of all arguments involved,
mapping the name to the :class:`Argument` object. Raise a ``ValueError`` if
arrays have conflicting arguments, i.e. sharing a name but differing in
shape and/or dtype.
'''

arguments = {}
for array in arrays:
if isinstance(array, Array):
for name, (shape, dtype) in array._arguments.items():
argument = arguments.get(name)
if argument is None:
arguments[name] = Argument(name, shape, dtype)
elif argument.shape != shape:
raise ValueError(f'inconsistent shapes for argument {name!r}')
elif argument.dtype != dtype:
raise ValueError(f'inconsistent dtypes for argument {name!r}')
return arguments


# BASES


Expand Down
27 changes: 27 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,3 +1450,30 @@ def test_lower_spaces(self):
topo, geom = mesh.rectilinear([3])
with self.assertRaisesRegex(ValueError, 'cannot lower function with spaces \(.+\) - did you forget integral or sample?'):
function.factor(geom)


class arguments_for(TestCase):

def test_single(self):
x = function.field('x', numpy.array([1,2,3]), shape=(2,), dtype=int)
f = x**2
self.assertEqual({a.name: (a.shape, a.dtype) for a in function.arguments_for(f).values()},
{'x': ((3,2), int)})

def test_multiple(self):
x = function.field('x', numpy.array([1,2,3]), shape=(2,), dtype=int)
y = function.field('y', numpy.array([4,5]), dtype=float)
z = function.field('z', dtype=complex)
f = x * y
g = x**2 * z
self.assertEqual({a.name: (a.shape, a.dtype) for a in function.arguments_for(f, g).values()},
{'x': ((3,2), int), 'y': ((2,), float), 'z': ((), complex)})

def test_conflict(self):
x = function.field('x', numpy.array([1,2,3]), shape=(2,), dtype=int)
y1 = function.field('y', numpy.array([4,5]), dtype=float)
y2 = function.field('y', dtype=complex)
f = x * y1
g = x**2 * y2
with self.assertRaisesRegex(ValueError, "inconsistent shapes for argument 'y'"):
function.arguments_for(f, g)

0 comments on commit 6bd65da

Please sign in to comment.