From 2e04f52d11af346e919cbb7db1c7ed84cb86222e Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Mon, 13 Jan 2025 11:32:06 -0700 Subject: [PATCH] Move internal changes (#475) ## [tensorrt] Fix various edge cases in 'tensorrt-broadcast-elimination' This change fixes several edge cases in 'tensorrt-broadcast-elimination' for a particular pattern that commutes 'tensorrt.broadcast' and 'tensorrt.collapse_rank'. Previously it did not correctly handle multiple collapsed unit dimensions and would produce an erroneous assertion when the collapsed input is dynamic. Several additional regression tests are added. ## [tensorrt] Fix nvinfer1::Dims max rank assertion Fixes an assertion which limited translation of MLIR to TensorRT to programs with tensors of rank < 8 when it should be <= 8 to align with TensorRT op verification and the actual nvinfer API. Adds an additional translation test. ## NFC: apply isort to Python files ## Fix options configuration to enable global flags in "compiler task" API Co-authored-by: Copybara Bot --- mlir-tensorrt/CMakeLists.txt | 8 ++ .../mlir-tensorrt/Compiler/OptionsProviders.h | 7 +- .../StablehloToExecutable/TensorRTExtension.h | 6 +- .../lib/Compiler/OptionsProviders.cpp | 40 ++++------ .../TensorRTExtension.cpp | 10 ++- .../IntegrationTests/Torch/test_torch_add.py | 5 +- mlir-tensorrt/python/CompilerPackage.cmake | 3 - .../mlir_tensorrt/compiler/torch_bridge.py | 10 ++- .../NetworkEncoder.h | 2 +- .../Transforms/BroadcastElimination.cpp | 78 +++++++++++-------- .../TensorRT/broadcast-elimination.mlir | 76 +++++++++++++++++- .../test/Target/TensorRT/max-tensor-rank.mlir | 14 ++++ 12 files changed, 181 insertions(+), 78 deletions(-) create mode 100644 mlir-tensorrt/tensorrt/test/Target/TensorRT/max-tensor-rank.mlir diff --git a/mlir-tensorrt/CMakeLists.txt b/mlir-tensorrt/CMakeLists.txt index 59bbc6102..02c709de3 100644 --- a/mlir-tensorrt/CMakeLists.txt +++ b/mlir-tensorrt/CMakeLists.txt @@ -253,6 +253,14 @@ include(AddMLIRPython) include(MLIRDetectPythonEnv) mlir_configure_python_dev_packages() +# Declare the Python source targets ahead-of-time since the sources may be built +# up across multiple sub-directories. +if(MLIR_TRT_ENABLE_PYTHON) + declare_mlir_python_sources(MLIRTensorRTPythonCompiler) + declare_mlir_python_sources(MLIRTensorRTPythonCompiler.Dialects + ADD_TO_PARENT MLIRTensorRTPythonCompiler) +endif() + # Add a meta target for all documentation generation targets. You can generate # all documentation under `${buildDir}/docs` by building this target. add_custom_target("mlir-tensorrt-doc") diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h index fabfb6d9e..846083f6b 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/OptionsProviders.h @@ -164,6 +164,11 @@ struct DebugOptions : public OptionsProvider { &this->ctx, "debug-only", llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated, OmitFromCLI{}}; + /// If set to `true`, we populate the pass manager instrumentation using + /// global MLIR CL options rather than the local options contained here. + Option useGlobalCLPrintingOptions{&this->ctx, "use-global-cl-options", + llvm::cl::init(false), OmitFromCLI{}}; + /// Apply these options to the current pass manager. void applyToPassManager(mlir::PassManager &pm) const; }; @@ -195,8 +200,6 @@ struct DeviceOptions : public OptionsProvider { llvm::cl::desc("whether to ignore `deviceX` options and instead infer " "them from the host GPU")}; - Status inferFromHost(); - public: DeviceOptions(mlir::OptionsContext &ctx) : OptionsProvider(ctx) { ctx.addOption( diff --git a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h index 58be6a88a..9569a694d 100644 --- a/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h +++ b/mlir-tensorrt/compiler/include/mlir-tensorrt/Compiler/StablehloToExecutable/TensorRTExtension.h @@ -51,6 +51,8 @@ class StablehloToExecutableTensorRTExtension void addToOptions(mlir::OptionsContext &context) final { context.addOption("disable-tensorrt-extension", disabled, llvm::cl::init(false)); + context.addOption("use-global-tensorrt-translation-flags", useGlobalCLFlags, + llvm::cl::init(false)); translationOptions.addToOptions(context); } @@ -63,8 +65,8 @@ class StablehloToExecutableTensorRTExtension /// Options for MLIR-to-TensorRT translation. mlir::tensorrt::TensorRTTranslationOptions translationOptions; - /// Path where we should persist the timing cache to storage. - std::string timingCachePath; + /// Whether to use global CL config for options. + bool useGlobalCLFlags{false}; }; } // namespace mlirtrt::compiler diff --git a/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp b/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp index 3d1ec70c9..2ce57468d 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/OptionsProviders.cpp @@ -27,6 +27,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Support/Timing.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlirtrt; @@ -37,6 +38,18 @@ using namespace mlirtrt::compiler; //===----------------------------------------------------------------------===// void DebugOptions::applyToPassManager(PassManager &pm) const { + // If the options specify to use global MLIR CL flags, then apply those + // options. Otherwise, use our local options. Using global options is only + // possible if the LLVM global command line flag environment is initialized + // correctly. + if (useGlobalCLPrintingOptions) { + if (failed(applyPassManagerCLOptions(pm))) + llvm::report_fatal_error("failed to populate pass manager " + "instrumentation from global CL options"); + applyDefaultTimingPassManagerCLOptions(pm); + return; + } + std::function shouldPrintBeforePass; std::function shouldPrintAfterPass; @@ -78,7 +91,6 @@ void DebugOptions::applyToPassManager(PassManager &pm) const { printAfterFailure, printTreeDir); return; } - pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterChange, printAfterFailure, llvm::errs()); @@ -88,32 +100,6 @@ void DebugOptions::applyToPassManager(PassManager &pm) const { // DeviceOptions //===----------------------------------------------------------------------===// -Status DeviceOptions::inferFromHost() { - cudaDeviceProp properties; - cudaError_t err = cudaGetDeviceProperties(&properties, 0); - if (err != cudaSuccess) - return getStatusWithMsg(StatusCode::InternalError, - "failed to get cuda device properties"); - int ccMajor = 0; - int ccMinor = 0; - err = cudaDeviceGetAttribute( - &ccMajor, cudaDeviceAttr::cudaDevAttrComputeCapabilityMajor, 0); - if (err != cudaSuccess) - return getStatusWithMsg(StatusCode::InternalError, - "failed to get cuda device compute capability"); - err = cudaDeviceGetAttribute( - &ccMinor, cudaDeviceAttr::cudaDevAttrComputeCapabilityMinor, 0); - if (err != cudaSuccess) - return getStatusWithMsg(StatusCode::InternalError, - "failed to get cuda device compute capability"); - // We want SM version as a single number. - int64_t smVersion = ccMajor * 10 + ccMinor; - info.computeCapability = smVersion; - info.maxSharedMemoryPerBlockKb = properties.sharedMemPerBlock / 1024; - info.maxRegistersPerBlock = properties.regsPerBlock; - return Status::getOk(); -} - llvm::Error DeviceOptions::finalizeImpl() { if (shouldInferFromHost) { StatusOr deviceInfo = getDeviceInformationFromHost(); diff --git a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/TensorRTExtension.cpp b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/TensorRTExtension.cpp index 6b8b2dd5d..c9418324d 100644 --- a/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/TensorRTExtension.cpp +++ b/mlir-tensorrt/compiler/lib/Compiler/StablehloToExecutable/TensorRTExtension.cpp @@ -44,8 +44,12 @@ void StablehloToExecutableTensorRTExtension::populatePasses( if (this->disabled) return; + tensorrt::TensorRTTranslationOptions translationOpts = + useGlobalCLFlags ? tensorrt::TensorRTTranslationOptions::fromCLFlags() + : translationOptions; + if (phase == Phase::PreClustering) { - // We must materialize TRT plugin shape regions prior to clustering. + // We must materialize TRT plugion shape regions prior to clustering. pm.addNestedPass(tensorrt::createInferPluginShapesPass()); return; } @@ -61,9 +65,9 @@ void StablehloToExecutableTensorRTExtension::populatePasses( // Simplify and translate functions nested in `tensorrt.module` ops. auto &trtPM = pm.nest(); tensorrt::buildTensorRTModuleTransformationPipeline( - trtPM, translationOptions.enableStronglyTyped); + trtPM, translationOpts.enableStronglyTyped); trtPM.addPass( - tensorrt::createTranslateTensorRTPass(nullptr, translationOptions)); + tensorrt::createTranslateTensorRTPass(nullptr, translationOpts)); return; } diff --git a/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py b/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py index 62537dccc..c86565a12 100644 --- a/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py +++ b/mlir-tensorrt/compiler/test/python/IntegrationTests/Torch/test_torch_add.py @@ -2,12 +2,11 @@ import mlir_tensorrt.compiler.api as compiler import mlir_tensorrt.compiler.ir as ir -import mlir_tensorrt.runtime.api as runtime import mlir_tensorrt.compiler.torch_bridge as torch_bridge - +import mlir_tensorrt.runtime.api as runtime import numpy as np -import torch.nn as nn import torch +import torch.nn as nn class Model(nn.Module): diff --git a/mlir-tensorrt/python/CompilerPackage.cmake b/mlir-tensorrt/python/CompilerPackage.cmake index cb5133ab8..e8417f6a3 100644 --- a/mlir-tensorrt/python/CompilerPackage.cmake +++ b/mlir-tensorrt/python/CompilerPackage.cmake @@ -23,13 +23,10 @@ configure_file( # Structural groupings. ################################################################################ -declare_mlir_python_sources(MLIRTensorRTPythonCompiler) declare_mlir_python_sources(MLIRTensorRTPythonCompiler.Core ADD_TO_PARENT MLIRTensorRTPythonCompiler) declare_mlir_python_sources(MLIRTensorRTPythonCompiler.CompilerAPI ADD_TO_PARENT MLIRTensorRTPythonCompiler) -declare_mlir_python_sources(MLIRTensorRTPythonCompiler.Dialects - ADD_TO_PARENT MLIRTensorRTPythonCompiler) ################################################################################ # Pure python sources and generated code diff --git a/mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py b/mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py index 27f8ff9ff..ba4e69dc3 100644 --- a/mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py +++ b/mlir-tensorrt/python/mlir_tensorrt_compiler/mlir_tensorrt/compiler/torch_bridge.py @@ -1,11 +1,13 @@ -from typing import Optional, Union, Dict, Tuple, Any +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple, Union + from torch import nn from torch.export import ExportedProgram -from dataclasses import dataclass, field -from .compiler_utils import OutputType + from . import fx -from .extras.fx_importer import FxImporter +from .compiler_utils import OutputType from .dialects import torch as torch_d +from .extras.fx_importer import FxImporter __all__ = ["TorchInput", "get_mlir_module_from_torch_module"] diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h index 6d1391310..08672510c 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/Target/TensorRTEncodingOpInterface/NetworkEncoder.h @@ -267,7 +267,7 @@ template || std::is_same_v, T *> = nullptr> static nvinfer1::Dims getNvInferDims(ArrayRef arrayRef) { - assert(arrayRef.size() < nvinfer1::Dims::MAX_DIMS && + assert(arrayRef.size() <= nvinfer1::Dims::MAX_DIMS && "input array exceeds max dims"); nvinfer1::Dims dims; dims.nbDims = arrayRef.size(); diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp index de5f1f099..bc80b8eef 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/Transforms/BroadcastElimination.cpp @@ -1,8 +1,7 @@ //===- BroadcastElimination.cpp----------------------------------*- c++ -*-===// // -// SPDX-FileCopyrightText: Copyright 2024 NVIDIA CORPORATION & AFFILIATES. -// All rights reserved. -// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright 2024 - 2025 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,8 +18,7 @@ //===----------------------------------------------------------------------===// #include "mlir-tensorrt-dialect/TensorRT/IR/TensorRTDialect.h" #include "mlir-tensorrt-dialect/TensorRT/Transforms/Passes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/PatternMatch.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -55,40 +53,58 @@ static FailureOr exchangeCollapseRankAndBroadcast( collapseOp.getInputShapeDimIndicesOfRemovedDims(); if (removedDims.empty()) return failure(); + llvm::sort(removedDims, [](int64_t lhs, int64_t rhs) { return lhs > rhs; }); - // Let's just focus on the first removed dimension. - int64_t focusDim = removedDims.front(); SmallVector bcastInputShape(bcastOp.getInput().getType().getShape()); + SmallVector bcastResultShape(bcastOp.getType().getShape()); SmallVector broadcastDims(bcastOp.getBroadcastDims()); - // If it is broadcasted, then we must remove it at the input. Drop this dim - // from the list, and all indices higher than this must be decremented. - int64_t *bcastDimIter = llvm::find(broadcastDims, focusDim); - // TODO: can we handle this case? - if (bcastDimIter == broadcastDims.end()) - return failure(); + auto getBroadcastDimsIndex = [&](int64_t dim) -> std::optional { + auto it = llvm::find(broadcastDims, dim); + if (it != broadcastDims.end()) + return std::distance(broadcastDims.begin(), it); + + return {}; + }; + + bool changed = false; + for (int64_t removedDim : removedDims) { + std::optional inputShapeDimIdx = + getBroadcastDimsIndex(removedDim); + if (!inputShapeDimIdx) + continue; - // Find which input dimension this corresponds to. - unsigned inputShapeDimIdx = - std::distance(broadcastDims.begin(), bcastDimIter); - assert(bcastInputShape[inputShapeDimIdx] == 1); - - // Erase this broadcast dimension. - bcastInputShape.erase(bcastInputShape.begin() + inputShapeDimIdx); - broadcastDims.erase(bcastDimIter); - // Adjust all the other broadcast dimensions. - for (auto &bcastDim : broadcastDims) { - if (bcastDim > focusDim) - bcastDim--; + assert((bcastInputShape[*inputShapeDimIdx] == 1 || + ShapedType::isDynamic(bcastInputShape[*inputShapeDimIdx])) && + "expected size-1 dimension"); + + // Erase this broadcast dimension. + changed = true; + bcastInputShape.erase(bcastInputShape.begin() + *inputShapeDimIdx); + broadcastDims.erase(broadcastDims.begin() + *inputShapeDimIdx); + bcastResultShape.erase(bcastResultShape.begin() + removedDim); + // Adjust all the other broadcast dimensions. + for (int64_t &bcastDim : broadcastDims) { + if (bcastDim > removedDim) + bcastDim--; + } } - Type newCollapseShapeType = - RankedTensorType::Builder( - cast(bcastOp.getInput().getType())) - .setShape(bcastInputShape); + if (!changed) + return failure(); + + RankedTensorType newCollapseShapeType = + bcastOp.getInput().getType().clone(bcastInputShape); + + Value newBcastInput; + if (getReassociationIndicesForCollapse( + bcastOp.getInput().getType().getShape(), bcastInputShape)) + newBcastInput = rewriter.create( + bcastOp.getLoc(), newCollapseShapeType, bcastOp.getInput()); + else + newBcastInput = rewriter.create( + bcastOp.getLoc(), newCollapseShapeType, bcastOp.getInput()); - Value newBcastInput = rewriter.create( - bcastOp.getLoc(), newCollapseShapeType, bcastOp.getInput()); auto newBroadcastOp = rewriter.create( collapseOp.getLoc(), collapseOp.getType(), newBcastInput, broadcastDims); diff --git a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir index bba942f64..ac3b1d6b6 100644 --- a/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir +++ b/mlir-tensorrt/tensorrt/test/Dialect/TensorRT/broadcast-elimination.mlir @@ -13,6 +13,79 @@ func.func @pushdown_broadcast(%arg0: tensor<1x1x10xf32>, %arg1: tensor<100x10xf3 // ----- +func.func @pushdown_broadcast_collapse_shape_multiple_collapsed_dims() -> tensor<96x512x10x10xf32> { + %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x96x1x512x1x1xf32> + %0 = tensorrt.broadcast %cst_f32 broadcast_dims<0, 1, 2, 3, 4, 5> : tensor<1x96x1x512x1x1xf32> to tensor<1x96x1x512x10x10xf32> + %1 = tensorrt.collapse_rank %0 : tensor<1x96x1x512x10x10xf32> to tensor<96x512x10x10xf32> + return %1 : tensor<96x512x10x10xf32> +} + +// CHECK-LABEL: func.func @pushdown_broadcast_collapse_shape_multiple_collapsed_dims +// CHECK-NEXT: %[[cst_f32:.+]] = tensorrt.constant {{.*}} : tensor<96x512x1x1xf32> +// CHECK-NEXT: %[[v0:.+]] = tensorrt.broadcast %[[cst_f32]] broadcast_dims<0, 1, 2, 3> : tensor<96x512x1x1xf32> to tensor<96x512x10x10xf32> +// CHECK-NEXT: return %[[v0]] : tensor<96x512x10x10xf32> + +// ----- + +// For this test case, the dimension removed by the collapse_rank (dim #1) is not +// part of the broadcast dimensions. Check that the 'PushDownBroadcastReduceRankOp' correctly +// exits, leaving the other patterns to simplify the IR. +func.func @pushdown_transposed_broadcast_collapse_3() -> tensor<1x96x1x1xf32> { + %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x96x1xf32> + %0 = tensorrt.broadcast %cst_f32 broadcast_dims<0, 2, 4> : tensor<1x96x1xf32> to tensor<1x1x96x1x1xf32> + %1 = tensorrt.collapse_rank %0 : tensor<1x1x96x1x1xf32> to tensor<1x96x1x1xf32> + return %1 : tensor<1x96x1x1xf32> +} + +// CHECK-LABEL: func.func @pushdown_transposed_broadcast_collapse_3 +// CHECK-NEXT: %[[cst_f32:.+]] = tensorrt.constant {{.*}} : tensor<1x96x1x1xf32> +// CHECK-NEXT: return %[[cst_f32]] : tensor<1x96x1x1xf32> + +// ----- + +func.func @pushdown_transposed_broadcast_collapse_4() -> tensor<96x4xf32> { + %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x96x1xf32> + %0 = tensorrt.broadcast %cst_f32 broadcast_dims<2, 1, 0> : tensor<1x96x1xf32> to tensor<1x96x1x4xf32> + %1 = tensorrt.collapse_rank %0 : tensor<1x96x1x4xf32> to tensor<96x4xf32> + return %1 : tensor<96x4xf32> +} + +// CHECK-LABEL: func.func @pushdown_transposed_broadcast_collapse_4 +// CHECK: %[[cst_f32:.+]] = tensorrt.constant {{.*}} : tensor<96x1xf32> +// CHECK: %[[v0:.+]] = tensorrt.broadcast %[[cst_f32]] broadcast_dims<0, 1> : tensor<96x1xf32> to tensor<96x4xf32> +// CHECK: return %[[v0]] : tensor<96x4xf32> + +// ----- + +func.func @pushdown_transposed_broadcast_collapse_transpose() -> tensor<96x96xf32> { + %cst_f32 = tensorrt.constant dense<1.000000e+00> : tensor<1x96x1x96xf32> + %0 = tensorrt.broadcast %cst_f32 broadcast_dims<2, 3, 0, 1> : tensor<1x96x1x96xf32> to tensor<1x96x1x96xf32> + %1 = tensorrt.collapse_rank %0 : tensor<1x96x1x96xf32> to tensor<96x96xf32> + return %1 : tensor<96x96xf32> +} + +// CHECK: #[[$map:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-LABEL: func.func @pushdown_transposed_broadcast_collapse_transpose +// CHECK-NEXT: %[[cst_f32:.+]] = tensorrt.constant {{.*}} : tensor<96x96xf32> +// CHECK-NEXT: %[[v0:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[cst_f32]] : +// CHECK-NEXT: return %[[v0]] : tensor<96x96xf32> + +// ----- + +func.func @pushdown_transposed_broadcast_collapse_dynamic(%arg0: tensor<1x96x?x96xf32>) -> tensor<96x96xf32> { + %0 = tensorrt.broadcast %arg0 broadcast_dims<2, 3, 0, 1> : tensor<1x96x?x96xf32> to tensor<1x96x1x96xf32> + %1 = tensorrt.collapse_rank %0 : tensor<1x96x1x96xf32> to tensor<96x96xf32> + return %1 : tensor<96x96xf32> +} + +// CHECK-LABEL: func.func @pushdown_transposed_broadcast_collapse_dynamic +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x96x?x96xf32>) +// CHECK-NEXT: %[[v0:.+]] = tensorrt.reshape %[[arg0]] : tensor<1x96x?x96xf32> to tensor<96x96xf32> +// CHECK-NEXT: %[[v1:.+]] = tensorrt.transpose {permutation = #[[$map]]} %[[v0]] : +// CHECK-NEXT: return %[[v1]] : tensor<96x96xf32> + +// ----- + func.func @broadcast_ewise(%arg0: tensor<128x128xf32>, %arg1: tensor<1x128xf32>) -> tensor<128x128xf32> { %0 = tensorrt.broadcast %arg1 broadcast_dims<0, 1> : tensor<1x128xf32> to tensor<128x128xf32> %1 = tensorrt.element_wise (%arg0, %0 : tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xf32> @@ -263,7 +336,6 @@ func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor, } // CHECK: #[[$map:.+]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> -// CHECK: module { // CHECK-LABEL: func.func @broadcast_dynamic_expand_shape_regression // CHECK-SAME: (%[[arg0:.+]]: tensor, %[[arg1:.+]]: tensor, %[[arg2:.+]]: tensor, %[[arg3:.+]]: tensor<4xi32>) -> tensor { // CHECK: %[[cst_i32:.+]] = tensorrt.constant dense<1> : tensor<1xi32> @@ -274,4 +346,4 @@ func.func @broadcast_dynamic_expand_shape_regression(%arg0: tensor, // CHECK: %[[v4:.+]] = tensorrt.concatenation {axis = 0 : i32} ins(%[[cst_i32]], %[[v2]], %[[cst_i32]], %[[v3]] : tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> // CHECK: %[[v5:.+]] = tensorrt.reshape %[[v0]] shape(%[[v4]]: tensor<4xi32>) : tensor to tensor<1x?x1x?xf16> // CHECK: %[[v6:.+]] = tensorrt.select ins(%[[arg0]], %[[arg2]], %[[v5]] : tensor, tensor, tensor<1x?x1x?xf16>) -> tensor -// CHECK: return %[[v6]] : tensor \ No newline at end of file +// CHECK: return %[[v6]] : tensor diff --git a/mlir-tensorrt/tensorrt/test/Target/TensorRT/max-tensor-rank.mlir b/mlir-tensorrt/tensorrt/test/Target/TensorRT/max-tensor-rank.mlir new file mode 100644 index 000000000..ab72d2f2f --- /dev/null +++ b/mlir-tensorrt/tensorrt/test/Target/TensorRT/max-tensor-rank.mlir @@ -0,0 +1,14 @@ +// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \ +// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 %s | FileCheck %s + +!tensor_type = tensor<8x8x8x8x8x8x8x8xf32> + +// Check that we can convert a network with tensor ranks of max allowed by TensorRT (rank 8). +func.func @trt_max_rank(%arg1: !tensor_type, %arg2: !tensor_type) -> (!tensor_type) { + %1 = tensorrt.element_wise (%arg1, %arg2 : !tensor_type, !tensor_type) + -> !tensor_type + return %1 : !tensor_type +} + +// CHECK-LABEL: @trt_max_rank +// CHECK-SAME: tensorrt.engine