Skip to content

Commit

Permalink
replace function.replace by util variants
Browse files Browse the repository at this point in the history
This patch removes the function.replace decorator and replaces it with
_util.deep_replace_property (for simplified and optimized_for_numpy) and
_util.shallow_replace (for replace_arguments, _deep_flatten_constants
and _combine_loop_concatenates). The specialized variants are faster,
use less memory, and offer better cache reuse.
  • Loading branch information
Gertjan van Zwieten committed Mar 21, 2024
1 parent f102f75 commit 507f0d6
Showing 1 changed file with 17 additions and 163 deletions.
180 changes: 17 additions & 163 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import collections.abc
import math
import treelog as log
import weakref
import time
import contextlib
import subprocess
Expand Down Expand Up @@ -129,149 +128,6 @@ class ExpensiveEvaluationWarning(warnings.NutilsInefficiencyWarning):
pass


def replace(func=None, depthfirst=False, recursive=False, lru=4):
'''decorator for deep object replacement
Generates a deep replacement method for general objects based on a callable
that is applied (recursively) on individual constructor arguments.
Args
----
func
Callable which maps an object onto a new object, or `None` if no
replacement is made. It must have one positional argument for the object,
and may have any number of additional positional and/or keyword
arguments.
depthfirst : :class:`bool`
If `True`, decompose each object as far a possible, then apply `func` to
all arguments as the objects are reconstructed. Otherwise apply `func`
directly on each new object that is encountered in the decomposition,
proceding only if the return value is `None`.
recursive : :class:`bool`
If `True`, repeat replacement for any object returned by `func` until it
returns `None`. Otherwise perform a single, non-recursive sweep.
lru : :class:`int`
Maximum size of the least-recently-used cache. A persistent weak-key
dictionary is maintained for every unique set of function arguments. When
the size of `lru` is reached, the least recently used cache is dropped.
Returns
-------
:any:`callable`
The method that searches the object to perform the replacements.
'''

if func is None:
return functools.partial(replace, depthfirst=depthfirst, recursive=recursive, lru=lru)

signature = inspect.signature(func)
arguments = [] # list of past function arguments, least recently used last
caches = [] # list of weak-key dictionaries matching arguments (above)

remember = object() # token to signal that rstack[-1] can be cached as the replacement of fstack[-1]
recreate = object() # token to signal that all arguments for object recreation are ready on rstack
pending = object() # token to hold the place of a cachable object pending creation
identity = object() # token to hold the place of the cache value in case it matches key, to avoid circular references

@functools.wraps(func)
def wrapped(target, *funcargs, **funckwargs):

# retrieve or create a weak-key dictionary
bound = signature.bind(None, *funcargs, **funckwargs)
bound.apply_defaults()
try:
index = arguments.index(bound.arguments) # by using index, arguments need not be hashable
except ValueError:
index = -1
cache = weakref.WeakKeyDictionary()
else:
cache = caches[index]
if index != 0: # function arguments are not the most recent (possibly new)
if index > 0 or len(arguments) >= lru:
caches.pop(index) # pop matching (or oldest) item
arguments.pop(index)
caches.insert(0, cache) # insert popped (or new) item to front
arguments.insert(0, bound.arguments)

fstack = [target] # stack of unprocessed objects and command tokens
rstack = [] # stack of processed objects
_stack = fstack if recursive else rstack

try:
while fstack:
obj = fstack.pop()

if obj is recreate:
args = [rstack.pop() for obj in range(fstack.pop())]
f = fstack.pop()
r = f(*args)
if depthfirst:
newr = func(r, *funcargs, **funckwargs)
if newr is not None:
_stack.append(newr)
continue
rstack.append(r)
continue

if obj is remember:
obj = fstack.pop()
cache[obj] = rstack[-1] if rstack[-1] is not obj else identity
continue

if obj.__class__ in (tuple, list, dict, set, frozenset):
if not obj:
rstack.append(obj) # shortcut to avoid recreation of empty container
else:
fstack.append(lambda *x, T=type(obj): T(x))
fstack.append(len(obj))
fstack.append(recreate)
fstack.extend(obj if not isinstance(obj, dict) else obj.items())
continue

