Skip to content

Commit

Permalink
Move internal changes (#475)
Browse files Browse the repository at this point in the history
## [tensorrt] Fix various edge cases in 'tensorrt-broadcast-elimination'

This change fixes several edge cases in 'tensorrt-broadcast-elimination'
for a particular pattern that commutes 'tensorrt.broadcast' and
'tensorrt.collapse_rank'. Previously it did not correctly handle
multiple
collapsed unit dimensions and would produce an erroneous assertion when
the collapsed input is dynamic. Several additional regression tests
are added.

## [tensorrt] Fix nvinfer1::Dims max rank assertion

Fixes an assertion which limited translation of MLIR to TensorRT to
programs with tensors of rank < 8 when it should be <= 8 to align
with TensorRT op verification and the actual nvinfer API. Adds an
additional translation test.

## NFC: apply isort to Python files

## Fix options configuration to enable global flags in "compiler task"
API

Co-authored-by: Copybara Bot <[email protected]>
  • Loading branch information
christopherbate and Copybara Bot authored Jan 13, 2025
1 parent 63ff483 commit 2e04f52
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 78 deletions.
8 changes: 8 additions & 0 deletions mlir-tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ include(AddMLIRPython)
include(MLIRDetectPythonEnv)
mlir_configure_python_dev_packages()

# Declare the Python source targets ahead-of-time since the sources may be built
# up across multiple sub-directories.
if(MLIR_TRT_ENABLE_PYTHON)
declare_mlir_python_sources(MLIRTensorRTPythonCompiler)
declare_mlir_python_sources(MLIRTensorRTPythonCompiler.Dialects
ADD_TO_PARENT MLIRTensorRTPythonCompiler)
endif()

# Add a meta target for all documentation generation targets. You can generate
# all documentation under `${buildDir}/docs` by building this target.
add_custom_target("mlir-tensorrt-doc")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ struct DebugOptions : public OptionsProvider<DebugOptions> {
&this->ctx, "debug-only", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated,
OmitFromCLI{}};

/// If set to `true`, we populate the pass manager instrumentation using
/// global MLIR CL options rather than the local options contained here.
Option<bool> useGlobalCLPrintingOptions{&this->ctx, "use-global-cl-options",
llvm::cl::init(false), OmitFromCLI{}};

/// Apply these options to the current pass manager.
void applyToPassManager(mlir::PassManager &pm) const;
};
Expand Down Expand Up @@ -195,8 +200,6 @@ struct DeviceOptions : public OptionsProvider<DeviceOptions> {
llvm::cl::desc("whether to ignore `deviceX` options and instead infer "
"them from the host GPU")};

Status inferFromHost();

public:
DeviceOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) {
ctx.addOption(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class StablehloToExecutableTensorRTExtension
void addToOptions(mlir::OptionsContext &context) final {
context.addOption("disable-tensorrt-extension", disabled,
llvm::cl::init(false));
context.addOption("use-global-tensorrt-translation-flags", useGlobalCLFlags,
llvm::cl::init(false));
translationOptions.addToOptions(context);
}

Expand All @@ -63,8 +65,8 @@ class StablehloToExecutableTensorRTExtension
/// Options for MLIR-to-TensorRT translation.
mlir::tensorrt::TensorRTTranslationOptions translationOptions;

/// Path where we should persist the timing cache to storage.
std::string timingCachePath;
/// Whether to use global CL config for options.
bool useGlobalCLFlags{false};
};

} // namespace mlirtrt::compiler
Expand Down
40 changes: 13 additions & 27 deletions mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/Timing.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"

using namespace mlir;
using namespace mlirtrt;
Expand All @@ -37,6 +38,18 @@ using namespace mlirtrt::compiler;
//===----------------------------------------------------------------------===//

void DebugOptions::applyToPassManager(PassManager &pm) const {
// If the options specify to use global MLIR CL flags, then apply those
// options. Otherwise, use our local options. Using global options is only
// possible if the LLVM global command line flag environment is initialized
// correctly.
if (useGlobalCLPrintingOptions) {
if (failed(applyPassManagerCLOptions(pm)))
llvm::report_fatal_error("failed to populate pass manager "
"instrumentation from global CL options");
applyDefaultTimingPassManagerCLOptions(pm);
return;
}

std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;

Expand Down Expand Up @@ -78,7 +91,6 @@ void DebugOptions::applyToPassManager(PassManager &pm) const {
printAfterFailure, printTreeDir);
return;
}

pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
printModuleScope, printAfterChange, printAfterFailure,
llvm::errs());
Expand All @@ -88,32 +100,6 @@ void DebugOptions::applyToPassManager(PassManager &pm) const {
// DeviceOptions
//===----------------------------------------------------------------------===//

Status DeviceOptions::inferFromHost() {
cudaDeviceProp properties;
cudaError_t err = cudaGetDeviceProperties(&properties, 0);
if (err != cudaSuccess)
return getStatusWithMsg(StatusCode::InternalError,
"failed to get cuda device properties");
int ccMajor = 0;
int ccMinor = 0;
err = cudaDeviceGetAttribute(
&ccMajor, cudaDeviceAttr::cudaDevAttrComputeCapabilityMajor, 0);
if (err != cudaSuccess)
return getStatusWithMsg(StatusCode::InternalError,
"failed to get cuda device compute capability");
err = cudaDeviceGetAttribute(
&ccMinor, cudaDeviceAttr::cudaDevAttrComputeCapabilityMinor, 0);
if (err != cudaSuccess)
return getStatusWithMsg(StatusCode::InternalError,
"failed to get cuda device compute capability");
// We want SM version as a single number.
int64_t smVersion = ccMajor * 10 + ccMinor;
info.computeCapability = smVersion;
info.maxSharedMemoryPerBlockKb = properties.sharedMemPerBlock / 1024;
info.maxRegistersPerBlock = properties.regsPerBlock;
return Status::getOk();
}

llvm::Error DeviceOptions::finalizeImpl() {
if (shouldInferFromHost) {
StatusOr<DeviceInfo> deviceInfo = getDeviceInformationFromHost();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ void StablehloToExecutableTensorRTExtension::populatePasses(
if (this->disabled)
return;

tensorrt::TensorRTTranslationOptions translationOpts =
useGlobalCLFlags ? tensorrt::TensorRTTranslationOptions::fromCLFlags()
: translationOptions;

if (phase == Phase::PreClustering) {
// We must materialize TRT plugin shape regions prior to clustering.
// We must materialize TRT plugion shape regions prior to clustering.
pm.addNestedPass<func::FuncOp>(tensorrt::createInferPluginShapesPass());
return;
}
Expand All @@ -61,9 +65,9 @@ void StablehloToExecutableTensorRTExtension::populatePasses(
// Simplify and translate functions nested in `tensorrt.module` ops.
auto &trtPM = pm.nest<tensorrt::TensorRTModuleOp>();
tensorrt::buildTensorRTModuleTransformationPipeline(
trtPM, translationOptions.enableStronglyTyped);
trtPM, translationOpts.enableStronglyTyped);
trtPM.addPass(
tensorrt::createTranslateTensorRTPass(nullptr, translationOptions));
tensorrt::createTranslateTensorRTPass(nullptr, translationOpts));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import mlir_tensorrt.compiler.api as compiler
import mlir_tensorrt.compiler.ir as ir
import mlir_tensorrt.runtime.api as runtime
import mlir_tensorrt.compiler.torch_bridge as torch_bridge

import mlir_tensorrt.runtime.api as runtime
import numpy as np
import torch.nn as nn
import torch
import torch.nn as nn


class Model(nn.Module):
Expand Down
3 changes: 0 additions & 3 deletions mlir-tensorrt/python/CompilerPackage.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,10 @@ configure_file(
# Structural groupings.
################################################################################

declare_mlir_python_sources(MLIRTensorRTPythonCompiler)
declare_mlir_python_sources(MLIRTensorRTPythonCompiler.Core
ADD_TO_PARENT MLIRTensorRTPythonCompiler)
declare_mlir_python_sources(MLIRTensorRTPythonCompiler.CompilerAPI
ADD_TO_PARENT MLIRTensorRTPythonCompiler)
declare_mlir_python_sources(MLIRTensorRTPythonCompiler.Dialects
ADD_TO_PARENT MLIRTensorRTPythonCompiler)

################################################################################
# Pure python sources and generated code
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Optional, Union, Dict, Tuple, Any
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple, Union

from torch import nn
from torch.export import ExportedProgram
from dataclasses import dataclass, field
from .compiler_utils import OutputType

from . import fx
from .extras.fx_importer import FxImporter
from .compiler_utils import OutputType
from .dialects import torch as torch_d
from .extras.fx_importer import FxImporter

__all__ = ["TorchInput", "get_mlir_module_from_torch_module"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ template <typename T, std::enable_if_t<std::is_same_v<T, int32_t> ||
std::is_same_v<T, int64_t>,
T *> = nullptr>
static nvinfer1::Dims getNvInferDims(ArrayRef<T> arrayRef) {
assert(arrayRef.size() < nvinfer1::Dims::MAX_DIMS &&
assert(arrayRef.size() <= nvinfer1::Dims::MAX_DIMS &&
"input array exceeds max dims");
nvinfer1::Dims dims;
dims.nbDims = arrayRef.size();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//===- BroadcastElimination.cpp----------------------------------*- c++ -*-===//
//
// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES.
// All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright 2024 - 2025 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -19,8 +18,7 @@
//===----------------------------------------------------------------------===//
#include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h"
#include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
Expand Down Expand Up @@ -55,40 +53,58 @@ static FailureOr<BroadcastOp> exchangeCollapseRankAndBroadcast(
collapseOp.getInputShapeDimIndicesOfRemovedDims();
if (removedDims.empty())
return failure();
llvm::sort(removedDims, [](int64_t lhs, int64_t rhs) { return lhs > rhs; });

// Let's just focus on the first removed dimension.
int64_t focusDim = removedDims.front();
SmallVector<int64_t> bcastInputShape(bcastOp.getInput().getType().getShape());
SmallVector<int64_t> bcastResultShape(bcastOp.getType().getShape());
SmallVector<int64_t> broadcastDims(bcastOp.getBroadcastDims());

// If it is broadcasted, then we must remove it at the input. Drop this dim
// from the list, and all indices higher than this must be decremented.
int64_t *bcastDimIter = llvm::find(broadcastDims, focusDim);
// TODO: can we handle this case?
if (bcastDimIter == broadcastDims.end())
return failure();
auto getBroadcastDimsIndex = [&](int64_t dim) -> std::optional<unsigned> {
auto it = llvm::find(broadcastDims, dim);
if (it != broadcastDims.end())
return std::distance(broadcastDims.begin(), it);

return {};
};

bool changed = false;
for (int64_t removedDim : removedDims) {
std::optional<unsigned> inputShapeDimIdx =
getBroadcastDimsIndex(removedDim);
if (!inputShapeDimIdx)
continue;

// Find which input dimension this corresponds to.
unsigned inputShapeDimIdx =
std::distance(broadcastDims.begin(), bcastDimIter);
assert(bcastInputShape[inputShapeDimIdx] == 1);

// Erase this broadcast dimension.
bcastInputShape.erase(bcastInputShape.begin() + inputShapeDimIdx);
broadcastDims.erase(bcastDimIter);
// Adjust all the other broadcast dimensions.
for (auto &bcastDim : broadcastDims) {
if (bcastDim > focusDim)
bcastDim--;
assert((bcastInputShape[*inputShapeDimIdx] == 1 ||
ShapedType::isDynamic(bcastInputShape[*inputShapeDimIdx])) &&
"expected size-1 dimension");

// Erase this broadcast dimension.
changed = true;
bcastInputShape.erase(bcastInputShape.begin() + *inputShapeDimIdx);
broadcastDims.erase(broadcastDims.begin() + *inputShapeDimIdx);
bcastResultShape.erase(bcastResultShape.begin() + removedDim);
// Adjust all the other broadcast dimensions.
for (int64_t &bcastDim : broadcastDims) {
if (bcastDim > removedDim)
bcastDim--;
}
}

Type newCollapseShapeType =
RankedTensorType::Builder(
cast<RankedTensorType>(bcastOp.getInput().getType()))
.setShape(bcastInputShape);
if (!changed)
return failure();

RankedTensorType newCollapseShapeType =
bcastOp.getInput().getType().clone(bcastInputShape);

Value newBcastInput;
if (getReassociationIndicesForCollapse(
bcastOp.getInput().getType().getShape(), bcastInputShape))
newBcastInput = rewriter.create<CollapseRankOp>(
bcastOp.getLoc(), newCollapseShapeType, bcastOp.getInput());
else
newBcastInput = rewriter.create<ReshapeOp>(
bcastOp.getLoc(), newCollapseShapeType, bcastOp.getInput());

Value newBcastInput = rewriter.create<CollapseRankOp>(
bcastOp.getLoc(), newCollapseShapeType, bcastOp.getInput());
auto newBroadcastOp = rewriter.create<BroadcastOp>(
collapseOp.getLoc(), collapseOp.getType(), newBcastInput, broadcastDims);

Expand Down
Loading

0 comments on commit 2e04f52

Please sign in to comment.