From e75953d85ae416a13feeeb67fa5860225daa5744 Mon Sep 17 00:00:00 2001 From: Jhalak Patel Date: Wed, 6 Nov 2024 10:17:29 -0800 Subject: [PATCH] Address review comments --- .../TensorRTRuntimeToExecutor.cpp | 2 +- .../Plan/Transforms/EliminateShapeOps.cpp | 9 ++---- .../lib/Dialect/Plan/Transforms/Passes.cpp | 30 ++++++------------ .../include/mlir-executor-c/Runtime/Runtime.h | 2 +- .../test/lib/BufferizationTestPass.cpp | 31 +++++++------------ .../python/bindings/Runtime/RuntimePyBind.cpp | 10 +++--- .../NetworkEncoder.cpp | 8 ++--- .../TRT10/test_stablehlo_add.py | 6 +--- .../IntegrationTests/test_call_validation.py | 6 +--- .../test_executable_serialize.py | 8 ++--- .../IntegrationTests/test_stablehlo_add.py | 8 ++--- .../test_stablehlo_dynamic.py | 7 ++--- .../test_runtime_debug_dump.py | 4 +-- 13 files changed, 45 insertions(+), 86 deletions(-) diff --git a/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp b/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp index 43cc64276..ec947acc3 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/TensorRTRuntimeToExecutor/TensorRTRuntimeToExecutor.cpp @@ -379,7 +379,7 @@ struct ConvertEnqueueAllocToCall // Create output memrefs from output descriptors SmallVector results; - unsigned offset = 1; + unsigned offset = 1; // Skip num of results for (unsigned i = 0; i < op->getNumResults(); ++i) { unsigned rank = cast(op->getResult(i).getType()).getRank(); Value rankOffset = b.create( diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp index 21759d805..a1aed10b4 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/EliminateShapeOps.cpp @@ -113,14 +113,11 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection, call.getInputsMutable().erase(i); else if (auto callAlloc = dyn_cast(callOp)) callAlloc.getInputsMutable().erase(i); - else { - llvm::errs() << "Unexpected operation type in callOps\n"; - callOp->dump(); - return failure(); - } + else + return emitError(funcOp->getLoc()) + << "Unexpected operation type in callOps"; } } - return success(); } diff --git a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp index 753bea88e..c9ce415aa 100644 --- a/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp +++ b/mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp @@ -103,31 +103,21 @@ struct ClusteringPipelineCliOpts llvm::cl::init(NV_TENSORRT_MAJOR)}; }; -struct PlanBufferizationPipelineCliOpts - : public PassPipelineOptions { - Option enableNonDPSReturns{ - *this, "enable-non-dps-returns", - llvm::cl::desc("allow backend clusters to directly allocate outputs"), - llvm::cl::init(false)}; -}; - } // namespace // Register pipelines. void plan::registerPlanDialectPipelines() { - PassPipelineRegistration - executorBufferizationPipeline( - "plan-bufferize-pipeline", - "perform bufferization and standard pre/post processing passes", - [](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) { - PlanAllocTensorsPassOptions allocTensorOpts{}; - allocTensorOpts.enableNonDPSReturns = opts.enableNonDPSReturns; - buildPlanBufferizationPipeline(pm, allocTensorOpts); - buildPlanBufferOptimizationPipeline(pm); - buildPlanBufferDeallocationPipeline( - pm, bufferization::DeallocationOptions{false}); - }); + PassPipelineRegistration<> executorBufferizationPipeline( + "plan-bufferize-pipeline", + "perform bufferization and standard pre/post processing passes", + [](OpPassManager &pm) { + PlanAllocTensorsPassOptions allocTensorOpts{}; + buildPlanBufferizationPipeline(pm, allocTensorOpts); + buildPlanBufferOptimizationPipeline(pm); + buildPlanBufferDeallocationPipeline( + pm, bufferization::DeallocationOptions{false}); + }); PassPipelineRegistration<> bufferOptPipeline( "plan-buffer-opt-pipeline", "perform post-bufferization optimizations", diff --git a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h index 6bdf72ad6..99a5686b0 100644 --- a/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h +++ b/mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h @@ -53,7 +53,7 @@ extern "C" { /// caller must be sure to delete errors via mtrtStatusDestroy. //===----------------------------------------------------------------------===// -typedef struct MTRT_RuntimeClient MTRT_Runtimeclient; +struct MTRT_RuntimeClient; // Forward declaration //===----------------------------------------------------------------------===// // MTRT_GlobalDebug diff --git a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp index 437acdc3f..869373d44 100644 --- a/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp +++ b/mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp @@ -54,31 +54,22 @@ class ExecutorBufferizationTestPass } }; -struct PlanBufferizationPipelineCliOpts - : public PassPipelineOptions { - Option enableNonDPSReturns{ - *this, "enable-non-dps-returns", - llvm::cl::desc("allow backend clusters to directly allocate outputs"), - llvm::cl::init(false)}; -}; - } // namespace namespace mlir::executor { void registerTestExecutorBufferizePass() { PassRegistration(); - PassPipelineRegistration - executorBufferizationPipeline( - "test-executor-bufferization-pipeline", - "Run one-shot-bufferization and buffer deallocation pipelines", - [](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) { - pm.addPass(std::make_unique()); - pm.addPass(bufferization::createDropEquivalentBufferResultsPass()); - bufferization::BufferDeallocationPipelineOptions deallocOptions{}; - bufferization::buildBufferDeallocationPipeline(pm, deallocOptions); - pm.addPass(createCSEPass()); - pm.addPass(createCanonicalizerPass()); - }); + PassPipelineRegistration<> executorBufferizationPipeline( + "test-executor-bufferization-pipeline", + "Run one-shot-bufferization and buffer deallocation pipelines", + [](OpPassManager &pm) { + pm.addPass(std::make_unique()); + pm.addPass(bufferization::createDropEquivalentBufferResultsPass()); + bufferization::BufferDeallocationPipelineOptions deallocOptions{}; + bufferization::buildBufferDeallocationPipeline(pm, deallocOptions); + pm.addPass(createCSEPass()); + pm.addPass(createCanonicalizerPass()); + }); } } // namespace mlir::executor diff --git a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp index 1654549de..1c2d99039 100644 --- a/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp +++ b/mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp @@ -244,7 +244,7 @@ class PyRuntimeClient using Base::Base; DECLARE_WRAPPER_CONSTRUCTORS(PyRuntimeClient); - static constexpr auto kMethodTable = CAPITable{ + static constexpr auto kMethodTable = CAPITable{ mtrtRuntimeClientIsNull, mtrtRuntimeClientDestroy}; }; @@ -961,7 +961,8 @@ PYBIND11_MODULE(_api, m) { [](PyRuntimeSession &self, std::string name, std::vector inArgs, std::optional> outArgs, - std::optional stream, PyRuntimeClient &client) { + std::optional stream, + PyRuntimeClient *client = nullptr) { MTRT_StringView nameRef{name.data(), name.size()}; int64_t numResults; @@ -980,7 +981,8 @@ PYBIND11_MODULE(_api, m) { self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(), outArgsGeneric.data(), outArgsGeneric.size(), resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(), - client); + client ? MTRT_RuntimeClient(*client) + : mtrtRuntimeClientGetNull()); THROW_IF_MTRT_ERROR(s); std::vector resultPyObject; @@ -992,7 +994,7 @@ PYBIND11_MODULE(_api, m) { return resultPyObject; }, py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(), - py::arg("stream") = py::none(), py::arg("client"), + py::arg("stream") = py::none(), py::arg("client") = nullptr, "Execute a function given input and optional output arguments. " "Return optional results as a Python object if output arguments are " "not present."); diff --git a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp index 502e63a17..bea4391b1 100644 --- a/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp +++ b/mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp @@ -568,10 +568,10 @@ static LogicalResult serializeSplatElements(DenseIntOrFPElementsAttr values, std::fill_n(reinterpret_cast(data.data()), data.size(), packed); return llvm::success(); } - llvm::errs() << "Error: " - << "unsupported data type to convert MLIR splat attribute to " - "TensorRT weights!"; - return llvm::failure(); + + return emitError(UnknownLoc::get(values.getContext())) + << "unsupported data type to convert MLIR splat attribute to TensorRT " + "weights!"; } FailureOr diff --git a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py index 9535aef79..480ce74d4 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/TRT10/test_stablehlo_add.py @@ -36,11 +36,7 @@ def test_stablehlo_add( session = runtime.RuntimeSession(session_options, exe) session.execute_function( - "main", - in_args=test.in_args, - out_args=test.out_args, - stream=stream, - client=runtime_client, + "main", in_args=test.in_args, out_args=test.out_args, stream=stream ) output = [ ( diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py index ee3c784f8..415767d70 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_call_validation.py @@ -73,11 +73,7 @@ def execute(self, arg: runtime.RuntimeValue): session = runtime.RuntimeSession(self.session_options, self.exe) try: session.execute_function( - "main", - in_args=[arg], - out_args=[arg], - stream=self.stream, - client=self.client, + "main", in_args=[arg], out_args=[arg], stream=self.stream ) print("Test passed succesfully") except runtime.MTRTException as e: diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py b/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py index c841a334e..e4bb3cba5 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_executable_serialize.py @@ -47,9 +47,7 @@ def test_serialize(ASM): device=devices[0], stream=stream, ) - session0.execute_function( - "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client - ) + session0.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) output0 = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() @@ -59,9 +57,7 @@ def test_serialize(ASM): exe_reconstructed = compiler.Executable(serialized_exe) session1 = runtime.RuntimeSession(session_options, exe_reconstructed) - session1.execute_function( - "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client - ) + session1.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) output1 = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py index d81bedeff..2c95a3081 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_add.py @@ -49,9 +49,7 @@ def stablehlo_add(): device=devices[0], stream=stream, ) - session.execute_function( - "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client - ) + session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) data = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() @@ -63,9 +61,7 @@ def stablehlo_add(): num_iter = 5 start_time = time.time() for _ in range(0, num_iter): - session.execute_function( - "main", in_args=[arg0], out_args=[arg0], stream=stream, client=client - ) + session.execute_function("main", in_args=[arg0], out_args=[arg0], stream=stream) data = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync() end_time = time.time() diff --git a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py index 0429e20f1..35515e054 100644 --- a/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py +++ b/mlir-tensorrt/test/python/IntegrationTests/test_stablehlo_dynamic.py @@ -77,10 +77,7 @@ def infer_output_shape(client, session, exe, input_shape): outs = [client.create_memref(out_0, shape=shape, dtype=runtime.ScalarTypeCode.i64)] session.execute_function( - exe.get_signature("main").get_shape_func_name(), - in_args=ins, - out_args=outs, - client=client, + exe.get_signature("main").get_shape_func_name(), in_args=ins, out_args=outs ) # Copy output shape from device to host. Also, convert to int32 type since shape calculation uses int64 type. @@ -138,7 +135,7 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True): ) session.execute_function( - "main", in_args=[arg0, arg1], out_args=[arg2], stream=stream, client=client + "main", in_args=[arg0, arg1], out_args=[arg2], stream=stream ) data = np.asarray(client.copy_to_host(arg2, stream=stream)) stream.sync() diff --git a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py index 4d379c3e6..68ca684e8 100644 --- a/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py +++ b/mlir-tensorrt/test/python/mlir_tensorrt_runtime/test_runtime_debug_dump.py @@ -49,9 +49,7 @@ def stablehlo_add(): device=devices[0], stream=stream, ) - session.execute_function( - "main", in_args=[arg0], out_args=[arg1], stream=stream, client=client - ) + session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream) data = np.asarray(client.copy_to_host(arg1, stream=stream)) stream.sync()