Skip to content

Commit

Permalink
Add pybindings for TensorRTToExecutableOptions
Browse files Browse the repository at this point in the history
Fix TensorRTOptions registration
  • Loading branch information
yizhuoz004 committed Jan 23, 2025
1 parent 9183619 commit 6e552ca
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 4 deletions.
34 changes: 34 additions & 0 deletions mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,40 @@ static inline bool mtrtStableHloToExecutableOptionsIsNull(
return !options.ptr;
}

//===----------------------------------------------------------------------===//
// MTRT_TensorRTToExecutableOptions
//===----------------------------------------------------------------------===//

/// Options for compiling StableHLO MLIR to an Executable.
typedef struct MTRT_TensorRTToExecutableOptions {
void *ptr;
} MTRT_TensorRTToExecutableOptions;

MLIR_CAPI_EXPORTED MTRT_Status mtrtTensorRTToExecutableOptionsCreate(
MTRT_CompilerClient client, MTRT_TensorRTToExecutableOptions *options,
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped);

MLIR_CAPI_EXPORTED MTRT_Status mtrtTensorRTToExecutableOptionsCreateFromArgs(
MTRT_CompilerClient client, MTRT_TensorRTToExecutableOptions *options,
const MlirStringRef *argv, unsigned argc);

/// Specifies whether to enable the global LLVM debug flag for the duration of
/// the compilation process. If the flag is enabled then the debug types
/// specified in the array of literals are used as the global LLVM debug types
/// (equivalent to `-debug-only=[list]`).
MLIR_CAPI_EXPORTED MTRT_Status mtrtTensorRTToExecutableOptionsSetDebugOptions(
MTRT_TensorRTToExecutableOptions options, bool enableDebugging,
const char **debugTypes, size_t debugTypeSizes,
const char *dumpIrTreeDir = nullptr, const char *dumpTensorRTDir = nullptr);

MLIR_CAPI_EXPORTED MTRT_Status mtrtTensorRTToExecutableOptionsDestroy(
MTRT_TensorRTToExecutableOptions options);

static inline bool mtrtTensorRTToExecutableOptionsIsNull(
MTRT_TensorRTToExecutableOptions options) {
return !options.ptr;
}

//===----------------------------------------------------------------------===//
// PassManagerReference APIs
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ struct TensorRTOptions : public OptionsProvider<TensorRTOptions> {
using OptionsProvider::OptionsProvider;
mlir::tensorrt::TensorRTTranslationOptions options;

TensorRTOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {}

void addToOptions(mlir::OptionsContext &context) {
options.addToOptions(context);
TensorRTOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {
options.addToOptions(ctx);
}
};

Expand Down
86 changes: 86 additions & 0 deletions mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "mlir-tensorrt/Compiler/OptionsRegistry.h"
#include "mlir-tensorrt/Compiler/StablehloToExecutable/StablehloToExecutable.h"
#include "mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h"
#include "mlir-tensorrt/Compiler/TensorRTToExecutable/TensorRTToExecutable.h"
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Utils.h"
Expand All @@ -50,6 +51,8 @@ using namespace mlir;
DEFINE_C_API_PTR_METHODS(MTRT_CompilerClient, CompilerClient)
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
StablehloToExecutableOptions)
DEFINE_C_API_PTR_METHODS(MTRT_TensorRTToExecutableOptions,
TensorRTToExecutableOptions)
DEFINE_C_API_PTR_METHODS(MTRT_OptionsContext, OptionsContext)
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
Expand Down Expand Up @@ -271,6 +274,89 @@ MTRT_Status mtrtStableHloToExecutableOptionsDestroy(
return mtrtStatusGetOk();
}


//===----------------------------------------------------------------------===//
// MTRT_TensorRTToExecutableOptions
//===----------------------------------------------------------------------===//

MTRT_Status mtrtTensorRTToExecutableOptionsCreate(
MTRT_CompilerClient client, MTRT_TensorRTToExecutableOptions *options,
int32_t tensorRTBuilderOptLevel, bool tensorRTStronglyTyped) {
auto result =
std::make_unique<TensorRTToExecutableOptions>();
tensorrt::TensorRTTranslationOptions translationOpts = result->get<TensorRTOptions>().options;
translationOpts.tensorrtBuilderOptLevel = tensorRTBuilderOptLevel;
translationOpts.enableStronglyTyped = tensorRTStronglyTyped;

llvm::Error finalizeStatus = result->finalize();

std::optional<std::string> errMsg{};
llvm::handleAllErrors(
std::move(finalizeStatus),
[&errMsg](const llvm::StringError &err) { errMsg = err.getMessage(); });

if (errMsg)
return wrap(getInternalErrorStatus(errMsg->c_str()));

*options = wrap(result.release());
return mtrtStatusGetOk();
}

