Skip to content

Commit

Permalink
Deprecate Array.arguments and Array.argshapes
Browse files Browse the repository at this point in the history
This patch deprecates the .arguments and .argshapes attributes of the
function.Array object in favour of the newly introduced function.arguments_for
function.
  • Loading branch information
gertjanvanzwieten committed Feb 4, 2025
1 parent 6bd65da commit be8bc53
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 42 deletions.
64 changes: 35 additions & 29 deletions nutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(self, shape: Shape, dtype: DType, spaces: FrozenSet[str], arguments
self.shape = tuple(sh.__index__() for sh in shape)
self.dtype = dtype
self.spaces = frozenset(spaces)
self.arguments = dict(arguments)
self._arguments = dict(arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
raise NotImplementedError
Expand All @@ -236,8 +236,8 @@ def as_evaluable_array(self) -> evaluable.Array:
return self.lower(LowerArgs((), {}, {}))

def __index__(self):
if self.arguments or self.spaces:
raise ValueError('cannot convert non-constant array to index: arguments={}'.format(','.join(self.arguments)))
if self._arguments or self.spaces:
raise ValueError('cannot convert non-constant array to index: arguments={}'.format(','.join(self._arguments)))
elif self.ndim:
raise ValueError('cannot convert non-scalar array to index: shape={}'.format(self.shape))
elif self.dtype != int:
Expand Down Expand Up @@ -483,11 +483,17 @@ def replace(self, __arguments: Mapping[str, IntoArray]) -> 'Array':

def contains(self, __name: str) -> bool:
'Test if target occurs in this function.'
return __name in self.arguments
return __name in self._arguments

@property
def arguments(self) -> Mapping[str, Tuple[Shape, DType]]:
warnings.deprecation('array.arguments is deprecated and will be removed in Nutils 10, please use function.arguments_for(array) instead')
return self._arguments.copy()

Check warning on line 491 in nutils/function.py

View workflow job for this annotation

GitHub Actions / Test coverage

Lines not covered

Lines 490–491 of `nutils/function.py` are not covered by tests.

@property
def argshapes(self) -> Mapping[str, Tuple[int, ...]]:
return {name: shape for name, (shape, dtype) in self.arguments.items()}
warnings.deprecation("array.argshapes[...] is deprecated and will be removed in Nutils 10, please use function.arguments_for(array)[...].shape instead")
return {name: shape for name, (shape, dtype) in self._arguments.items()}

Check warning on line 496 in nutils/function.py

View workflow job for this annotation

GitHub Actions / Test coverage

Lines not covered

Lines 495–496 of `nutils/function.py` are not covered by tests.

def conjugate(self):
'''Return the complex conjugate, elementwise.
Expand Down Expand Up @@ -684,15 +690,15 @@ def __init__(self, args: Iterable[Any], shape: Tuple[int], dtype: DType, npointw
self._args = args
self._npointwise = npointwise
spaces = frozenset(space for arg in args if isinstance(arg, Array) for space in arg.spaces)
arguments = _join_arguments(arg.arguments for arg in args if isinstance(arg, Array))
arguments = _join_arguments(arg._arguments for arg in args if isinstance(arg, Array))
super().__init__(shape=(*points_shape, *shape), dtype=dtype, spaces=spaces, arguments=arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
evalargs = tuple(arg.lower(args) if isinstance(arg, Array) else arg for arg in self._args)
add_points_shape = tuple(map(evaluable.asarray, self.shape[:self._npointwise]))
points_shape = args.points_shape + add_points_shape
coordinates = {space: evaluable.Transpose.to_end(evaluable.appendaxes(coords, add_points_shape), coords.ndim-1) for space, coords in args.coordinates.items()}
return _CustomEvaluable(type(self).__name__, self.evalf, self.partial_derivative, evalargs, self.shape[self._npointwise:], self.dtype, self.spaces, types.frozendict(self.arguments), LowerArgs(points_shape, types.frozendict(args.transform_chains), types.frozendict(coordinates)))
return _CustomEvaluable(type(self).__name__, self.evalf, self.partial_derivative, evalargs, self.shape[self._npointwise:], self.dtype, self.spaces, types.frozendict(self._arguments), LowerArgs(points_shape, types.frozendict(args.transform_chains), types.frozendict(coordinates)))

@classmethod
def evalf(cls, *args: Any) -> numpy.ndarray:
Expand Down Expand Up @@ -833,7 +839,7 @@ class _WithoutPoints:
def __init__(self, __arg: Array) -> None:
self._arg = __arg
self.spaces = __arg.spaces
self.arguments = __arg.arguments
self._arguments = __arg._arguments

def lower(self, args: LowerArgs) -> evaluable.Array:
return self._arg.lower(LowerArgs((), args.transform_chains, {}))
Expand All @@ -851,7 +857,7 @@ def __init__(self, lower: Callable[..., evaluable.Array], *args: Lowerable, shap
self._args = args
assert all(hasattr(arg, 'lower') for arg in self._args)
spaces = frozenset(space for arg in args for space in arg.spaces)
arguments = _join_arguments(arg.arguments for arg in self._args)
arguments = _join_arguments(arg._arguments for arg in self._args)
super().__init__(shape, dtype, spaces, arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
Expand Down Expand Up @@ -916,8 +922,8 @@ def __init__(self, arg: Array, replacements: Dict[str, Array]) -> None:
raise ValueError(f'replacement functions cannot be bound to a space, but replacement for Argument {old.name!r} is bound to {", ".join(new.spaces)}.')
self._replacements[old.name] = new
# Build arguments map with replacements.
unreplaced = {name: shape_dtype for name, shape_dtype in arg.arguments.items() if name not in replacements}
arguments = _join_arguments([unreplaced] + [replacement.arguments for replacement in self._replacements.values()])
unreplaced = {name: shape_dtype for name, shape_dtype in arg._arguments.items() if name not in replacements}
arguments = _join_arguments([unreplaced] + [replacement._arguments for replacement in self._replacements.values()])
super().__init__(arg.shape, arg.dtype, arg.spaces, arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
Expand Down Expand Up @@ -950,7 +956,7 @@ def to_end(cls, array: Array, *axes: int) -> Array:
def __init__(self, arg: Array, axes: Tuple[int, ...]) -> None:
self._arg = arg
self._axes = tuple(n.__index__() for n in axes)
super().__init__(tuple(arg.shape[axis] for axis in axes), arg.dtype, arg.spaces, arg.arguments)
super().__init__(tuple(arg.shape[axis] for axis in axes), arg.dtype, arg.spaces, arg._arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
arg = self._arg.lower(args)
Expand All @@ -964,7 +970,7 @@ class _Opposite(Array):
def __init__(self, arg: Array, space: str) -> None:
self._arg = arg
self._space = space
super().__init__(arg.shape, arg.dtype, arg.spaces, arg.arguments)
super().__init__(arg.shape, arg.dtype, arg.spaces, arg._arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
oppargs = LowerArgs(args.points_shape, dict(args.transform_chains), args.coordinates)
Expand Down Expand Up @@ -1031,7 +1037,7 @@ def __init__(self, arg: Array, var: Argument) -> None:
self._arg = arg
self._var = var
self._eval_var = evaluable.Argument(var.name, tuple(evaluable.constant(n) for n in var.shape), var.dtype)
arguments = _join_arguments((arg.arguments, var.arguments))
arguments = _join_arguments((arg._arguments, var._arguments))
super().__init__(arg.shape+var.shape, complex if var.dtype == complex else arg.dtype, arg.spaces | var.spaces, arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
Expand All @@ -1056,7 +1062,7 @@ def __init__(self, func: Array, geom: Array) -> None:
common_shape = broadcast_shapes(func.shape, geom.shape[:-1])
self._func = numpy.broadcast_to(func, common_shape)
self._geom = numpy.broadcast_to(geom, (*common_shape, geom.shape[-1]))
arguments = _join_arguments((func.arguments, geom.arguments))
arguments = _join_arguments((func._arguments, geom._arguments))
super().__init__(self._geom.shape, complex if func.dtype == complex else float, func.spaces | geom.spaces, arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
Expand All @@ -1082,7 +1088,7 @@ def __init__(self, func: Array, geom: Array) -> None:
common_shape = broadcast_shapes(func.shape, geom.shape[:-1])
self._func = numpy.broadcast_to(func, common_shape)
self._geom = numpy.broadcast_to(geom, (*common_shape, geom.shape[-1]))
arguments = _join_arguments((func.arguments, geom.arguments))
arguments = _join_arguments((func._arguments, geom._arguments))
super().__init__(self._geom.shape, complex if func.dtype == complex else float, func.spaces | geom.spaces, arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
Expand Down Expand Up @@ -1112,7 +1118,7 @@ def __init__(self, geom: Array, tip_dim: Optional[int] = None) -> None:
'not greater than the dimension of the geometry.')
self._tip_dim = tip_dim
self._geom = geom
super().__init__((), float, geom.spaces, geom.arguments)
super().__init__((), float, geom.spaces, geom._arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
geom = self._geom.lower(args)
Expand All @@ -1133,7 +1139,7 @@ class _Normal(Array):
def __init__(self, geom: Array) -> None:
self._geom = geom
assert geom.dtype == float
super().__init__(geom.shape, float, geom.spaces, geom.arguments)
super().__init__(geom.shape, float, geom.spaces, geom._arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
geom = self._geom.lower(args)
Expand Down Expand Up @@ -1168,7 +1174,7 @@ class _ExteriorNormal(Array):
def __init__(self, rgrad: Array) -> None:
assert rgrad.dtype == float and rgrad.shape[-2] == rgrad.shape[-1] + 1
self._rgrad = rgrad
super().__init__(rgrad.shape[:-1], float, rgrad.spaces, rgrad.arguments)
super().__init__(rgrad.shape[:-1], float, rgrad.spaces, rgrad._arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
rgrad = self._rgrad.lower(args)
Expand All @@ -1195,7 +1201,7 @@ def __init__(self, __arrays: Sequence[IntoArray], axis: int) -> None:
shape=(*shape0[:self.axis], builtins.sum(array.shape[self.axis] for array in self.arrays), *shape0[self.axis+1:]),
dtype=self.arrays[0].dtype,
spaces=functools.reduce(operator.or_, (array.spaces for array in self.arrays)),
arguments=_join_arguments(array.arguments for array in self.arrays))
arguments=_join_arguments(array._arguments for array in self.arrays))

def lower(self, args: LowerArgs) -> evaluable.Array:
return util.sum(evaluable._inflate(array.lower(args), evaluable.Range(evaluable.constant(array.shape[self.axis])) + offset, evaluable.constant(self.shape[self.axis]), self.axis-self.ndim)
Expand Down Expand Up @@ -1700,14 +1706,14 @@ def _argument_to_array(d: Any, array: Array) -> Iterable[Tuple[Argument, Array]]
arg, new = item.split(':', 1) if isinstance(item, str) else item

if isinstance(arg, str):
if arg not in array.arguments:
if arg not in array._arguments:
continue
arg = Argument(arg, *array.arguments[arg])
arg = Argument(arg, *array._arguments[arg])
elif not isinstance(arg, Argument):
raise ValueError('Key must be string or argument')
elif arg.name not in arguments:
continue
elif array.arguments[arg.name] != (arg.shape, arg.dtype):
elif array._arguments[arg.name] != (arg.shape, arg.dtype):
raise ValueError(f'Argument {arg.name!r} has wrong shape or dtype')

Check warning on line 1717 in nutils/function.py

View workflow job for this annotation

GitHub Actions / Test coverage

Lines not covered

Lines 1712–1717 of `nutils/function.py` are not covered by tests.

if isinstance(new, str):
Expand Down Expand Up @@ -1852,14 +1858,14 @@ def derivative(__arg: IntoArray, __var: Union[str, 'Argument']) -> Array:

arg = Array.cast(__arg)
if isinstance(__var, str):
if __var not in arg.arguments:
if __var not in arg._arguments:
raise ValueError('no such argument: {}'.format(__var))
shape, dtype = arg.arguments[__var]
shape, dtype = arg._arguments[__var]
__var = Argument(__var, shape, dtype=dtype)
elif not isinstance(__var, Argument):
raise ValueError('Expected an instance of `Argument` as second argument of `derivative` but got a `{}.{}`.'.format(type(__var).__module__, type(__var).__qualname__))
if __var.name in arg.arguments:
shape, dtype = arg.arguments[__var.name]
if __var.name in arg._arguments:
shape, dtype = arg._arguments[__var.name]
if __var.shape != shape:
raise ValueError('Argument {!r} has shape {} in the function, but the derivative to {!r} with shape {} was requested.'.format(__var.name, shape, __var.name, __var.shape))
if __var.dtype != dtype:
Expand Down Expand Up @@ -2411,7 +2417,7 @@ class factor(Array):

def __init__(self, array: Array) -> None:
self._array = evaluable.factor(array)
super().__init__(shape=array.shape, dtype=array.dtype, spaces=set(), arguments=array.arguments)
super().__init__(shape=array.shape, dtype=array.dtype, spaces=set(), arguments=array._arguments)

def lower(self, args: LowerArgs) -> evaluable.Array:
return evaluable.prependaxes(self._array, args.points_shape)
Expand Down Expand Up @@ -2502,7 +2508,7 @@ def __init__(self, ndofs: int, nelems: int, index: Array, coords: Array) -> None
self.nelems = nelems
self.index = Array.cast(index, dtype=int, ndim=0)
self.coords = coords
arguments = _join_arguments((index.arguments, coords.arguments))
arguments = _join_arguments((index._arguments, coords._arguments))
super().__init__((ndofs,), float, spaces=index.spaces | coords.spaces, arguments=arguments)

_index = evaluable.Argument('_index', shape=(), dtype=int)
Expand Down
6 changes: 3 additions & 3 deletions nutils/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ class _Integral(function.Array):
def __init__(self, integrand: function.Array, sample: Sample) -> None:
self._integrand = integrand
self._sample = sample
super().__init__(shape=integrand.shape, dtype=float if integrand.dtype in (bool, int) else integrand.dtype, spaces=integrand.spaces - frozenset(sample.spaces), arguments=integrand.arguments)
super().__init__(shape=integrand.shape, dtype=float if integrand.dtype in (bool, int) else integrand.dtype, spaces=integrand.spaces - frozenset(sample.spaces), arguments=integrand._arguments)

def lower(self, args: function.LowerArgs) -> evaluable.Array:
ielem = evaluable.loop_index('_sample_' + '_'.join(self._sample.spaces), self._sample.nelems)
Expand All @@ -928,7 +928,7 @@ class _ConcatenatePoints(function.Array):
def __init__(self, func: function.Array, sample: _TransformChainsSample) -> None:
self._func = func
self._sample = sample
super().__init__(shape=(sample.npoints, *func.shape), dtype=func.dtype, spaces=func.spaces - frozenset(sample.spaces), arguments=func.arguments)
super().__init__(shape=(sample.npoints, *func.shape), dtype=func.dtype, spaces=func.spaces - frozenset(sample.spaces), arguments=func._arguments)

def lower(self, args: function.LowerArgs) -> evaluable.Array:
axis = len(args.points_shape)
Expand All @@ -948,7 +948,7 @@ def __init__(self, func: function.Array, indices: evaluable.Array) -> None:
self._func = func
self._indices = indices
assert indices.ndim == 1 and func.shape[0] == indices.shape[0].__index__()
super().__init__(shape=func.shape, dtype=func.dtype, spaces=func.spaces, arguments=func.arguments)
super().__init__(shape=func.shape, dtype=func.dtype, spaces=func.spaces, arguments=func._arguments)

def lower(self, args: function.LowerArgs) -> evaluable.Array:
func = self._func.lower(args)
Expand Down
5 changes: 2 additions & 3 deletions nutils/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,9 +1002,8 @@ def _split_trial_test(target):
def _target_helper(target, *args):
trial, test = _split_trial_test(target)
if test is not None:
arguments = function._join_arguments(arg.arguments for arg in args)
testargs = [function.Argument(t, *arguments[t]) for t in test]
args = [map(arg.derivative, testargs) for arg in args]
arguments = function.arguments_for(*args)
args = [[arg.derivative(arguments[t]) for t in test] for arg in args]
elif len(args) > 1:
shapes = [{f.shape for f in ziparg if f is not None} for ziparg in zip(*args)]
if any(len(arg) != len(shapes) for arg in args) or any(len(shape) != 1 for shape in shapes):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_expression_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,27 +560,27 @@ def test_define_for_3d(self):
def test_add_single_field(self):
ns = expression_v2.Namespace()
ns.add_field('u', numpy.array([1,2,3]))
self.assertEqual(ns.u.argshapes, dict(u=(3,)))
self.assertEqual({a.name: a.shape for a in function.arguments_for(ns.u).values()}, dict(u=(3,)))
self.assertEqual(ns.u.shape, ())

def test_add_multiple_fields(self):
ns = expression_v2.Namespace()
ns.add_field(('u', 'v'), numpy.array([1,2,3]))
self.assertEqual(ns.u.argshapes, dict(u=(3,)))
self.assertEqual({a.name: a.shape for a in function.arguments_for(ns.u).values()}, dict(u=(3,)))
self.assertEqual(ns.u.shape, ())
self.assertEqual(ns.v.argshapes, dict(v=(3,)))
self.assertEqual({a.name: a.shape for a in function.arguments_for(ns.v).values()}, dict(v=(3,)))
self.assertEqual(ns.v.shape, ())

def test_add_single_field_multiple_bases(self):
ns = expression_v2.Namespace()
ns.add_field('u', numpy.array([1,2,3]), numpy.array([4,5,6,7]))
self.assertEqual(ns.u.argshapes, dict(u=(3,4)))
self.assertEqual({a.name: a.shape for a in function.arguments_for(ns.u).values()}, dict(u=(3,4)))
self.assertEqual(ns.u.shape, ())

def test_add_single_field_with_shape(self):
ns = expression_v2.Namespace()
ns.add_field('u', numpy.array([1,2,3]), shape=(2,))
self.assertEqual(ns.u.argshapes, dict(u=(3,2)))
self.assertEqual({a.name: a.shape for a in function.arguments_for(ns.u).values()}, dict(u=(3,2)))
self.assertEqual(ns.u.shape, (2,))

def test_copy(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_argshapes(self):
a = function.Argument('a', (2, 3), dtype=int)
b = function.Argument('b', (3,), dtype=int)
f = (a * b[None]).sum(-1)
self.assertEqual(dict(f.argshapes), dict(a=(2, 3), b=(3,)))
self.assertEqual({a.name: a.shape for a in function.arguments_for(f).values()}, dict(a=(2, 3), b=(3,)))

def test_argshapes_shape_mismatch(self):
with self.assertRaises(Exception):
Expand Down Expand Up @@ -425,7 +425,7 @@ def test(self):
f = function._Unlower(e, frozenset(), arguments, function.LowerArgs((2, 3), {}, {}))
self.assertEqual(f.shape, (4, 5))
self.assertEqual(f.dtype, int)
self.assertEqual(f.arguments, arguments)
self.assertEqual(f._arguments, arguments)
self.assertEqual(f.lower(function.LowerArgs((2, 3), {}, {})), e)
with self.assertRaises(ValueError):
f.lower(function.LowerArgs((3, 4), {}, {}))
Expand Down

0 comments on commit be8bc53

Please sign in to comment.