From cabd790a06d842f72905ae4949cdce29553beda2 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 22 Mar 2024 13:26:12 +0100 Subject: [PATCH 1/5] remove superfluous implementation of IDDict.__str__ --- nutils/_util.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nutils/_util.py b/nutils/_util.py index 138bb28e3..8f0016f6f 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -821,11 +821,8 @@ def __iter__(self): def __contains__(self, key): return self.__dict.__contains__(id(key)) - def __str__(self): - return '{' + ', '.join(f'{k!r}: {v!r}' for k, v in self.items()) + '}' - def __repr__(self): - return self.__str__() + return '{' + ', '.join(f'{k!r}: {v!r}' for k, v in self.items()) + '}' def _tuple(*args): From c9a9415732f8a177dfa5090ee457c8b0c3590f6c Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 22 Mar 2024 12:15:50 +0100 Subject: [PATCH 2/5] add _util.IDSet --- nutils/_util.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_util.py | 75 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) diff --git a/nutils/_util.py b/nutils/_util.py index 8f0016f6f..453c3db47 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -778,6 +778,95 @@ def wrapper(*args, **kwargs): return wrapper +class IDSetView: + + def __init__(self, init=()): + self._dict = init._dict if isinstance(init, IDSetView) else {id(obj): obj for obj in init} + + def __len__(self): + return len(self._dict) + + def __bool__(self): + return bool(self._dict) + + def __iter__(self): + return iter(self._dict.values()) + + def __and__(self, other): + return self.copy().__iand__(other) + + def __or__(self, other): + return self.copy().__ior__(other) + + def __sub__(self, other): + return self.copy().__isub__(other) + + def isdisjoint(self, other): + return self._dict.isdisjoint(IDSetView(other)) + + def intersection(self, other): + return self.__and__(IDSetView(other)) + + def difference(self, other): + return self.__sub__(IDSetView(other)) + + def union(self, other): + return self.__or__(IDSetView(other)) + + def __repr__(self): + return '{' + ', '.join(map(repr, self)) + '}' + + def copy(self): + return IDSet(self) + + +class IDSet(IDSetView): + + def __init__(self, init=()): + self._dict = init._dict.copy() if isinstance(init, IDSetView) else {id(obj): obj for obj in init} + + def __iand__(self, other): + if not isinstance(other, IDSetView): + return NotImplemented + if not other._dict: + self._dict.clear() + elif self._dict: + for k in set(self._dict) - set(other._dict): + del self._dict[k] + return self + + def __ior__(self, other): + if not isinstance(other, IDSetView): + return NotImplemented + self._dict.update(other._dict) + return self + + def __isub__(self, other): + if not isinstance(other, IDSetView): + return NotImplemented + for k in other._dict: + self._dict.pop(k, None) + return self + + def add(self, obj): + self._dict[id(obj)] = obj + + def pop(self): + return self._dict.popitem()[1] + + def intersection_update(self, other): + self.__iand__(IDSetView(other)) + + def difference_update(self, other): + self.__isub__(IDSetView(other)) + + def update(self, other): + self.__ior__(IDSetView(other)) + + def view(self): + return IDSetView(self) + + class IDDict: '''Mapping from instance (is, not ==) to value. Keys need not be hashable.''' diff --git a/tests/test_util.py b/tests/test_util.py index 4ea82b926..db834d82a 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -451,6 +451,81 @@ def test_repr(self): self.assertEqual(repr(self.d), "{'a': 1, 'b': 2}") +class IDSet(TestCase): + + def setUp(self): + self.a, self.b, self.c = 'abc' + self.ab = util.IDSet([self.a, self.b]) + self.ac = util.IDSet([self.a, self.c]) + + def test_union(self): + union = self.ab | self.ac + self.assertEqual(list(union), ['a', 'b', 'c']) + union = self.ac.union([self.a, self.b]) + self.assertEqual(list(union), ['a', 'c', 'b']) + + def test_union_update(self): + self.ab |= self.ac + self.assertEqual(list(self.ab), ['a', 'b', 'c']) + self.ac.update([self.a, self.b]) + self.assertEqual(list(self.ac), ['a', 'c', 'b']) + + def test_intersection(self): + intersection = self.ab & self.ac + self.assertEqual(list(intersection), ['a']) + intersection = self.ab.intersection([self.a, self.c]) + self.assertEqual(list(intersection), ['a']) + + def test_intersection_update(self): + self.ab &= self.ac + self.assertEqual(list(self.ab), ['a']) + self.ac.intersection_update([self.a, self.b]) + self.assertEqual(list(self.ac), ['a']) + + def test_difference(self): + difference = self.ab - self.ac + self.assertEqual(list(difference), ['b']) + difference = self.ac - self.ab + self.assertEqual(list(difference), ['c']) + + def test_difference_update(self): + self.ab -= self.ac + self.assertEqual(list(self.ab), ['b']) + self.ac.difference_update([self.a, self.b]) + self.assertEqual(list(self.ac), ['c']) + + def test_add(self): + self.ab.add(self.a) + self.assertEqual(list(self.ab), ['a', 'b']) + self.ab.add(self.c) + self.assertEqual(list(self.ab), ['a', 'b', 'c']) + self.ac.add(self.b) + self.assertEqual(list(self.ac), ['a', 'c', 'b']) + + def test_pop(self): + self.assertEqual(self.ab.pop(), 'b') + self.assertEqual(list(self.ab), ['a']) + + def test_copy(self): + copy = self.ab.copy() + self.ab.pop() + self.assertEqual(list(self.ab), ['a']) + self.assertEqual(list(copy), ['a', 'b']) + + def test_view(self): + view = self.ab.view() + self.ab.pop() + self.assertEqual(list(view), ['a']) + with self.assertRaises(AttributeError): + view.pop() + + def test_str(self): + self.assertEqual(str(self.ab), "{'a', 'b'}") + + def test_repr(self): + self.assertEqual(repr(self.ab), "{'a', 'b'}") + + class replace(TestCase): class Base: From 292cad6f1c72191ee31dbba24a5f3de0347fc99a Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 22 Mar 2024 12:17:01 +0100 Subject: [PATCH 3/5] use IDSet in deep_replace_property This patch replaces the ostack list in deep_replace_property by an IDSet for faster contains checks. --- nutils/_util.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/nutils/_util.py b/nutils/_util.py index 453c3db47..c40bbb440 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -975,7 +975,7 @@ def __delete__(self, obj): def __get__(self, obj, objtype=None): fstack = [obj] # stack of unprocessed objects and command tokens rstack = [] # stack of processed objects - ostack = [] # stack of original objects to cache new value into + ostack = IDSet() # stack of original objects to cache new value into while fstack: obj = fstack.pop() @@ -999,10 +999,9 @@ def __get__(self, obj, objtype=None): if (r := obj.__dict__.get(self.name)) is not None: # in cache rstack.append(r if r is not self.identity else obj) elif obj in ostack: - index = ostack.index(obj) - raise Exception(f'{type(obj).__name__}.{self.name} is caught in a loop of size {len(ostack)-index}') + raise Exception(f'{type(obj).__name__}.{self.name} is caught in a loop') else: - ostack.append(obj) + ostack.add(obj) fstack.append(ostack) f, args = obj.__reduce__() fstack.append(self.recreate(f, len(args))) From 7f7cceb5b7295e92efbd6bf26a78619d73804c30 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 22 Mar 2024 12:32:10 +0100 Subject: [PATCH 4/5] support non-decorator use of _util.shallow_replace --- nutils/_util.py | 57 +++++++++++++++++++++++----------------------- tests/test_util.py | 10 ++++++++ 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/nutils/_util.py b/nutils/_util.py index c40bbb440..a3658502e 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -1020,7 +1020,7 @@ def __get__(self, obj, objtype=None): return rstack[0] -def shallow_replace(func): +def shallow_replace(func, *funcargs, **funckwargs): '''decorator for deep object replacement Generates a deep replacement method for reduceable objects based on a @@ -1043,42 +1043,43 @@ def shallow_replace(func): The method that searches the object to perform the replacements. ''' - recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig']) + if not funcargs and not funckwargs: # decorator + # it would be nice to use partial here but then the decorator doesn't work with methods + return functools.wraps(func)(lambda *args, **kwargs: shallow_replace(func, *args, **kwargs)) - @functools.wraps(func) - def wrapped(target, *funcargs, **funckwargs): - fstack = [target] # stack of unprocessed objects and command tokens - rstack = [] # stack of processed objects - cache = IDDict() # cache of seen objects + target, *funcargs = funcargs + recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig']) - while fstack: - obj = fstack.pop() + fstack = [target] # stack of unprocessed objects and command tokens + rstack = [] # stack of processed objects + cache = IDDict() # cache of seen objects - if isinstance(obj, recreate): - f, nargs, orig = obj - r = f(*[rstack.pop() for _ in range(nargs)]) - cache[orig] = r - rstack.append(r) + while fstack: + obj = fstack.pop() - elif (r := cache.get(obj)) is not None: - rstack.append(r) + if isinstance(obj, recreate): + f, nargs, orig = obj + r = f(*[rstack.pop() for _ in range(nargs)]) + cache[orig] = r + rstack.append(r) - elif (r := func(obj, *funcargs, **funckwargs)) is not None: - cache[obj] = r - rstack.append(r) + elif (r := cache.get(obj)) is not None: + rstack.append(r) - elif reduced := _reduce(obj): - f, args = reduced - fstack.append(recreate(f, len(args), obj)) - fstack.extend(args) + elif (r := func(obj, *funcargs, **funckwargs)) is not None: + cache[obj] = r + rstack.append(r) - else: # obj cannot be reduced - rstack.append(obj) + elif reduced := _reduce(obj): + f, args = reduced + fstack.append(recreate(f, len(args), obj)) + fstack.extend(args) - assert len(rstack) == 1 - return rstack[0] + else: # obj cannot be reduced + rstack.append(obj) - return wrapped + assert len(rstack) == 1 + return rstack[0] # vim:sw=4:sts=4:et diff --git a/tests/test_util.py b/tests/test_util.py index db834d82a..ec05d7cbe 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -576,3 +576,13 @@ def test_shallow_nested(self): newobj = self.subs10(obj, 20) self.assertEqual(type(newobj), type(obj)) self.assertEqual(newobj.args, (5, {7, 20})) + + def test_shallow_direct(self): + ten = self.Ten() + obj = self.Base(5, {7, ten}) + def subs(arg): + if isinstance(arg, self.Ten): + return 20 + newobj = util.shallow_replace(subs, obj) + self.assertEqual(type(newobj), type(obj)) + self.assertEqual(newobj.args, (5, {7, 20})) From 052536d4ab5f6b3f9d206a77f33db8af06eca1c0 Mon Sep 17 00:00:00 2001 From: Gertjan van Zwieten Date: Fri, 22 Mar 2024 12:37:44 +0100 Subject: [PATCH 5/5] change _util.deep_replace_property return value This patch changes the definition of the function decorated by _util.deep_replace_property, from returning None if no replacement is made, to returning the object itself. The new definition is a more natural choice for a recursive procedure (recursion stops when a trivial loop is reached) and results in better readable code for the casual observer. --- nutils/_util.py | 4 ++-- nutils/evaluable.py | 8 ++++++-- tests/test_util.py | 4 +++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/nutils/_util.py b/nutils/_util.py index a3658502e..0854e6a77 100644 --- a/nutils/_util.py +++ b/nutils/_util.py @@ -951,7 +951,7 @@ class deep_replace_property: Args ---- func - Callable which maps an object onto a new object, or ``None`` if no + Callable which maps an object onto a new object, or onto itself if no replacement is made. It must have precisely one positional argument for the object. ''' @@ -983,7 +983,7 @@ def __get__(self, obj, objtype=None): if isinstance(obj, self.recreate): # recreate object from rstack f, nargs = obj r = f(*[rstack.pop() for _ in range(nargs)]) - if isinstance(r, self.owner) and (newr := self.func(r)) is not None: + if isinstance(r, self.owner) and (newr := self.func(r)) is not r: fstack.append(newr) # recursion else: rstack.append(r) diff --git a/nutils/evaluable.py b/nutils/evaluable.py index b3a417ec0..5093bcaba 100644 --- a/nutils/evaluable.py +++ b/nutils/evaluable.py @@ -293,7 +293,9 @@ def _format_stack(self, values, e): @util.deep_replace_property def simplified(obj): retval = obj._simplified() - if retval is not None and isinstance(obj, Array): + if retval is None: + return obj + if isinstance(obj, Array): assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval) return retval @@ -310,7 +312,9 @@ def optimized_for_numpy(self): @util.deep_replace_property def _optimized_for_numpy1(obj): retval = obj._simplified() or obj._optimized_for_numpy() - if retval is not None and isinstance(obj, Array): + if retval is None: + return obj + if isinstance(obj, Array): assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__) return retval diff --git a/tests/test_util.py b/tests/test_util.py index ec05d7cbe..07927f3e3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -540,8 +540,10 @@ def simple(self): self.called = True if isinstance(self, replace.Ten): return replace.Intermediate() # to test recursion - if isinstance(self, replace.Intermediate): + elif isinstance(self, replace.Intermediate): return 10 + else: + return self class Ten(Base): pass class Intermediate(Base): pass