Skip to content

Commit

Permalink
Enable end to end non-DPS testing
Browse files Browse the repository at this point in the history
Implement python binding changes to allow execute function return
multiple returns. Update tests to use non-DPS style calling convention.

Also, enable end to end lowering by enabling conversion of closed alloc
group op to tensorrt dialect.

Miscellaneous fixes:
1. Add missing handling of `CallAllocOp` in EliminateShapeOps pass.
2. Skip non ranked tensor type function arguments while collecting host
   tensor arguments.
3. Temporarily add a pass to remove clone operation in MemRefToExecutor
   dialect conversion.
4. Relax memref creation for empty shape tensors.
5. Fix memref life returned from Lua function results. This required
   session allocator to track returned memref.
  • Loading branch information
jhalakpatel committed Nov 4, 2024
1 parent 3ac751b commit fcb0964
Show file tree
Hide file tree
Showing 20 changed files with 477 additions and 144 deletions.
15 changes: 14 additions & 1 deletion mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ StableHLOToExecutableOptions::StableHLOToExecutableOptions(
disallowHostTensorsInTensorRTClusters, llvm::cl::init(false),
llvm::cl::desc("Don't allow TensorRt clusters to contain host tensor "
"calculations (but they can still be inputs)"));
addOption(
"enable-non-dps-returns", enableNonDPSReturns, llvm::cl::init(false),
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"));
addOption("executor-index-bitwidth", executorIndexBitwidth,
llvm::cl::init(64));
addOption("device-compute-capability", deviceComputeCapability,
Expand Down Expand Up @@ -306,6 +310,7 @@ void StableHloToExecutableTask::buildStablehloClusteringPipeline(
plan::StablehloClusteringPassOptions clusteringOpts{};
clusteringOpts.disallowHostTensorsInTensorRTClusters =
opts.disallowHostTensorsInTensorRTClusters;
clusteringOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
clusteringOpts.entrypoint = opts.entrypoint;
plan::buildPlanSegmentationPipeline(pm, clusteringOpts);

Expand Down Expand Up @@ -339,7 +344,9 @@ void StableHloToExecutableTask::buildPostClusteringPipeline(

// Perform bufferization.
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(plan::createPlanAllocTensorsPass());
plan::PlanAllocTensorsPassOptions allocTensorsOpts{};
allocTensorsOpts.enableNonDPSReturns = opts.enableNonDPSReturns;
pm.addPass(plan::createPlanAllocTensorsPass(allocTensorsOpts));
pm.addPass(plan::createPlanBufferizePass());
pm.addPass(createMemRefCastEliminationPass());
pm.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -525,6 +532,11 @@ struct ClusteringPipelineCliOpts
*this, "device-compute-capability",
llvm::cl::desc("target device compute capability (SM version)"),
llvm::cl::init(60)};
Option<bool> enableNonDPSReturns{
*this, "enable-non-dps-returns",
llvm::cl::desc(
"allow tensorrt based output allocations using output allocator"),
llvm::cl::init(false)};
Option<int64_t> deviceMaxSharedMemoryPerBlockKb{
*this, "device-max-smem-per-block",
llvm::cl::desc("max shared memory per block (in kilobytes)"),
Expand Down Expand Up @@ -552,6 +564,7 @@ static StableHLOToExecutableOptions populateStablehloClusteringPipelineOpts(
opts.deviceComputeCapability = cliOpts.deviceComputeCapability;
opts.deviceMaxSharedMemoryPerBlockKb =
cliOpts.deviceMaxSharedMemoryPerBlockKb;
opts.enableNonDPSReturns = cliOpts.enableNonDPSReturns;
opts.shouldInferDeviceOptionsFromHost = cliOpts.inferDeviceOptionsFromHost;
opts.entrypoint = cliOpts.entrypoint;
return opts;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ convertCallOp(Operation *op, IRRewriter &rewriter,
SmallVector<int64_t> hostTensorArgs;
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
const TensorKindLattice *kind = solver.lookupState<TensorKindLattice>(arg);
if (!isa<RankedTensorType>(arg.getType()))
continue;
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
// To be conservative, we only do this if type is i32 and num elements
// <= 8.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,26 @@ struct RemoveWithValuesRewriter : public OpRewritePattern<plan::WithValuesOp> {
} // namespace

/// Get a map from `tensorrt.func` functions to associated `tensorrt.call`
/// operations.
static llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>>
/// and `tensorrt.call_alloc` operations.
static llvm::DenseMap<func::FuncOp, SmallVector<Operation *>>
getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
llvm::DenseMap<func::FuncOp, SmallVector<tensorrt::CallOp>> map;
op->walk([&](tensorrt::CallOp callOp) {
func::FuncOp func = callOp.getFuncCallee(collection);
if (map.contains(func)) {
map[func].push_back(callOp);
llvm::DenseMap<func::FuncOp, SmallVector<Operation *>> map;
op->walk([&](Operation *callOp) {
if (!isa<tensorrt::CallOp, tensorrt::CallAllocOp>(callOp))
return;
}
map.insert(std::make_pair(func, SmallVector<tensorrt::CallOp>{callOp}));

func::FuncOp func;
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
func = call.getFuncCallee(collection);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
func = callAlloc.getFuncCallee(collection);
else
return;

if (map.count(func))
map[func].push_back(callOp);
else
map.insert({func, SmallVector<Operation *>{callOp}});
});
return map;
}
Expand All @@ -84,7 +93,7 @@ getTensorRTFunctionCallMap(ModuleOp op, SymbolTableCollection &collection) {
/// `tensorrt.call` operations.
static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
ModuleOp op, func::FuncOp funcOp,
ArrayRef<tensorrt::CallOp> callOps) {
ArrayRef<Operation *> callOps) {
llvm::SmallBitVector unusedArgs(funcOp.getNumArguments(), 0);
for (BlockArgument arg : funcOp.getArguments()) {
if (arg.use_empty())
Expand All @@ -99,8 +108,16 @@ static LogicalResult removeUnusedArgs(SymbolTableCollection &collection,
funcOp.eraseArgument(i);

// Update the call ops.
for (tensorrt::CallOp callOp : callOps)
callOp.getInputsMutable().erase(i);
for (Operation *callOp : callOps) {
if (auto call = dyn_cast<tensorrt::CallOp>(callOp))
call.getInputsMutable().erase(i);
else if (auto callAlloc = dyn_cast<tensorrt::CallAllocOp>(callOp))
callAlloc.getInputsMutable().erase(i);
else {
llvm::errs() << "Unexpected operation type in callOps\n";
callOp->dump();
}
}
}

return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,82 @@ static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,

static LogicalResult outlineTensorRTRegion(RewriterBase &rewriter,
plan::InlineClosedAllocGroupOp op) {
return op.emitError("outlinining inline closed alloc group ops to tensorrt "
"dialect is not yet implemented");
tensorrt::TensorRTModuleOp trtModuleOp = getOrCreateTensorRTModuleOp(op);
auto funcArgTypes = llvm::to_vector(TypeRange(op.getInputs()));
FailureOr<FunctionOpInterface> func = createOutlinedFunc(
rewriter, op.getLoc(), op, trtModuleOp, "tensorrt_cluster",
"cluster.tensorrt", TypeRange(op.getInputs()),
op.getYield()->getOperandTypes());
if (failed(func))
return failure();
assert(func->getFunctionBody().getBlocks().size() == 1 &&
"expected body with one block");
func->setPublic();

rewriter.setInsertionPoint(op);

auto callOp = rewriter.create<tensorrt::CallAllocOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
SymbolRefAttr::get(trtModuleOp.getNameAttr(),
{FlatSymbolRefAttr::get(*func)}));

// Populate the function arguments attributes.
for (unsigned i = 0; i < (*func).getNumArguments(); i++) {
BoundsAttr srcAttr = cast<BoundsAttr>(op.getInputAttrs()[i]);
// We may have scalar (index|signless int)-typed values since we haven't
// eliminated `plan.(with_shape|with_values)` ops yet.
if (!op.argHasTensorType(i) || srcAttr.isNone())
continue;
FailureOr<tensorrt::ShapeProfileAttr> boundAttr =
getTensorRTShapeProfile(srcAttr, op.getInputs()[i]);
if (failed(boundAttr))
return op->emitOpError("failed to create TensorRT shape profile "
"attribute from Plan BoundsAttr for argument #")
<< i << " (" << srcAttr << ")";
if (srcAttr.isShapeBound()) {
func->setArgAttr(i,
tensorrt::TensorRTDialect::getShapeProfileArgAttrName(),
*boundAttr);
continue;
}
assert(srcAttr.isValueBound() && "expected value bound or shape bound");
func->setArgAttr(
i, tensorrt::TensorRTDialect::getShapeTensorValueBoundsArgAttrName(),
*boundAttr);
func->setArgAttr(i, mlir::getHostTensorArgAttrName(),
rewriter.getUnitAttr());
}

// Populate the function entry block.
rewriter.eraseBlock(&func->getFunctionBody().front());

// Move private decomposition funcs associated with all `stablehlo.composite`
// ops to the `tensorrt.module` op. This is needed since `tensorrt.module` op
// has its own symbol table.
SymbolTableCollection symbolTable;
for (auto compositeOp : op.getBody().getOps<stablehlo::CompositeOp>()) {
auto decompositionFunc = dyn_cast_if_present<func::FuncOp>(
symbolTable.lookupSymbolIn(op->getParentOfType<ModuleOp>(),
compositeOp.getDecompositionAttr()));
if (!decompositionFunc)
return emitError(compositeOp.getLoc())
<< "failed to lookup stablehlo.composite decomposition "
"function: "
<< compositeOp.getDecompositionAttr();
rewriter.moveOpAfter(decompositionFunc, func->getOperation());
}

// Move region op operations to the func body.
Operation *regionYieldOp = op.getYield();
rewriter.inlineRegionBefore(op.getRegion(), func->getFunctionBody(),
func->getFunctionBody().end());
rewriter.setInsertionPoint(regionYieldOp);
rewriter.replaceOpWithNewOp<func::ReturnOp>(regionYieldOp,
regionYieldOp->getOperands());

// replace the original region results.
rewriter.replaceOp(op, callOp);
return success();
}

/// Create outlined functions for each `scf.execute_region` operation within
Expand Down
30 changes: 25 additions & 5 deletions mlir-tensorrt/executor/include/mlir-executor-c/Runtime/Runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ static inline bool mtrtRuntimeClientIsNull(MTRT_RuntimeClient client) {
return !client.ptr;
}

/// Returns null client.
static inline MTRT_RuntimeClient mtrtRuntimeClientGetNull() {
return MTRT_RuntimeClient{nullptr};
}

/// Creates a `MTRT_RuntimeClient`. Client must be alive for the lifetime of the
/// program execution.
/// The `stream` passed to the client is used by all underlying CUDA methods
Expand Down Expand Up @@ -308,6 +313,12 @@ static inline bool mtrtRuntimeValueIsNull(MTRT_RuntimeValue value) {
return !value.ptr;
}

// Returns whether the RuntimeValue is MemRef.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value);

// Returns whether the RuntimeValue is Scalar.
MLIR_CAPI_EXPORTED bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value);

/// Cast a MTRT_MemRefValue to a generic MTRT_RuntimeValue.
MLIR_CAPI_EXPORTED MTRT_RuntimeValue
mtrtMemRefCastToRuntimeValue(MTRT_MemRefValue memref);
Expand Down Expand Up @@ -391,16 +402,25 @@ static inline bool mtrtRuntimeSessionIsNull(MTRT_RuntimeSession session) {
return !session.ptr;
}

/// Using `session`, execute the pubic function with the specified name.
/// The `inArgs` and `outArgs` are arrays for input arguments and destination
/// arguments, respectively. Input arguments may be MemRefs or scalars, but
/// destination arguments must be MemRefs.
/// Using `session`, execute the public function with the specified name.
/// The `inArgs`, `outArgs`, and `results` are arrays for input arguments,
/// output arguments, and return values, respectively. Arguments and results
/// can be MemRefs, scalars, or other supported types. Both `outArgs` and
/// `results` can be used simultaneously, allowing for functions that both
/// modify arguments and return values.
/// A stream may optionally be specified, otherwise pass the result of
/// `mtrtStreamGetNull()`.
///
/// The `results` array must point to an array with at least the number of
/// elements returned by mtrtRuntimeSessionGetNumResults for the given function.
MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream);
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client);

MLIR_CAPI_EXPORTED MTRT_Status mtrtRuntimeSessionGetNumResults(
MTRT_RuntimeSession session, MTRT_StringView name, int64_t *numResults);

//===----------------------------------------------------------------------===//
// DLPack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ executeFunctionWithLuaBackend(LuaRuntimeSession &session, std::string_view name,
std::optional<CudaStream> stream = {},
std::optional<RuntimeClient *> client = {});

// Parses the results of a function call, handling both scalar and MemRef return
// types
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>>
parseResults(const sol::protected_function_result &pfr,
const FunctionSignatureView &sig,
std::optional<RuntimeClient *> client);

} // namespace mlirtrt::runtime

#endif // MLIR_TENSORRT_RUNTIME_BACKEND_LUA_LUARUNTIME_H
40 changes: 34 additions & 6 deletions mlir-tensorrt/executor/lib/CAPI/Runtime/Runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,16 @@ MTRT_ScalarValue mtrtRuntimeValueDynCastToScalar(MTRT_RuntimeValue v) {
return wrap(static_cast<ScalarValue *>(x));
}

bool mtrtRuntimeValueIsMemRef(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::MemRef;
}

bool mtrtRuntimeValueIsScalar(MTRT_RuntimeValue value) {
RuntimeValue *x = unwrap(value);
return x->getKind() == RuntimeValue::Kind::Scalar;
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeSessionOptions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -721,7 +731,8 @@ MTRT_Status mtrtRuntimeSessionDestroy(MTRT_RuntimeSession session) {
MTRT_Status mtrtRuntimeSessionExecuteFunction(
MTRT_RuntimeSession session, MTRT_StringView name,
const MTRT_RuntimeValue *inArgs, size_t numInArgs,
const MTRT_RuntimeValue *outArgs, size_t numOutArgs, MTRT_Stream stream) {
const MTRT_RuntimeValue *outArgs, size_t numOutArgs,
MTRT_RuntimeValue *results, MTRT_Stream stream, MTRT_RuntimeClient client) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));

Expand All @@ -731,19 +742,36 @@ MTRT_Status mtrtRuntimeSessionExecuteFunction(
llvm::SmallVector<RuntimeValue *> outArgValues =
llvm::map_to_vector(llvm::ArrayRef(outArgs, numOutArgs),
[](MTRT_RuntimeValue arg) { return unwrap(arg); });

StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> result =
StatusOr<llvm::SmallVector<std::unique_ptr<RuntimeValue>>> resultValues =
executeFunctionWithLuaBackend(
*cppSession, std::string_view(name.data, name.length), inArgValues,
outArgValues,
!mtrtStreamIsNull(stream)
? std::optional(unwrap(stream)->getRawStream())
: std::nullopt);
if (!result.isOk())
return wrap(result.getStatus());
: std::nullopt,
!mtrtRuntimeClientIsNull(client) ? std::optional(unwrap(client))
: std::nullopt);
if (!resultValues.isOk())
return wrap(resultValues.getStatus());

for (size_t i = 0; i < resultValues->size(); ++i)
results[i] = wrap((*resultValues)[i].release());

return mtrtStatusGetOk();
}

MTRT_Status mtrtRuntimeSessionGetNumResults(MTRT_RuntimeSession session,
MTRT_StringView name,
int64_t *numResults) {
LuaRuntimeSession *cppSession =
static_cast<LuaRuntimeSession *>(unwrap(session));
*numResults = cppSession->getExecutable()
.getFunction(std::string_view(name.data, name.length))
.getSignature()
.getNumResults();
return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// MTRT_RuntimeClient
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions mlir-tensorrt/executor/lib/Conversion/MemRefToExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir-executor/Conversion/ConvertToExecutorCommon.h"
#include "mlir-executor/Conversion/Passes.h"
#include "mlir-executor/Executor/IR/Executor.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -548,6 +549,21 @@ void executor::populateMemRefToExecutorPatterns(
}

namespace {

class RemoveNoOpClonePattern : public OpRewritePattern<bufferization::CloneOp> {
public:
using OpRewritePattern<bufferization::CloneOp>::OpRewritePattern;

LogicalResult matchAndRewrite(bufferization::CloneOp op,
PatternRewriter &rewriter) const override {
if (op.getInput().getType() == op.getOutput().getType()) {
rewriter.replaceOp(op, op.getInput());
return success();
}
return failure();
}
};

/// Pass to convert `memref` to `executor` dialect operrations.
class ConvertMemRefToExecutorPass
: public mlir::executor::impl::ConvertMemRefToExecutorPassBase<
Expand Down Expand Up @@ -579,6 +595,10 @@ class ConvertMemRefToExecutorPass
RewritePatternSet patterns(ctx);
executor::populateMemRefToExecutorPatterns(
patterns, typeConverter, allowUncheckedMemrefCastConversion);

// Remove unrealized cast and redundant clone operations.
patterns.add<RemoveNoOpClonePattern>(ctx);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
Expand Down
Loading

0 comments on commit fcb0964

Please sign in to comment.