Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable stablehlo-complex-math-expander pass. #20853

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions xla/pjrt/mlir_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ absl::Status MlirToXlaComputation(mlir::ModuleOp module,
mlir::BaseScopedDiagnosticHandler diagnostic_handler(context);
{
mlir::PassManager pm(context);
// Expand stablehlo complex math functions such as log_plus_one, etc.
pm.addNestedPass<mlir::func::FuncOp>(
mlir::stablehlo::createStablehloComplexMathExpanderPass());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createChloLegalizeToHloPass());
Expand Down Expand Up @@ -223,6 +226,10 @@ absl::StatusOr<std::string> SerializeUsingVersionedStablehlo(
// Legalize CHLO -> [StableHLO+Shape] -> StableHLO
// Preserve higher-level ops with XLA support. To be replaced by composites.
mlir::PassManager pm(context);
// Expand stablehlo complex math functions such as log_plus_one, etc.
pm.addNestedPass<mlir::func::FuncOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this one here as well? Not expanding during serialize will keep serialized artifacts smaller. I think all paths should go through the other function so we likely don't need this.

Copy link
Contributor Author

@pearu pearu Dec 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, apparently we do need it for CUDA path: removing the pass from SerializeUsingVersionedStablehlo will lead to previous behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference of CPU and CUDA paths is as follows:

CPU:
xla::ifrt::PjRtLoadedExecutable::Create 
  -> xla::TfrtCpuClient::Compile 
  -> xla::MlirToXlaComputation

CUDA:
xla::ifrt::PjRtLoadedExecutable::Create 
  -> xla::PjRtCApiClient::Compile 
  -> xla::Serialize 
  -> xla::SerializeUsingVersionedStablehlo

that is, MlirToXlaComputation is never called when on CUDA path.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a todo comment for fixing this? This isn't expected. There should be a single path to XLA HLO and ideally should be the other method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@GleasonK, re the todo comment, see issue #20903

mlir::stablehlo::createStablehloComplexMathExpanderPass());

xla::sdy::addSdyRoundTripExportPipeline(pm);
pm.addNestedPass<mlir::func::FuncOp>(
mlir::mhlo::createChloLegalizeToHighLevelMhloPass());
Expand Down