Skip to content

Commit

Permalink
Strip all info except for shape and dtype while matching input and ou…
Browse files Browse the repository at this point in the history
…tput avals for donation. For avals that are sharded, we match on shard_shape.

PiperOrigin-RevId: 722917186
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 4, 2025
1 parent 6281b86 commit 654a2f6
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def sharded_aval(aval: core.AbstractValue,
return aval
if not isinstance(aval, (core.ShapedArray, core.DShapedArray)):
raise NotImplementedError
return aval.update(sharding.shard_shape(aval.shape), sharding=None) # type: ignore
return aval.update(sharding.shard_shape(aval.shape)) # type: ignore


def eval_dynamic_shape(ctx: LoweringRuleContext,
Expand Down Expand Up @@ -1244,7 +1244,8 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out,
input_output_aliases = list(input_output_aliases)
# To match-up in-avals to out-avals we only care about the number of
# bytes, so we strip off unrelated aval metadata (eg. the named shape)
strip_metadata = lambda a: a.strip_weak_type()
strip_metadata = lambda a: (a if a is core.abstract_token else
core.ShapedArray(a.shape, a.dtype))
avals_in = map(strip_metadata, avals_in)
avals_out = map(strip_metadata, avals_out)

Expand Down

0 comments on commit 654a2f6

Please sign in to comment.