MTRT_Status mtrtTensorRTToExecutableOptionsCreateFromArgs(
MTRT_CompilerClient client, MTRT_TensorRTToExecutableOptions *options,
const MlirStringRef *argv, unsigned argc) {

auto result =
std::make_unique<TensorRTToExecutableOptions>();
std::vector<llvm::StringRef> argvStrRef(argc);
for (unsigned i = 0; i < argc; i++)
argvStrRef[i] = llvm::StringRef(argv[i].data, argv[i].length);

std::string err;
if (failed(result->parse(argvStrRef, err))) {
std::string line = llvm::join(argvStrRef, " ");
return wrap(getInternalErrorStatus(
"failed to parse options string {0} due to error: {1}", line, err));
}

llvm::Error finalizeStatus = result->finalize();

std::optional<std::string> errMsg{};
llvm::handleAllErrors(
std::move(finalizeStatus),
[&errMsg](const llvm::StringError &err) { errMsg = err.getMessage(); });

if (errMsg)
return wrap(getInternalErrorStatus(errMsg->c_str()));

*options = wrap(result.release());
return mtrtStatusGetOk();
}

MTRT_Status mtrtTensorRTToExecutableOptionsSetDebugOptions(
MTRT_TensorRTToExecutableOptions options, bool enableDebugging,
const char **debugTypes, size_t debugTypeSizes, const char *dumpIrTreeDir,
const char *dumpTensorRTDir) {

TensorRTToExecutableOptions *cppOpts = unwrap(options);
cppOpts->get<DebugOptions>().enableLLVMDebugFlag = enableDebugging;
for (unsigned i = 0; i < debugTypeSizes; i++)
cppOpts->get<DebugOptions>().llvmDebugTypes.emplace_back(debugTypes[i]);

if (dumpIrTreeDir) {
cppOpts->get<DebugOptions>().printTreeDir = std::string(dumpIrTreeDir);
cppOpts->get<DebugOptions>().printAfterAll = true;
}

return mtrtStatusGetOk();
}

MTRT_Status mtrtTensorRTToExecutableOptionsDestroy(
MTRT_TensorRTToExecutableOptions options) {
delete reinterpret_cast<TensorRTToExecutableOptions *>(options.ptr);
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// Main StableHLO Compiler API Functions
//===----------------------------------------------------------------------===//
Expand Down
50 changes: 50 additions & 0 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ class PyStableHLOToExecutableOptions
mtrtStableHloToExecutableOptionsDestroy};
};

/// Python object type wrapper for `MTRT_TensorRTToExecutableOptions`.
class PyTensorRTToExecutableOptions
: public PyMTRTWrapper<PyTensorRTToExecutableOptions,
MTRT_TensorRTToExecutableOptions> {
public:
using PyMTRTWrapper::PyMTRTWrapper;
DECLARE_WRAPPER_CONSTRUCTORS(PyTensorRTToExecutableOptions);
static constexpr auto kMethodTable =
CAPITable<MTRT_TensorRTToExecutableOptions>{
mtrtTensorRTToExecutableOptionsIsNull,
mtrtTensorRTToExecutableOptionsDestroy};
};

/// Python object type wrapper for `MlirPassManager`.
class PyPassManagerReference
: public PyMTRTWrapper<PyPassManagerReference, MlirPassManager> {
Expand Down Expand Up @@ -339,6 +352,43 @@ PYBIND11_MODULE(_api, m) {
py::arg("dump_ir_tree_dir") = py::none(),
py::arg("dump_tensorrt_dir") = py::none());

py::class_<PyTensorRTToExecutableOptions>(m, "TensorRTToExecutableOptions",
py::module_local())
.def(py::init<>([](PyCompilerClient &client,
const std::vector<std::string> &args)
-> PyTensorRTToExecutableOptions * {
std::vector<MlirStringRef> refs(args.size());
for (unsigned i = 0; i < args.size(); i++)
refs[i] = mlirStringRefCreate(args[i].data(), args[i].size());

MTRT_TensorRTToExecutableOptions options;
MTRT_Status s = mtrtTensorRTToExecutableOptionsCreateFromArgs(
client, &options, refs.data(), refs.size());
THROW_IF_MTRT_ERROR(s);
return new PyTensorRTToExecutableOptions(options);
}),
py::arg("client"), py::arg("args"))
.def(
"set_debug_options",
[](PyTensorRTToExecutableOptions &self, bool enabled,
std::vector<std::string> debugTypes,
std::optional<std::string> dumpIrTreeDir,
std::optional<std::string> dumpTensorRTDir) {
// The strings are copied by the CAPI call, so we just need to
// refence the C-strings temporarily.
std::vector<const char *> literals;
for (const std::string &str : debugTypes)
literals.push_back(str.c_str());
THROW_IF_MTRT_ERROR(mtrtTensorRTToExecutableOptionsSetDebugOptions(
self, enabled, literals.data(), literals.size(),
dumpIrTreeDir ? dumpIrTreeDir->c_str() : nullptr,
dumpTensorRTDir ? dumpTensorRTDir->c_str() : nullptr));
},
py::arg("enabled"),
py::arg("debug_types") = std::vector<std::string>{},
py::arg("dump_ir_tree_dir") = py::none(),
py::arg("dump_tensorrt_dir") = py::none());

py::class_<PyPassManagerReference>(m, "PassManagerReference",
py::module_local())
.def("run", [](PyPassManagerReference &self, MlirOperation op) {
Expand Down

0 comments on commit 6e552ca

Please sign in to comment.