Skip to content

Commit

Permalink
Migrate internal changes (#12)
Browse files Browse the repository at this point in the history
- Fix most non-functional style casts in order to conform with upstream
  deprecations
- Fix minor build/documentation issues
- add psutil and build packages

Co-authored-by: Sagar Shelke

Signed-off-by: Christopher Bate <[email protected]>
  • Loading branch information
christopherbate authored Aug 2, 2024
1 parent 31115b4 commit d532354
Show file tree
Hide file tree
Showing 48 changed files with 356 additions and 337 deletions.
2 changes: 1 addition & 1 deletion mlir-tensorrt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ git apply ../build_tools/llvm-project.patch

# Do the build
cd ..
./build_tools/scripts/build_mlir.sh llvm-project build/llvm
./build_tools/scripts/build_mlir.sh llvm-project build/llvm-project
```

2. Build the project and run all tests
Expand Down
32 changes: 32 additions & 0 deletions mlir-tensorrt/build_tools/scripts/cicd_build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env bash
set -ex
set -o pipefail

REPO_ROOT=$(pwd)
BUILD_DIR="${BUILD_DIR:=${REPO_ROOT}/build/mlir-tensorrt}"

ENABLE_NCCL=${ENABLE_NCCL:OFF}
RUN_LONG_TESTS=${RUN_LONG_TESTS:-False}
LLVM_LIT_ARGS=${LLVM_LIT_ARGS:-"-v --xunit-xml-output ${BUILD_DIR}/test-results.xml --timeout=1200 --time-tests -Drun_long_tests=${RUN_LONG_TESTS}"}
DOWNLOAD_TENSORRT_VERSION=${DOWNLOAD_TENSORRT_VERSION:-10.0.0.6}
ENABLE_ASAN=${ENABLE_ASAN:-OFF}

echo "Using DOWNLOAD_TENSORRT_VERSION=${DOWNLOAD_TENSORRT_VERSION}"
echo "Using LLVM_LIT_ARGS=${LLVM_LIT_ARGS}"

cmake -GNinja -B "${BUILD_DIR}" -S "${REPO_ROOT}" \
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
-DMLIR_TRT_USE_LINKER=lld -DCMAKE_BUILD_TYPE=RelWithDebInfo \
-DMLIR_TRT_PACKAGE_CACHE_DIR=$PWD/.cache.cpm \
-DMLIR_TRT_ENABLE_PYTHON=ON \
-DMLIR_TRT_ENABLE_NCCL=${ENABLE_NCCL} \
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION="$DOWNLOAD_TENSORRT_VERSION" \
-DLLVM_LIT_ARGS="${LLVM_LIT_ARGS}" \
-DENABLE_ASAN="${ENABLE_ASAN}" \
-DMLIR_DIR=${REPO_ROOT}/build/llvm-project/lib/cmake/mlir \
-DCMAKE_PLATFORM_NO_VERSIONED_SONAME=ON

echo "==== Running Build ==="
ninja -C ${BUILD_DIR} -k 0 check-mlir-executor
ninja -C ${BUILD_DIR} -k 0 check-mlir-tensorrt
ninja -C ${BUILD_DIR} -k 0 check-mlir-tensorrt-dialect
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ struct CudaBlasRunGemmOpConverter
SmallVector<Value> newOperands = {adaptor.getHandle(), adaptor.getStream()};
newOperands.push_back(adaptor.getAlgo());
auto createMemRefAndExractPtr = [&](Value oldVal, Value newVal) {
auto memrefType = oldVal.getType().cast<MemRefType>();
auto memrefType = cast<MemRefType>(oldVal.getType());
if (!memrefType)
return failure();
assert(newVal.getType().isa<TableType>());
assert(isa<TableType>(newVal.getType()));
executor::MemRefDescriptor memref(newVal, memrefType);
newOperands.push_back(memref.alignedPtr(b));
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ add_mlir_tensorrt_library(MLIRTensorRTPlanToExecutor
MLIRTensorRTExecutorDialect
MLIRTensorRTPlanDialect
MLIRTransforms
MLIRSCFTransforms
)
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,8 @@ struct ConstantOpConverter : public OpConversionPattern<arith::ConstantOp> {
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = getTypeConverter()
->convertType(op.getType())
.dyn_cast_or_null<RankedTensorType>();
auto resultType = dyn_cast_or_null<RankedTensorType>(
getTypeConverter()->convertType(op.getType()));
if (!resultType)
return failure();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct StablehloRewriteConcat
matchAndRewrite(stablehlo::ConcatenateOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
if (!llvm::all_of(op->getOperandTypes(), [](Type t) {
return t.cast<RankedTensorType>().getRank() == 1;
return cast<RankedTensorType>(t).getRank() == 1;
}))
return failure();
rewriter.replaceOp(op, adaptor.getFlatOperands(),
Expand All @@ -133,13 +133,13 @@ struct StablehloRewriteConcat
/// scalar `type`.
static Attribute getScalarValue(RewriterBase &rewriter, Type type,
int64_t idx) {
if (type.isa<FloatType>())
if (isa<FloatType>(type))
return rewriter.getFloatAttr(type, static_cast<double>(idx));
if (type.isa<IndexType>())
if (isa<IndexType>(type))
return rewriter.getIndexAttr(idx);
if (auto integerType = type.dyn_cast<IntegerType>())
if (auto integerType = dyn_cast<IntegerType>(type))
return rewriter.getIntegerAttr(
type, APInt(type.cast<IntegerType>().getWidth(), idx));
type, APInt(cast<IntegerType>(type).getWidth(), idx));
return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static void inlineStablehloRegionIntoSCFRegion(PatternRewriter &rewriter,
static Value extractScalarFromTensorValue(OpBuilder &b, Value tensor) {
Location loc = tensor.getLoc();
// If ranked tensor, first collapse shape.
if (tensor.getType().cast<RankedTensorType>().getRank() != 0)
if (cast<RankedTensorType>(tensor.getType()).getRank() != 0)
tensor = b.create<tensor::CollapseShapeOp>(
loc, tensor, SmallVector<ReassociationIndices>());

Expand Down Expand Up @@ -129,10 +129,10 @@ static scf::IfOp createNestedCases(int currentIdx, stablehlo::CaseOp op,

// Determine if the current index matches the case index.
auto scalarType = idxValue.getType();
auto shapedType = scalarType.cast<ShapedType>();
auto shapedType = cast<ShapedType>(scalarType);
auto constAttr = DenseElementsAttr::get(
shapedType,
{outerBuilder.getI32IntegerAttr(currentIdx).cast<mlir::Attribute>()});
{cast<mlir::Attribute>(outerBuilder.getI32IntegerAttr(currentIdx))});
Value currentIdxVal = outerBuilder.create<stablehlo::ConstantOp>(
loc, idxValue.getType(), constAttr);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ConvertChloErfToTensorRT
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto operand = adaptor.getOperand();
auto operandType = operand.getType().cast<RankedTensorType>();
auto operandType = cast<RankedTensorType>(operand.getType());
Type resultType = typeConverter->convertType(op.getType());
if (!resultType)
return failure();
Expand Down Expand Up @@ -74,7 +74,7 @@ struct ConvertChloTopKOpToTensorRT
matchAndRewrite(chlo::TopKOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operand = adaptor.getOperand();
RankedTensorType operandType = operand.getType().cast<RankedTensorType>();
RankedTensorType operandType = cast<RankedTensorType>(operand.getType());

int64_t rank = operandType.getRank();
uint64_t axis = static_cast<uint64_t>(rank) - 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct ConvertCaseOp : public ConvertHloOpToTensorRTPattern<stablehlo::CaseOp> {
if (!isa_and_nonnull<tensorrt::IdentityOp, stablehlo::ConvertOp>(op))
return false;
RankedTensorType producerType =
op->getOperand(0).getType().cast<RankedTensorType>();
cast<RankedTensorType>(op->getOperand(0).getType());
return isa_and_nonnull<tensorrt::IdentityOp, stablehlo::ConvertOp>(op) &&
producerType.getElementType().isInteger(1) &&
producerType.getNumElements() == 1;
Expand Down
Loading

0 comments on commit d532354

Please sign in to comment.