Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace #864

Merged
merged 5 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 124 additions & 38 deletions nutils/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,95 @@
return wrapper


class IDSetView:

def __init__(self, init=()):
self._dict = init._dict if isinstance(init, IDSetView) else {id(obj): obj for obj in init}

def __len__(self):
return len(self._dict)

def __bool__(self):
return bool(self._dict)

def __iter__(self):
return iter(self._dict.values())

def __and__(self, other):
return self.copy().__iand__(other)

def __or__(self, other):
return self.copy().__ior__(other)

def __sub__(self, other):
return self.copy().__isub__(other)

def isdisjoint(self, other):
return self._dict.isdisjoint(IDSetView(other))

Check warning on line 805 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L805

Added line #L805 was not covered by tests

def intersection(self, other):
return self.__and__(IDSetView(other))

def difference(self, other):
return self.__sub__(IDSetView(other))

Check warning on line 811 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L811

Added line #L811 was not covered by tests

def union(self, other):
return self.__or__(IDSetView(other))

def __repr__(self):
return '{' + ', '.join(map(repr, self)) + '}'

def copy(self):
return IDSet(self)


class IDSet(IDSetView):

def __init__(self, init=()):
self._dict = init._dict.copy() if isinstance(init, IDSetView) else {id(obj): obj for obj in init}

def __iand__(self, other):
if not isinstance(other, IDSetView):
return NotImplemented

Check warning on line 830 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L830

Added line #L830 was not covered by tests
if not other._dict:
self._dict.clear()

Check warning on line 832 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L832

Added line #L832 was not covered by tests
elif self._dict:
for k in set(self._dict) - set(other._dict):
del self._dict[k]
return self

def __ior__(self, other):
if not isinstance(other, IDSetView):
return NotImplemented

Check warning on line 840 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L840

Added line #L840 was not covered by tests
self._dict.update(other._dict)
return self

def __isub__(self, other):
if not isinstance(other, IDSetView):
return NotImplemented

Check warning on line 846 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L846

Added line #L846 was not covered by tests
for k in other._dict:
self._dict.pop(k, None)
return self

def add(self, obj):
self._dict[id(obj)] = obj

def pop(self):
return self._dict.popitem()[1]

def intersection_update(self, other):
self.__iand__(IDSetView(other))

def difference_update(self, other):
self.__isub__(IDSetView(other))

def update(self, other):
self.__ior__(IDSetView(other))

def view(self):
return IDSetView(self)


class IDDict:
'''Mapping from instance (is, not ==) to value. Keys need not be hashable.'''

Expand Down Expand Up @@ -821,11 +910,8 @@
def __contains__(self, key):
return self.__dict.__contains__(id(key))

def __str__(self):
return '{' + ', '.join(f'{k!r}: {v!r}' for k, v in self.items()) + '}'

def __repr__(self):
return self.__str__()
return '{' + ', '.join(f'{k!r}: {v!r}' for k, v in self.items()) + '}'


def _tuple(*args):
Expand Down Expand Up @@ -865,7 +951,7 @@
Args
----
func
Callable which maps an object onto a new object, or ``None`` if no
Callable which maps an object onto a new object, or onto itself if no
replacement is made. It must have precisely one positional argument for
the object.
'''
Expand All @@ -889,15 +975,15 @@
def __get__(self, obj, objtype=None):
fstack = [obj] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
ostack = [] # stack of original objects to cache new value into
ostack = IDSet() # stack of original objects to cache new value into

while fstack:
obj = fstack.pop()

if isinstance(obj, self.recreate): # recreate object from rstack
f, nargs = obj
r = f(*[rstack.pop() for _ in range(nargs)])
if isinstance(r, self.owner) and (newr := self.func(r)) is not None:
if isinstance(r, self.owner) and (newr := self.func(r)) is not r:
fstack.append(newr) # recursion
else:
rstack.append(r)
Expand All @@ -913,10 +999,9 @@
if (r := obj.__dict__.get(self.name)) is not None: # in cache
rstack.append(r if r is not self.identity else obj)
elif obj in ostack:
index = ostack.index(obj)
raise Exception(f'{type(obj).__name__}.{self.name} is caught in a loop of size {len(ostack)-index}')
raise Exception(f'{type(obj).__name__}.{self.name} is caught in a loop')

Check warning on line 1002 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L1002

Added line #L1002 was not covered by tests
else:
ostack.append(obj)
ostack.add(obj)
fstack.append(ostack)
f, args = obj.__reduce__()
fstack.append(self.recreate(f, len(args)))
Expand All @@ -935,7 +1020,7 @@
return rstack[0]


def shallow_replace(func):
def shallow_replace(func, *funcargs, **funckwargs):
'''decorator for deep object replacement

