Skip to content

Commit

Permalink
replace _deep_flatten_consts
Browse files Browse the repository at this point in the history
  • Loading branch information
joostvanzwieten committed Mar 2, 2024
1 parent b557cc4 commit 520fd9a
Showing 1 changed file with 16 additions and 26 deletions.
42 changes: 16 additions & 26 deletions nutils/evaluable.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def dependencies(self):
for func in self.__args:
funcdeps = func.dependencies
deps.update(funcdeps)
deps[func] = len(funcdeps)
deps[func] = not func.isconstant, len(funcdeps)
return types.frozendict(deps)

@cached_property
Expand Down Expand Up @@ -336,11 +336,9 @@ def serialized(self):
# property should be treated as if it is immutable.
@cached_property
def _serialized_evalf_head(self):
return tuple(op.evalf for op in self.ordereddeps[1:])

@property
def _serialized_evalf(self):
return zip(itertools.chain(self._serialized_evalf_head, (self.evalf,)), self.dependencytree[1:])
serialized = tuple(zip((op.evalf for op in self.ordereddeps[1:]), self.dependencytree[1:-1]))
nconsts = builtins.sum(itertools.takewhile(lambda v: v, map(operator.attrgetter('isconstant'), self.ordereddeps[1:])))
return serialized, nconsts

def _node(self, cache, subgraph, times):
if self in cache:
Expand All @@ -366,14 +364,25 @@ def eval(self, **evalargs):
'''Evaluate function on a specified element, point set.'''

values = [evalargs]
serialized_head, nconsts = self._serialized_evalf_head
if nconsts and (consts := getattr(self, '_eval_consts', None)):
assert len(consts) == nconsts
if len(serialized_head) + 1 == nconsts:
# self is constant
return consts[-1]

Check warning on line 372 in nutils/evaluable.py

View check run for this annotation

Codecov / codecov/patch

nutils/evaluable.py#L372

Added line #L372 was not covered by tests
values.extend(consts)
serialized_head = serialized_head[nconsts:]
try:
values.extend(op_evalf(*[values[i] for i in indices]) for op_evalf, indices in self._serialized_evalf)
values.extend(op_evalf(*[values[i] for i in indices]) for op_evalf, indices in serialized_head)
values.append(self.evalf(*[values[i] for i in self.dependencytree[-1]]))
except KeyboardInterrupt:
raise
except Exception as e:
log.error(self._format_stack(values, e))
raise
else:
if nconsts and consts is None:
self._eval_consts = values[1:1 + nconsts]
return values[-1]

def eval_withtimes(self, times, **evalargs):
Expand Down Expand Up @@ -449,7 +458,6 @@ def _simplified(self):
@cached_property
def optimized_for_numpy(self):
retval = self.simplified._optimized_for_numpy1() or self
retval = retval._deep_flatten_constants() or retval
return retval._combine_loop_concatenates(frozenset())

@replace(depthfirst=True, recursive=True)
Expand All @@ -463,11 +471,6 @@ def _optimized_for_numpy1(obj):
def _optimized_for_numpy(self):
return

@replace(depthfirst=False, recursive=False)
def _deep_flatten_constants(self):
if isinstance(self, Array):
return self._flatten_constant()

@cached_property
def _loop_concatenate_deps(self):
deps = []
Expand Down Expand Up @@ -1003,10 +1006,6 @@ def _const_uniform(self):
lower, upper = self._intbounds
return lower if lower == upper else None

def _flatten_constant(self):
if self.isconstant:
return constant(self.eval())


class Orthonormal(Array):
'make a vector orthonormal to a subspace'
Expand Down Expand Up @@ -1187,9 +1186,6 @@ def _const_uniform(self):
if self.ndim == 0:
return self.dtype(self.value[()])

def _flatten_constant(self):
pass


class InsertAxis(Array):

Expand Down Expand Up @@ -1309,9 +1305,6 @@ def _intbounds_impl(self):
def _const_uniform(self):
return self.func._const_uniform

def _flatten_constant(self):
pass


class Transpose(Array):

Expand Down Expand Up @@ -1507,9 +1500,6 @@ def _intbounds_impl(self):
def _const_uniform(self):
return self.func._const_uniform

def _flatten_constant(self):
pass


class Product(Array):

Expand Down

0 comments on commit 520fd9a

Please sign in to comment.