Skip to content

Commit

Permalink
rm function.replace, add deep and shallow variants
Browse files Browse the repository at this point in the history
This patch removes the general function.replace decorator and replaces
it with the specialized _util.deep_replace_property (used for simplified
and optimized_for_numpy) and _util.shallow_replace (used for
replace_arguments, _deep_flatten_constants and
_combine_loop_concatenates). The differences between the two constructs
are as follows:

@_util.deep_replace_property
- property
- intermediate values cached in object attribute
- depth first
- recursive

@_util.shallow_replace
- function
- intermediate values cached only during replacement
- depth last
- non recursive
  • Loading branch information
Gertjan van Zwieten committed Mar 20, 2024
1 parent 43b9439 commit 5d6cbca
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 163 deletions.
167 changes: 167 additions & 0 deletions nutils/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,4 +778,171 @@ def wrapper(*args, **kwargs):
return wrapper


def _reduce(obj):
'helper function for deep_replace_property and shallow_replace'

if obj.__class__ in (tuple, list, dict, set, frozenset):
if not obj: # empty containers need not be entered
raise Exception('refusing to enter empty container')
return lambda *x, T=type(obj): T(x), tuple(obj if not isinstance(obj, dict) else obj.items())
return obj.__reduce__()


class deep_replace_property:
'''decorator for deep object replacement
Generates a cached property for deep replacement of reduceable objects,
based on a callable that is applied depth first and recursively on
individual constructor arguments. Intermediate values are stored in the
attribute by the same name of any object that is a decendent of the class
that owns the property.
Args
----
func
Callable which maps an object onto a new object, or `None` if no
replacement is made. It must have precisely one positional argument for
the object.
'''

def __init__(self, func):
self.func = func

def __set_name__(self, owner, name):
self.owner = owner
self.name = name

def __set__(self, obj, value):
raise AttributeError("can't set attribute")

Check warning on line 816 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L816

Added line #L816 was not covered by tests

def __delete__(self, obj):
raise AttributeError("can't delete attribute")

Check warning on line 819 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L819

Added line #L819 was not covered by tests

def __get__(self, obj, objtype=None):
fstack = [obj] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
recreate = [] # stack of constructor functions
originals = [] # stack of original objects to cache new value into

while fstack:
obj = fstack.pop()

if obj is recreate: # recreate object from rstack
f, nargs = recreate.pop()
r = f(*[rstack.pop() for _ in range(nargs)])
if isinstance(r, self.owner) and (newr := self.func(r)) is not None:
fstack.append(newr) # recursion
else:
rstack.append(r)

elif obj is originals: # store new representation
orig = originals.pop()
r = rstack[-1]
if r is orig: # this may happen if obj is memoizing
r = None # prevent cyclic reference
orig.__dict__[self.name] = r

elif isinstance(obj, self.owner) and self.name in obj.__dict__: # in cache
r = obj.__dict__[self.name]
rstack.append(r if r is not None else obj)

else:
if isinstance(obj, self.owner):
if obj in originals:
index = originals.index(obj)
raise Exception(f'{type(instance).__name__}.{self.name} is caught in a loop of size {len(originals)-index}')

Check warning on line 853 in nutils/_util.py

View check run for this annotation

Codecov / codecov/patch

nutils/_util.py#L852-L853

Added lines #L852 - L853 were not covered by tests
originals.append(obj)
fstack.append(originals)
try:
f, args = _reduce(obj)
except: # obj cannot be reduced into a constructor and its arguments
rstack.append(obj)
else:
recreate.append((f, len(args)))
fstack.append(recreate)
fstack.extend(args)

assert not recreate
assert not originals
assert len(rstack) == 1
return rstack[0]


def shallow_replace(func):
'''decorator for deep object replacement
Generates a deep replacement method for reduceable objects based on a
callable that is applied on individual constructor arguments. The
replacement takes a shallow first approach and stops as soon as the
callable returns a value that is not `None`. Intermediate values are
flushed upon return.
Args
----
func
Callable which maps an object onto a new object, or `None` if no
replacement is made. It must have one positional argument for the object,
and may have any number of additional positional and/or keyword
arguments.
Returns
-------
:any:`callable`
The method that searches the object to perform the replacements.
'''

