From 1256ceb266311f365b42b5ce15b6e62fb8e20502 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 11 Jun 2024 06:59:16 -0700 Subject: [PATCH] [Mosaic GPU] Rearrange the pass pipeline (again) PiperOrigin-RevId: 642256145 --- jaxlib/mosaic/gpu/custom_call.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 47072e659a76..d67270354643 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -105,9 +105,9 @@ mlir::FailureOr GetPassPipeline( mlir::registerConvertFuncToLLVMPass(); mlir::registerConvertAffineToStandard(); mlir::registerReconcileUnrealizedCasts(); - mlir::registerGpuToLLVMConversionPass(); // TODO(apaszke): Only register the passes we actually use. mlir::memref::registerMemRefPasses(); + mlir::registerConvertToLLVMPass(); mlir::registerGPUPasses(); mosaic::gpu::registerGpuLaunchLoweringPass(); mosaic::gpu::registerConvertGpuToLLVMPass(); @@ -140,11 +140,12 @@ mlir::FailureOr GetPassPipeline( convert-math-to-llvm{approximate-log1p=true}, canonicalize{max-iterations=10 max-num-rewrites=-1 region-simplify=true test-convergence=false top-down=true}, cse, - reconcile-unrealized-casts,)" + + )" + (target != mlir::gpu::CompilationTarget::Assembly ? "gpu-launch-lowering," : "") + R"( - convert-func-to-llvm{index-bitwidth=0 use-bare-ptr-memref-call-conv=false} + convert-to-llvm, + reconcile-unrealized-casts ) )"); } @@ -170,9 +171,9 @@ void InitContext(mlir::MLIRContext* context) { mlir::registerConvertFuncToLLVMInterface(registry); mlir::index::registerConvertIndexToLLVMInterface(registry); mlir::cf::registerConvertControlFlowToLLVMInterface(registry); - mlir::ub::registerConvertUBToLLVMInterface(registry); // Arith needs this + mlir::ub::registerConvertUBToLLVMInterface(registry); mlir::arith::registerConvertArithToLLVMInterface(registry); - mlir::registerFinalizeMemRefToLLVMConversionPass(); + mlir::registerConvertMemRefToLLVMInterface(registry); mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(registry); mlir::NVVM::registerNVVMTargetInterfaceExternalModels(registry); mlir::registerBuiltinDialectTranslation(registry);