Generates a deep replacement method for reduceable objects based on a
Expand All @@ -958,42 +1043,43 @@
The method that searches the object to perform the replacements.
'''

recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig'])
if not funcargs and not funckwargs: # decorator
# it would be nice to use partial here but then the decorator doesn't work with methods
return functools.wraps(func)(lambda *args, **kwargs: shallow_replace(func, *args, **kwargs))

@functools.wraps(func)
def wrapped(target, *funcargs, **funckwargs):
fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
cache = IDDict() # cache of seen objects
target, *funcargs = funcargs
recreate = collections.namedtuple('recreate', ['f', 'nargs', 'orig'])

while fstack:
obj = fstack.pop()
fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
cache = IDDict() # cache of seen objects

if isinstance(obj, recreate):
f, nargs, orig = obj
r = f(*[rstack.pop() for _ in range(nargs)])
cache[orig] = r
rstack.append(r)
while fstack:
obj = fstack.pop()

elif (r := cache.get(obj)) is not None:
rstack.append(r)
if isinstance(obj, recreate):
f, nargs, orig = obj
r = f(*[rstack.pop() for _ in range(nargs)])
cache[orig] = r
rstack.append(r)

elif (r := func(obj, *funcargs, **funckwargs)) is not None:
cache[obj] = r
rstack.append(r)
elif (r := cache.get(obj)) is not None:
rstack.append(r)

elif reduced := _reduce(obj):
f, args = reduced
fstack.append(recreate(f, len(args), obj))
fstack.extend(args)
elif (r := func(obj, *funcargs, **funckwargs)) is not None:
cache[obj] = r
rstack.append(r)

else: # obj cannot be reduced
rstack.append(obj)
elif reduced := _reduce(obj):
f, args = reduced
fstack.append(recreate(f, len(args), obj))
fstack.extend(args)

assert len(rstack) == 1
return rstack[0]
else: # obj cannot be reduced
rstack.append(obj)

return wrapped
assert len(rstack) == 1
return rstack[0]


# vim:sw=4:sts=4:et
8 changes: 6 additions & 2 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def _format_stack(self, values, e):
@util.deep_replace_property
def simplified(obj):
retval = obj._simplified()
if retval is not None and isinstance(obj, Array):
if retval is None:
return obj
if isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval)
return retval

Expand All @@ -310,7 +312,9 @@ def optimized_for_numpy(self):
@util.deep_replace_property
def _optimized_for_numpy1(obj):
retval = obj._simplified() or obj._optimized_for_numpy()
if retval is not None and isinstance(obj, Array):
if retval is None:
return obj
if isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__)
return retval

Expand Down
89 changes: 88 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,81 @@ def test_repr(self):
self.assertEqual(repr(self.d), "{'a': 1, 'b': 2}")


class IDSet(TestCase):

def setUp(self):
self.a, self.b, self.c = 'abc'
self.ab = util.IDSet([self.a, self.b])
self.ac = util.IDSet([self.a, self.c])

def test_union(self):
union = self.ab | self.ac
self.assertEqual(list(union), ['a', 'b', 'c'])
union = self.ac.union([self.a, self.b])
self.assertEqual(list(union), ['a', 'c', 'b'])

def test_union_update(self):
self.ab |= self.ac
self.assertEqual(list(self.ab), ['a', 'b', 'c'])
self.ac.update([self.a, self.b])
self.assertEqual(list(self.ac), ['a', 'c', 'b'])

def test_intersection(self):
intersection = self.ab & self.ac
self.assertEqual(list(intersection), ['a'])
intersection = self.ab.intersection([self.a, self.c])
self.assertEqual(list(intersection), ['a'])

def test_intersection_update(self):
self.ab &= self.ac
self.assertEqual(list(self.ab), ['a'])
self.ac.intersection_update([self.a, self.b])
self.assertEqual(list(self.ac), ['a'])

def test_difference(self):
difference = self.ab - self.ac
self.assertEqual(list(difference), ['b'])
difference = self.ac - self.ab
self.assertEqual(list(difference), ['c'])

def test_difference_update(self):
self.ab -= self.ac
self.assertEqual(list(self.ab), ['b'])
self.ac.difference_update([self.a, self.b])
self.assertEqual(list(self.ac), ['c'])

def test_add(self):
self.ab.add(self.a)
self.assertEqual(list(self.ab), ['a', 'b'])
self.ab.add(self.c)
self.assertEqual(list(self.ab), ['a', 'b', 'c'])
self.ac.add(self.b)
self.assertEqual(list(self.ac), ['a', 'c', 'b'])

def test_pop(self):
self.assertEqual(self.ab.pop(), 'b')
self.assertEqual(list(self.ab), ['a'])

def test_copy(self):
copy = self.ab.copy()
self.ab.pop()
self.assertEqual(list(self.ab), ['a'])
self.assertEqual(list(copy), ['a', 'b'])

def test_view(self):
view = self.ab.view()
self.ab.pop()
self.assertEqual(list(view), ['a'])
with self.assertRaises(AttributeError):
view.pop()

def test_str(self):
self.assertEqual(str(self.ab), "{'a', 'b'}")

def test_repr(self):
self.assertEqual(repr(self.ab), "{'a', 'b'}")


class replace(TestCase):

class Base:
Expand All @@ -465,8 +540,10 @@ def simple(self):
self.called = True
if isinstance(self, replace.Ten):
return replace.Intermediate() # to test recursion
if isinstance(self, replace.Intermediate):
elif isinstance(self, replace.Intermediate):
return 10
else:
return self

class Ten(Base): pass
class Intermediate(Base): pass
Expand Down Expand Up @@ -501,3 +578,13 @@ def test_shallow_nested(self):
newobj = self.subs10(obj, 20)
self.assertEqual(type(newobj), type(obj))
self.assertEqual(newobj.args, (5, {7, 20}))

def test_shallow_direct(self):
ten = self.Ten()
obj = self.Base(5, {7, ten})
def subs(arg):
if isinstance(arg, self.Ten):
return 20
newobj = util.shallow_replace(subs, obj)
self.assertEqual(type(newobj), type(obj))
self.assertEqual(newobj.args, (5, {7, 20}))
Loading