@functools.wraps(func)
def wrapped(target, *funcargs, **funckwargs):
fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
recreate = [] # stack of constructor functions
originals = [] # stack of original objects to cache new value into
cache = {} # cache of seen objects

while fstack:
obj = fstack.pop()

if obj is recreate:
f, nargs = recreate.pop()
r = f(*[rstack.pop() for _ in range(nargs)])
rstack.append(r)

elif obj is originals:
orig = originals.pop()
cache[type(orig), orig] = rstack[-1]

else:
try:
r = cache[type(obj), obj] # including type to avoid 1 == 1. etc
except KeyError: # object can be cached but isn't
originals.append(obj)
fstack.append(originals)
except TypeError: # object is not hashable
pass
else: # object is in cache
rstack.append(r)
continue

newr = func(obj, *funcargs, **funckwargs)
if newr is not None:
rstack.append(newr)
continue

try:
f, args = _reduce(obj)
except: # obj cannot be reduced into a constructor and its arguments
rstack.append(obj)
else:
recreate.append((f, len(args)))
fstack.append(recreate)
fstack.extend(args)

assert not recreate
assert not originals
assert len(rstack) == 1
return rstack[0]

return wrapped


# vim:sw=4:sts=4:et
180 changes: 17 additions & 163 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import collections.abc
import math
import treelog as log
import weakref
import time
import contextlib
import subprocess
Expand Down Expand Up @@ -129,149 +128,6 @@ class ExpensiveEvaluationWarning(warnings.NutilsInefficiencyWarning):
pass


def replace(func=None, depthfirst=False, recursive=False, lru=4):
'''decorator for deep object replacement
Generates a deep replacement method for general objects based on a callable
that is applied (recursively) on individual constructor arguments.
Args
----
func
Callable which maps an object onto a new object, or `None` if no
replacement is made. It must have one positional argument for the object,
and may have any number of additional positional and/or keyword
arguments.
depthfirst : :class:`bool`
If `True`, decompose each object as far a possible, then apply `func` to
all arguments as the objects are reconstructed. Otherwise apply `func`
directly on each new object that is encountered in the decomposition,
proceding only if the return value is `None`.
recursive : :class:`bool`
If `True`, repeat replacement for any object returned by `func` until it
returns `None`. Otherwise perform a single, non-recursive sweep.
lru : :class:`int`
Maximum size of the least-recently-used cache. A persistent weak-key
dictionary is maintained for every unique set of function arguments. When
the size of `lru` is reached, the least recently used cache is dropped.
Returns
-------
:any:`callable`
The method that searches the object to perform the replacements.
'''

if func is None:
return functools.partial(replace, depthfirst=depthfirst, recursive=recursive, lru=lru)

signature = inspect.signature(func)
arguments = [] # list of past function arguments, least recently used last
caches = [] # list of weak-key dictionaries matching arguments (above)

remember = object() # token to signal that rstack[-1] can be cached as the replacement of fstack[-1]
recreate = object() # token to signal that all arguments for object recreation are ready on rstack
pending = object() # token to hold the place of a cachable object pending creation
identity = object() # token to hold the place of the cache value in case it matches key, to avoid circular references

@functools.wraps(func)
def wrapped(target, *funcargs, **funckwargs):

# retrieve or create a weak-key dictionary
bound = signature.bind(None, *funcargs, **funckwargs)
bound.apply_defaults()
try:
index = arguments.index(bound.arguments) # by using index, arguments need not be hashable
except ValueError:
index = -1
cache = weakref.WeakKeyDictionary()
else:
cache = caches[index]
if index != 0: # function arguments are not the most recent (possibly new)
if index > 0 or len(arguments) >= lru:
caches.pop(index) # pop matching (or oldest) item
arguments.pop(index)
caches.insert(0, cache) # insert popped (or new) item to front
arguments.insert(0, bound.arguments)

fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
_stack = fstack if recursive else rstack

try:
while fstack:
obj = fstack.pop()