try:
r = cache[obj]
except KeyError: # object can be weakly cached, but isn't
cache[obj] = pending
fstack.append(obj)
fstack.append(remember)
except TypeError: # object cannot be referenced or is not hashable
pass
else: # object is in cache
if r is pending:
pending_objs = tuple(k for k, v in cache.items() if v is pending)
index = pending_objs.index(obj)
raise Exception('{}@replace caught in a circular dependence\n'.format(func.__name__) + Tuple(pending_objs[index:]).asciitree().split('\n', 1)[1])
rstack.append(r if r is not identity else obj)
continue

if not depthfirst:
newr = func(obj, *funcargs, **funckwargs)
if newr is not None:
_stack.append(newr)
continue

try:
f, args = obj.__reduce__()
except: # obj cannot be reduced into a constructor and its arguments
rstack.append(obj)
else:
fstack.append(f)
fstack.append(len(args))
fstack.append(recreate)
fstack.extend(args)

assert len(rstack) == 1

finally:
while fstack:
if fstack.pop() is remember:
assert cache.pop(fstack.pop()) is pending

return rstack[0]

return wrapped


class Evaluable(types.Singleton):
'Base class'

Expand Down Expand Up @@ -434,36 +290,34 @@ def _format_stack(self, values, e):
lines.append(f'{next(stack)} --> {e}')
return '\n '.join(lines)

@property
@replace(depthfirst=True, recursive=True)
@util.deep_replace_property
def simplified(obj):
if isinstance(obj, Evaluable):
retval = obj._simplified()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval)
return retval
retval = obj._simplified()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape) and retval.dtype == obj.dtype, '{} --simplify--> {}'.format(obj, retval)
return retval

def _simplified(self):
return

@cached_property
def optimized_for_numpy(self):
retval = self.simplified._optimized_for_numpy1() or self
retval = retval._deep_flatten_constants() or retval
return retval._combine_loop_concatenates(frozenset())
return self.simplified \
._optimized_for_numpy1 \
._deep_flatten_constants() \
._combine_loop_concatenates(frozenset())

@replace(depthfirst=True, recursive=True)
@util.deep_replace_property
def _optimized_for_numpy1(obj):
if isinstance(obj, Evaluable):
retval = obj._simplified() or obj._optimized_for_numpy()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__)
return retval
retval = obj._simplified() or obj._optimized_for_numpy()
if retval is not None and isinstance(obj, Array):
assert isinstance(retval, Array) and equalshape(retval.shape, obj.shape), '{0}._optimized_for_numpy or {0}._simplified resulted in shape change'.format(type(obj).__name__)
return retval

def _optimized_for_numpy(self):
return

@replace(depthfirst=False, recursive=False)
@util.shallow_replace
def _deep_flatten_constants(self):
if isinstance(self, Array):
return self._flatten_constant()
Expand Down Expand Up @@ -511,7 +365,7 @@ def _combine_loop_concatenates(self, outer_exclude):
intbounds = dict(zip(('_lower', '_upper'), lc._intbounds)) if lc.dtype == int else {}
replacements[lc] = ArrayFromTuple(combined, i, lc.shape, lc.dtype, **intbounds)
if replacements:
self = replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None, recursive=False, depthfirst=False)(self)
self = util.shallow_replace(lambda key: replacements.get(key) if isinstance(key, LoopConcatenate) else None)(self)
else:
return self

Expand Down Expand Up @@ -5087,7 +4941,7 @@ def loop_concatenate_combined(funcs, index):
return tuple(ArrayFromTuple(loop, unique_funcs.index(func), tuple(shape), func.dtype) for func, start, stop, *shape in unique_func_data)


@replace
@util.shallow_replace
def replace_arguments(value, arguments):
'''Replace :class:`Argument` objects in ``value``.
Expand Down

0 comments on commit 507f0d6

Please sign in to comment.