Skip to content

Commit

Permalink
[IFRT] Apply pass that merges multiple reshards into a single one whe…
Browse files Browse the repository at this point in the history
…n they have the same source and destination.

PiperOrigin-RevId: 717613221
  • Loading branch information
jupvfranco authored and Google-ML-Automation committed Jan 20, 2025
1 parent 8784fe8 commit 03260eb
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions xla/python/ifrt/ir/transforms/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ void CreateIfrtToOutlinedAtomProgramsPipeline(

if (!options.propagate_shardings) {
pm.addPass(CreateIfrtVerifyShardingSpecifiedPass());
pm.addNestedPass<mlir::func::FuncOp>(
xla::ifrt::CreateIfrtMergeReshardsPass());
// We can split ifrt.Reshard to ifrt.CopyArrays because all the shardings
// are specified.
pm.addPass(CreateIfrtReshardToCopyArraysPass());
Expand Down

0 comments on commit 03260eb

Please sign in to comment.