if obj is recreate:
args = [rstack.pop() for obj in range(fstack.pop())]
f = fstack.pop()
r = f(*args)
if depthfirst:
newr = func(r, *funcargs, **funckwargs)
if newr is not None:
_stack.append(newr)
continue
rstack.append(r)
continue

if obj is remember:
obj = fstack.pop()
cache[obj] = rstack[-1] if rstack[-1] is not obj else identity
continue

if obj.__class__ in (tuple, list, dict, set, frozenset):
if not obj:
rstack.append(obj) # shortcut to avoid recreation of empty container
else:
fstack.append(lambda *x, T=type(obj): T(x))
fstack.append(len(obj))
fstack.append(recreate)
fstack.extend(obj if not isinstance(obj, dict) else obj.items())
continue

try:
r = cache[obj]
except KeyError: # object can be weakly cached, but isn't
cache[obj] = pending
fstack.append(obj)
fstack.append(remember)
except TypeError: # object cannot be referenced or is not hashable
pass
else: # object is in cache
if r is pending:
pending_objs = tuple(k for k, v in cache.items() if v is pending)
index = pending_objs.index(obj)
raise Exception('{}@replace caught in a circular dependence\n'.format(func.__name__) + Tuple(pending_objs[index:]).asciitree().split('\n', 1)[1])
rstack.append(r if r is not identity else obj)
continue

if not depthfirst:
newr = func(obj, *funcargs, **funckwargs)
if newr is not None:
_stack.append(newr)
continue

try:
f, args = obj.__reduce__()
except: # obj cannot be reduced into a constructor and its arguments
rstack.append(obj)
else:
fstack.append(f)
fstack.append(len(args))
fstack.append(recreate)
fstack.extend(args)

assert len(rstack) == 1

finally:
while fstack:
if fstack.pop() is remember:
assert cache.pop(fstack.pop()) is pending

return rstack[0]

return wrapped


class Evaluable(types.Singleton):
'Base class'

Expand Down Expand Up @@ -434,36 +290,34 @@ def _format_stack(self, values, e):
lines.append(f'{next(stack)} --> {e}')
return '\n '.join(lines)

@property
@replace(depthfirst=True, recursive=True)
@util.deep_replace_property
def simplified(obj):
if isinstance(obj, Evaluable):
retval = obj._simplified()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval)
return retval
retval = obj._simplified()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval)
return retval

def _simplified(self):
return

@cached_property
def optimized_for_numpy(self):
retval = self.simplified._optimized_for_numpy1() or self
retval = retval._deep_flatten_constants() or retval
return retval._combine_loop_concatenates(frozenset())
return self.simplified \
._optimized_for_numpy1 \
._deep_flatten_constants() \
._combine_loop_concatenates(frozenset())

@replace(depthfirst=True, recursive=True)
@util.deep_replace_property
def _optimized_for_numpy1(obj):
if isinstance(obj, Evaluable):
retval = obj._simplified() or obj._optimized_for_numpy()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__)
return retval
retval = obj._simplified() or obj._optimized_for_numpy()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__)
return retval

def _optimized_for_numpy(self):
return

@replace(depthfirst=False, recursive=False)
@util.shallow_replace
def _deep_flatten_constants(self):
if isinstance(self, Array):
return self._flatten_constant()
Expand Down Expand Up @@ -511,7 +365,7 @@ def _combine_loop_concatenates(self, outer_exclude):
intbounds = dict(zip(('_lower', '_upper'), lc._intbounds)) if lc.dtype == int else {}
replacements[lc] = ArrayFromTuple(combined, i, lc.shape, lc.dtype, **intbounds)
if replacements:
self = replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None, recursive=False, depthfirst=False)(self)
self = util.shallow_replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None)(self)
else:
return self

Expand Down Expand Up @@ -5086,7 +4940,7 @@ def loop_concatenate_combined(funcs, index):
return tuple(ArrayFromTuple(loop, unique_funcs.index(func), tuple(shape), func.dtype) for func, start, stop, *shape in unique_func_data)


@replace
@util.shallow_replace
def replace_arguments(value, arguments):
'''Replace :class:`Argument` objects in ``value``.
Expand Down
Loading

0 comments on commit 5d6cbca

Please sign in to comment.