Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 6, 2024
1 parent ac28e6c commit e75953d
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ struct ConvertEnqueueAllocToCall

// Create output memrefs from output descriptors
SmallVector<Value> results;
unsigned offset = 1;
unsigned offset = 1; // Skip num of results
for (unsigned i = 0; i < op->getNumResults(); ++i) {
unsigned rank = cast<MemRefType>(op->getResult(i).getType()).getRank();
Value rankOffset = b.create<executor::GetOffsetOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,11 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
call.getInputsMutable().erase(i);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
callAlloc.getInputsMutable().erase(i);
else {
llvm::errs() << "Unexpected operation type in callOps\n";
callOp->dump();
return failure();
}
else
return emitError(funcOp->getLoc())
<< "Unexpected operation type in callOps";
}
}

return success();
}

Expand Down
30 changes: 10 additions & 20 deletions mlir-tensorrt/compiler/lib/Dialect/Plan/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,31 +103,21 @@ struct ClusteringPipelineCliOpts
llvm::cl::init(NV_TENSORRT_MAJOR)};
};

struct PlanBufferizationPipelineCliOpts
: public PassPipelineOptions<PlanBufferizationPipelineCliOpts> {
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc("allow backend clusters to directly allocate outputs"),
llvm::cl::init(false)};
};

} // namespace

// Register pipelines.

void plan::registerPlanDialectPipelines() {
PassPipelineRegistration<PlanBufferizationPipelineCliOpts>
executorBufferizationPipeline(
"plan-bufferize-pipeline",
"perform bufferization and standard pre/post processing passes",
[](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) {
PlanAllocTensorsPassOptions allocTensorOpts{};
allocTensorOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
buildPlanBufferizationPipeline(pm, allocTensorOpts);
buildPlanBufferOptimizationPipeline(pm);
buildPlanBufferDeallocationPipeline(
pm, bufferization::DeallocationOptions{false});
});
PassPipelineRegistration<> executorBufferizationPipeline(
"plan-bufferize-pipeline",
"perform bufferization and standard pre/post processing passes",
[](OpPassManager &pm) {
PlanAllocTensorsPassOptions allocTensorOpts{};
buildPlanBufferizationPipeline(pm, allocTensorOpts);
buildPlanBufferOptimizationPipeline(pm);
buildPlanBufferDeallocationPipeline(
pm, bufferization::DeallocationOptions{false});
});

PassPipelineRegistration<> bufferOptPipeline(
"plan-buffer-opt-pipeline", "perform post-bufferization optimizations",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ extern "C" {
/// caller must be sure to delete errors via mtrtStatusDestroy.
//===----------------------------------------------------------------------===//

typedef struct MTRT_RuntimeClient MTRT_Runtimeclient;
struct MTRT_RuntimeClient; // Forward declaration

//===----------------------------------------------------------------------===//
// MTRT_GlobalDebug
Expand Down
31 changes: 11 additions & 20 deletions mlir-tensorrt/executor/test/lib/BufferizationTestPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,22 @@ class ExecutorBufferizationTestPass
}
};

struct PlanBufferizationPipelineCliOpts
: public PassPipelineOptions<PlanBufferizationPipelineCliOpts> {
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc("allow backend clusters to directly allocate outputs"),
llvm::cl::init(false)};
};

} // namespace

namespace mlir::executor {
void registerTestExecutorBufferizePass() {
PassRegistration<ExecutorBufferizationTestPass>();

PassPipelineRegistration<PlanBufferizationPipelineCliOpts>
executorBufferizationPipeline(
"test-executor-bufferization-pipeline",
"Run one-shot-bufferization and buffer deallocation pipelines",
[](OpPassManager &pm, const PlanBufferizationPipelineCliOpts &opts) {
pm.addPass(std::make_unique<ExecutorBufferizationTestPass>());
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
bufferization::BufferDeallocationPipelineOptions deallocOptions{};
bufferization::buildBufferDeallocationPipeline(pm, deallocOptions);
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
});
PassPipelineRegistration<> executorBufferizationPipeline(
"test-executor-bufferization-pipeline",
"Run one-shot-bufferization and buffer deallocation pipelines",
[](OpPassManager &pm) {
pm.addPass(std::make_unique<ExecutorBufferizationTestPass>());
pm.addPass(bufferization::createDropEquivalentBufferResultsPass());
bufferization::BufferDeallocationPipelineOptions deallocOptions{};
bufferization::buildBufferDeallocationPipeline(pm, deallocOptions);
pm.addPass(createCSEPass());
pm.addPass(createCanonicalizerPass());
});
}
} // namespace mlir::executor
10 changes: 6 additions & 4 deletions mlir-tensorrt/python/bindings/Runtime/RuntimePyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class PyRuntimeClient
using Base::Base;
DECLARE_WRAPPER_CONSTRUCTORS(PyRuntimeClient);

static constexpr auto kMethodTable = CAPITable<MTRT_Runtimeclient>{
static constexpr auto kMethodTable = CAPITable<MTRT_RuntimeClient>{
mtrtRuntimeClientIsNull, mtrtRuntimeClientDestroy};
};

