Skip to content

Commit

Permalink
Move internal changes (#468)
Browse files Browse the repository at this point in the history
This PR moves the following internal changes to OSS,

**[tensorrt/test] Remove/modify incorrect INT4 test**
This MR removes INT4 unit test from `TensorRT/TRT10/convolution.mlir`
test. It worked before but doesn't work with TensorRT 10.7
anymore. Even though it worked before, INT4 Q/DQ (which
is block quantization) is supported for 2D weights only.
This MR also modifies matmul INT4 test from
`TensorRT/TRT10/matrix-multiply.mlir`
by making scale 2D to strictly follow block quantization specs.

**Don't use `executor.opaque<...>` in TensorRTRuntime-to-Executor
lowering**
Eliminates the final uses of the `!executor.opaque<...>` type, which can
be dropped from the Executor dialect in a future commit.

**NFC: run isort and formatter on all python files**

**[runtime] Improve Lua runtime organization**
This change:

- Removes 'sol' objects and header inclusions from Lua runtime interface
  headers.
- Creates registration mechanism for all the Lua runtime
"modules/extensions".
  This simplifies the interface for adding and maintaining the runtime
  extensions. In the future, it also allows for selectively loading
  extensions during session creation (not implemented in this change).
- Removes misc dead code.

**[compiler] Add support to compile PyTorch models**
This MR enables compiling PyTorch models via Torch-MLIR.
- Add `torch_bridge` module to the `mlir_tensorrt.compiler` package
which
takes in torch module along with necessary arguments and return
MLIR IR (Stablehlo OR Linalg on tensors).
- Add an example to e2e test Torch integration.

**[compiler] update compiler python type stub, remove dead APIs**
Updates the Python type stubs for PyBind11 modules. Corrects the naming
of the "PyPassManagerReference" object and removes some APIs which are
no longer needed.

**NFC: [runtime] Create runtime C implementation 'CoreModule'**
Creates an initial C implementation for functions under 'CoreModule'
which will be needed for exposure to other (non-Lua) backend
implementations, e.g. the LLVM JIT runtime.

Marked 'NFC' since it is pure code movement.

**[compiler] Improve pipeline setup and handling of printing/debug
options**
This commit makes the following changes in order to simplify the
debugging experience when working through python vs. CLI:

- Each compilation task options set now includes a full mirror of the
  debugging options available through MLIR CLI tools (e.g.
  `--mlir-print-ir-after-all`, --mlir-print-ir-tree-dir`, etc).
  This also includes options for display pass statistics and timing
  information (`--mlir-pass-statistics` and `--mlir-timing`).

- Introduce a pass in the StablehloToExecutable pipeline to set the
  default 'plan.cluster_kinds' attribute on the module if it is not
  set by the frontend. This eliminates the special step present in
  'stablehlo-to-executable' and allows us to remove the APIs specific
  to StablehloToExecutable at the frontend and replace them with use of
  `client.get_compilation_task`.

**[compiler] Enable omitting options from MLIR CLI options when
converting an OptionsContext**
Allows tagging options of a derived `OptionsContext` in order to omit
those options when converting to an `mlir::PassPipelineOptions` using
the template adaptor.

**[compiler] Improve organization of compiler tasks, reduce redundant
options objects**
This change:

- Improves organization of compiler tasks by moving
  'StablehloToExecutable' related files under its own sub directory.
  Adds a dedicated Tablegen file for the auxiliary passes needed by
  the StablehloToExecutable pipeline.

- Adds an adaptor that allows using classes derived from
  'OptionsContext' as `mlir::PassPipelineOptions`. This eliminates the
  need to keep redundant options definitions for MLIR pipelines
  registered for CLI tools.

---------

Co-authored-by: Copybara Bot <[email protected]>
  • Loading branch information
shelkesagar29 and Copybara Bot authored Jan 4, 2025
1 parent b0d43f3 commit d341395
Show file tree
Hide file tree
Showing 89 changed files with 1,611 additions and 863 deletions.
9 changes: 9 additions & 0 deletions mlir-tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ mtrt_option(MLIR_TRT_ENABLE_TESTING "Enable building optional tests" ON)
mtrt_option(MLIR_TRT_TARGET_LUA "Enable translating MLIR to the Lua target" ON)
mtrt_option(MLIR_TRT_ENABLE_EXECUTOR "Build the Executor dialect and MLIR-TensorRT Execution Engine" ON)
mtrt_option(MLIR_TRT_ENABLE_NCCL "Enable the NCCL runtime module" ON)
mtrt_option(MLIR_TRT_ENABLE_TORCH "Whether to include torch-mlir features" OFF)

set(MLIR_TRT_TENSORRT_DIR "" CACHE STRING "Path to TensorRT install directory")
set(MLIR_TRT_DOWNLOAD_TENSORRT_VERSION "10.5" CACHE STRING
Expand Down Expand Up @@ -194,6 +195,14 @@ if(MLIR_TRT_ENABLE_HLO AND NOT TARGET StablehloOps)
)
endif()

if(MLIR_TRT_ENABLE_TORCH)
mtrt_add_torch_mlir(
GIT_TAG "30c519369ed7eabad0282d0f874500a9b41fcbbd"
PATCHES
"${CMAKE_CURRENT_LIST_DIR}/build_tools/patches/torch_mlir/torch_mlir.patch"
)
endif()

if(MLIR_TRT_ENABLE_TESTING)
# TODO: force the flag that makes LLVM build google benchmark.
endif()
Expand Down
22 changes: 21 additions & 1 deletion mlir-tensorrt/build_tools/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ function(download_tensorrt)
elseif(NOT (ARCH STREQUAL "x86_64"))
message(FATAL_ERROR "Direct download not available for architecture: ${ARCH}")
endif()
set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/${trt_short_version}/tars/TensorRT-${TRT_VERSION}.${OS}.${ARCH}-gnu.cuda-${CUDA_VERSION}.tar.gz")
set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/${trt_short_version}/tars/TensorRT-${TRT_VERSION}.${OS}.${ARCH}-gnu.cuda-${CUDA_VERSION}.tar.gz")
endif()

# Handle TensorRT 9 versions. These are publicly accessible download links.
Expand Down Expand Up @@ -285,3 +285,23 @@ function(mlir_tensorrt_find_dlpack)
add_library(DLPack::Headers ALIAS DLPackHeaderOnly)
endif()
endfunction()

#-------------------------------------------------------------------------------------
# Download Torch-MLIR
#-------------------------------------------------------------------------------------

function(mtrt_add_torch_mlir)
CPMAddPackage(
NAME torch_mlir
GIT_REPOSITORY https://github.com/llvm/torch-mlir.git
EXCLUDE_FROM_ALL TRUE
OPTIONS
"TORCH_MLIR_OUT_OF_TREE_BUILD ON"
"TORCH_MLIR_ENABLE_STABLEHLO ON"
"TORCH_MLIR_EXTERNAL_STABLEHLO_DIR ${stablehlo_SOURCE_DIR}"
"MLIR_DIR ${CMAKE_BINARY_DIR}/lib/cmake/mlir"
"LLVM_DIR ${llvm_project_BINARY_DIR}/lib/cmake/llvm"
${ARGN}
)
set(torch_mlir_SOURCE_DIR "${torch_mlir_SOURCE_DIR}" PARENT_SCOPE)
endfunction()
128 changes: 128 additions & 0 deletions mlir-tensorrt/build_tools/patches/torch_mlir/torch_mlir.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 822afa0a..987c8bd2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -44,6 +44,14 @@ if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif()

+# It is possible that both stablehlo and torch_mlir projects are used in some compiler project.
+# In this case, we don't want to use stablehlo that is downloaded by torch_mlir (in external/stablehlo)
+# folder but instead stablehlo that is part of top level compiler project.
+# TORCH_MLIR_EXTERNAL_STABLEHLO_DIR represents stablehlo directory (<some_path>/stablehlo)
+# that is included in torch_mlir. It is assumed that top level compiler project makes
+# stablehlo targets available (for example with `add_subdirectory`) and thus they are not added.
+set(TORCH_MLIR_EXTERNAL_STABLEHLO_DIR "" CACHE STRING "Path to stablehlo dir from super project")
+
option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF)

# PyTorch native extension gate. If OFF, then no features which depend on
@@ -142,7 +150,8 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)

function(torch_mlir_target_includes target)
set(_dirs
- $<BUILD_INTERFACE:${MLIR_INCLUDE_DIRS}>
+ $<BUILD_INTERFACE:${MLIR_INCLUDE_DIR}>
+ $<BUILD_INTERFACE:${MLIR_GENERATED_INCLUDE_DIR}>
$<BUILD_INTERFACE:${TORCH_MLIR_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${TORCH_MLIR_BINARY_DIR}/include>
)
@@ -233,12 +242,16 @@ endif()
# project that we don't actually depend on. Further some of those parts
# do not even compile on all platforms.
if (TORCH_MLIR_ENABLE_STABLEHLO)
- set(STABLEHLO_BUILD_EMBEDDED ON)
- set(STABLEHLO_ENABLE_BINDINGS_PYTHON ON)
- add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo
- ${CMAKE_CURRENT_BINARY_DIR}/stablehlo
- EXCLUDE_FROM_ALL)
- include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo)
+ if (NOT "${TORCH_MLIR_EXTERNAL_STABLEHLO_DIR}" STREQUAL "")
+ include_directories(${TORCH_MLIR_EXTERNAL_STABLEHLO_DIR})
+ else()
+ set(STABLEHLO_BUILD_EMBEDDED ON)
+ set(STABLEHLO_ENABLE_BINDINGS_PYTHON ON)
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo
+ ${CMAKE_CURRENT_BINARY_DIR}/stablehlo
+ EXCLUDE_FROM_ALL)
+ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo)
+ endif()
endif()

#-------------------------------------------------------------------------------
diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp
index ddb6e5a5..22b95c8a 100644
--- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp
+++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp
@@ -59,6 +59,12 @@ public:
matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
+ // Find parent module to add global seed, if not present already.
+ auto module = op->getParentOfType<ModuleOp>();
+ OpBuilder b(module.getBodyRegion());
+ if (failed(getOrCreateGlobalVariableForSeed(b, module)))
+ return failure();
+
// Generate sequence for getting the next seed with LCG step:
// nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64.
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
@@ -115,11 +121,6 @@ public:
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);

- auto module = getOperation();
- OpBuilder b(module.getBodyRegion());
- if (failed(getOrCreateGlobalVariableForSeed(b, module)))
- signalPassFailure();
-
RewritePatternSet patterns(context);
target.addIllegalOp<GetNextSeedOp>();
patterns.add<ConvertGetNextSeedOp>(typeConverter, context);
diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py
index ecf129d7..cf07526e 100644
--- a/python/torch_mlir/compiler_utils.py
+++ b/python/torch_mlir/compiler_utils.py
@@ -10,8 +10,8 @@ import tempfile
from typing import Union, List

import torch
-from torch_mlir.passmanager import PassManager
-from torch_mlir.ir import StringAttr
+from .passmanager import PassManager
+from .ir import StringAttr


class TensorPlaceholder:
diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py
index cfe87348..5309f573 100644
--- a/python/torch_mlir/fx.py
+++ b/python/torch_mlir/fx.py
@@ -13,11 +13,11 @@ import torch.export
import torch.nn as nn
from torch.export import ExportedProgram

-from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks
-from torch_mlir import ir
-from torch_mlir.dialects import torch as torch_d
-from torch_mlir.extras.fx_decomp_util import get_decomposition_table
-from torch_mlir.compiler_utils import (
+from .extras.fx_importer import FxImporter, FxImporterHooks
+from . import ir
+from .dialects import torch as torch_d
+from .extras.fx_decomp_util import get_decomposition_table
+from .compiler_utils import (
OutputType,
run_pipeline_with_repro_report,
lower_mlir_module,
diff --git a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir
index 8ef04d95..da2424fc 100644
--- a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir
+++ b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir
@@ -11,5 +11,5 @@ module {
func.func private @f7() -> i64
}

-// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
+// CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
// CHECK-NOT: @global_seed
12 changes: 2 additions & 10 deletions mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,17 @@ static inline bool mtrtStableHloToExecutableOptionsIsNull(
}

//===----------------------------------------------------------------------===//
// StableHloPipeline APIs
// PassManagerReference APIs
//===----------------------------------------------------------------------===//

static inline bool mtrtStableHloPipelineIsNull(MlirPassManager pm) {
static inline bool mtrtPassManagerReferenceIsNull(MlirPassManager pm) {
return !pm.ptr;
}

MLIR_CAPI_EXPORTED MTRT_Status mtrtStableHloPipelineGetCached(
MTRT_CompilerClient client, MTRT_StableHLOToExecutableOptions options,
MlirPassManager *result);

//===----------------------------------------------------------------------===//
// Main StableHLO Compiler API Functions
//===----------------------------------------------------------------------===//

/// Get Executable using StableHloPassManager.
MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerGetExecutable(
MlirPassManager pm, MlirOperation module, MTRT_Executable *result);

/// Compiler StableHLO to Executable.
MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerStableHLOToExecutable(
MTRT_CompilerClient client, MlirOperation module,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(Compiler)
add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(StablehloToExecutable)
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@ class CompilerClient {
/// Return the MLIRContext associated with the client.
mlir::MLIRContext *getContext() const { return context; }

/// Helper for setting the correct logging options on cached PassManagers.
static void setupPassManagerLogging(mlir::PassManager &pm,
const DebugOptions &options);

protected:
CompilerClient(mlir::MLIRContext *context);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "mlir-executor/Support/DeviceInfo.h"
#include "mlir-tensorrt-dialect/Utils/Options.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
Expand All @@ -47,6 +48,8 @@ constexpr bool has_finalize_impl_v<
// a default implementation otherwise.
template <typename Derived>
struct OptionsProvider {
using OmitFromCLI = mlir::OptionsContext::OmitFromCLI;

OptionsProvider(mlir::OptionsContext &ctx) : ctx(ctx) {}

// We don't allow move construction since the actual ptrs/locations of
Expand Down Expand Up @@ -81,21 +84,88 @@ struct OptionsProvider {
struct DebugOptions : public OptionsProvider<DebugOptions> {
public:
using OptionsProvider::OptionsProvider;
/// A directory path where the IR will be dumped during compilation
/// using the `mlir-print-ir-tree-dir` mechanism.
Option<std::string> dumpIRPath{&this->ctx, "mlir-print-ir-tree-dir",
llvm::cl::init("")};
//===--------------------------------------------------------------------===//
// Crash Reproducer Generator
//===--------------------------------------------------------------------===//
Option<std::string> reproducerFile{
&this->ctx, "mlir-pass-pipeline-crash-reproducer",
llvm::cl::desc("Generate a .mlir reproducer file at the given output path"
" if the pass manager crashes or fails"),
OmitFromCLI{}};
Option<bool> localReproducer{
&this->ctx, "mlir-pass-pipeline-local-reproducer",
llvm::cl::desc("When generating a crash reproducer, attempt to generated "
"a reproducer with the smallest pipeline."),
llvm::cl::init(false), OmitFromCLI{}};

//===--------------------------------------------------------------------===//
// IR Printing
//===--------------------------------------------------------------------===//

Option<bool> printBeforeAll{&this->ctx, "mlir-print-ir-before-all",
llvm::cl::desc("Print IR before each pass"),
llvm::cl::init(false), OmitFromCLI{}};
Option<bool> printAfterAll{&this->ctx, "mlir-print-ir-after-all",
llvm::cl::desc("Print IR after each pass"),
llvm::cl::init(false), OmitFromCLI{}};
Option<bool> printAfterChange{
&this->ctx, "mlir-print-ir-after-change",
llvm::cl::desc(
"When printing the IR after a pass, only print if the IR changed"),
llvm::cl::init(false), OmitFromCLI{}};
Option<bool> printAfterFailure{
&this->ctx, "mlir-print-ir-after-failure",
llvm::cl::desc(
"When printing the IR after a pass, only print if the pass failed"),
llvm::cl::init(false), OmitFromCLI{}};
Option<bool> printModuleScope{
&this->ctx, "mlir-print-ir-module-scope",
llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
"always print the top-level operation"),
llvm::cl::init(false), OmitFromCLI{}};
Option<std::string> printTreeDir{
&this->ctx, "mlir-print-ir-tree-dir",
llvm::cl::desc("When printing the IR before/after a pass, print file "
"tree rooted at this directory. Use in conjunction with "
"mlir-print-ir-* flags"),
OmitFromCLI{}};

//===--------------------------------------------------------------------===//
// Pass Statistics
//===--------------------------------------------------------------------===//
Option<bool> passStatistics{
&this->ctx, "mlir-pass-statistics",
llvm::cl::desc("Display the statistics of each pass"),
llvm::cl::init(false), OmitFromCLI{}};

//===--------------------------------------------------------------------===//
// Pass Timing
//===--------------------------------------------------------------------===//
Option<bool> enableTiming{
&this->ctx, "mlir-timing",
llvm::cl::desc(
"Time each pass and print to stderr after the pipeline completes"),
llvm::cl::init(false), OmitFromCLI{}};

//===----------------------------------------------------------------------===//
// Debug Printing
//===----------------------------------------------------------------------===//

/// Whether the LLVM 'debug' flag that enables execution of code guarded by
/// the `LLVM_DEBUG` macro should be set to 'on'. This results in very verbose
/// output from the compiler dumped to stderr.
Option<bool> enableLLVMDebugFlag{&this->ctx, "debug", llvm::cl::init(false)};
Option<bool> enableLLVMDebugFlag{&this->ctx, "debug", llvm::cl::init(false),
OmitFromCLI{}};

/// A set of names to be given to the LLVM 'debug types' option, akin to
/// setting
/// `-debug-types=...` from the command line.
ListOption<std::string> llvmDebugTypes{
&this->ctx, "debug-only", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated};
&this->ctx, "debug-only", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
OmitFromCLI{}};

/// Apply these options to the current pass manager.
void applyToPassManager(mlir::PassManager &pm) const;
};

struct ExecutorOptions : public OptionsProvider<ExecutorOptions> {
Expand Down Expand Up @@ -131,7 +201,7 @@ struct DeviceOptions : public OptionsProvider<DeviceOptions> {
DeviceOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {
ctx.addOption(
"device-compute-capability", info.computeCapability, llvm::cl::init(60),
llvm::cl::desc("Sets the device compute capbility. Only relevant "
llvm::cl::desc("Sets the device compute capability. Only relevant "
"if '--device-infer-from-host=false'"));
ctx.addOption("device-max-shared-memory-per-block-kb",
info.maxSharedMemoryPerBlockKb, llvm::cl::init(48));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
set(_TABLEGEN_ARGS )
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name StablehloToExecutable ${_TABLEGEN_ARGS})
add_public_tablegen_target(MLIRTensorRTStablehloToExecutableIncGen)
Loading

0 comments on commit d341395

Please sign in to comment.