Expand Down Expand Up @@ -961,7 +961,8 @@ PYBIND11_MODULE(_api, m) {
[](PyRuntimeSession &self, std::string name,
std::vector<py::object> inArgs,
std::optional<std::vector<py::object>> outArgs,
std::optional<MTRT_Stream> stream, PyRuntimeClient &client) {
std::optional<MTRT_Stream> stream,
PyRuntimeClient *client = nullptr) {
MTRT_StringView nameRef{name.data(), name.size()};

int64_t numResults;
Expand All @@ -980,7 +981,8 @@ PYBIND11_MODULE(_api, m) {
self, nameRef, inArgsGeneric.data(), inArgsGeneric.size(),
outArgsGeneric.data(), outArgsGeneric.size(),
resultsGeneric.data(), stream ? *stream : mtrtStreamGetNull(),
client);
client ? MTRT_RuntimeClient(*client)
: mtrtRuntimeClientGetNull());
THROW_IF_MTRT_ERROR(s);

std::vector<py::object> resultPyObject;
Expand All @@ -992,7 +994,7 @@ PYBIND11_MODULE(_api, m) {
return resultPyObject;
},
py::arg("name"), py::arg("in_args"), py::arg("out_args") = py::none(),
py::arg("stream") = py::none(), py::arg("client"),
py::arg("stream") = py::none(), py::arg("client") = nullptr,
"Execute a function given input and optional output arguments. "
"Return optional results as a Python object if output arguments are "
"not present.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,10 @@ static LogicalResult serializeSplatElements(DenseIntOrFPElementsAttr values,
std::fill_n(reinterpret_cast<uint8_t *>(data.data()), data.size(), packed);
return llvm::success();
}
llvm::errs() << "Error: "
<< "unsupported data type to convert MLIR splat attribute to "
"TensorRT weights!";
return llvm::failure();

return emitError(UnknownLoc::get(values.getContext()))
<< "unsupported data type to convert MLIR splat attribute to TensorRT "
"weights!";
}

FailureOr<nvinfer1::Weights>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ def test_stablehlo_add(
session = runtime.RuntimeSession(session_options, exe)

session.execute_function(
"main",
in_args=test.in_args,
out_args=test.out_args,
stream=stream,
client=runtime_client,
"main", in_args=test.in_args, out_args=test.out_args, stream=stream
)
output = [
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ def execute(self, arg: runtime.RuntimeValue):
session = runtime.RuntimeSession(self.session_options, self.exe)
try:
session.execute_function(
"main",
in_args=[arg],
out_args=[arg],
stream=self.stream,
client=self.client,
"main", in_args=[arg], out_args=[arg], stream=self.stream
)
print("Test passed succesfully")
except runtime.MTRTException as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def test_serialize(ASM):
device=devices[0],
stream=stream,
)
session0.execute_function(
"main", in_args=[arg0], out_args=[arg1], stream=stream, client=client
)
session0.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)
output0 = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()

Expand All @@ -59,9 +57,7 @@ def test_serialize(ASM):
exe_reconstructed = compiler.Executable(serialized_exe)

session1 = runtime.RuntimeSession(session_options, exe_reconstructed)
session1.execute_function(
"main", in_args=[arg0], out_args=[arg1], stream=stream, client=client
)
session1.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)
output1 = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def stablehlo_add():
device=devices[0],
stream=stream,
)
session.execute_function(
"main", in_args=[arg0], out_args=[arg1], stream=stream, client=client
)
session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)

data = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()
Expand All @@ -63,9 +61,7 @@ def stablehlo_add():
num_iter = 5
start_time = time.time()
for _ in range(0, num_iter):
session.execute_function(
"main", in_args=[arg0], out_args=[arg0], stream=stream, client=client
)
session.execute_function("main", in_args=[arg0], out_args=[arg0], stream=stream)
data = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()
end_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ def infer_output_shape(client, session, exe, input_shape):
outs = [client.create_memref(out_0, shape=shape, dtype=runtime.ScalarTypeCode.i64)]

session.execute_function(
exe.get_signature("main").get_shape_func_name(),
in_args=ins,
out_args=outs,
client=client,
exe.get_signature("main").get_shape_func_name(), in_args=ins, out_args=outs
)

# Copy output shape from device to host. Also, convert to int32 type since shape calculation uses int64 type.
Expand Down Expand Up @@ -138,7 +135,7 @@ def test_program(program: str, input_shape: Iterable[int], debug: bool = True):
)

session.execute_function(
"main", in_args=[arg0, arg1], out_args=[arg2], stream=stream, client=client
"main", in_args=[arg0, arg1], out_args=[arg2], stream=stream
)
data = np.asarray(client.copy_to_host(arg2, stream=stream))
stream.sync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def stablehlo_add():
device=devices[0],
stream=stream,
)
session.execute_function(
"main", in_args=[arg0], out_args=[arg1], stream=stream, client=client
)
session.execute_function("main", in_args=[arg0], out_args=[arg1], stream=stream)

data = np.asarray(client.copy_to_host(arg1, stream=stream))
stream.sync()
Expand Down

0 comments on commit e75953d

Please sign in to comment.