From b129f799aa75227963440c6604afe5a6ece6bd17 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Wed, 29 Jan 2025 17:23:13 -0800 Subject: [PATCH] Integrate StableHLO at openxla/stablehlo@48a1e14e PiperOrigin-RevId: 721167917 --- third_party/stablehlo/temporary.patch | 4814 ------------------------- third_party/stablehlo/workspace.bzl | 4 +- 2 files changed, 2 insertions(+), 4816 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 90d4d84f5e8a7..340477167c84b 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,15 +1,3 @@ -diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel ---- stablehlo/BUILD.bazel -+++ stablehlo/BUILD.bazel -@@ -1547,7 +1547,7 @@ - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", -- td_file = "stablehlo/dialect/VhloAttrs.td", -+ td_file = "stablehlo/dialect/VhloEnums.td", - deps = [ - ":vhlo_ops_td_files", - ], diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll --- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll +++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll @@ -79,4806 +67,4 @@ diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeTo Pattern => replace op(input0 : Value<_: Tosa_Tensor>, input1 : Value<_: Tosa_Tensor>) -diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/stablehlo/dialect/AssemblyFormat.cpp ---- stablehlo/stablehlo/dialect/AssemblyFormat.cpp -+++ stablehlo/stablehlo/dialect/AssemblyFormat.cpp -@@ -860,6 +860,29 @@ - return parser.parseSymbolName(target); - } - -+void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol, -+ int64_t ulps, Attribute mode) { -+ odsPrinter << "<"; -+ if (!atol.isZero()) { -+ odsPrinter << "atol = "; -+ odsPrinter.printFloat(atol); -+ odsPrinter << ", "; -+ } -+ if (!rtol.isZero()) { -+ odsPrinter << "rtol = "; -+ odsPrinter.printFloat(rtol); -+ odsPrinter << ", "; -+ } -+ if (ulps != 0) { -+ odsPrinter << "ulps = "; -+ odsPrinter << ulps; -+ odsPrinter << ", "; -+ } -+ odsPrinter << "mode = "; -+ odsPrinter.printAttribute(mode); -+ odsPrinter << ">"; -+} -+ - void printTypeExtensions(BoundedAttrInterface attr, DialectAsmPrinter& os) { - os << "bounds<"; - llvm::interleaveComma(attr.getBounds(), os, -diff --ruN a/stablehlo/stablehlo/dialect/AssemblyFormat.h b/stablehlo/stablehlo/dialect/AssemblyFormat.h ---- stablehlo/stablehlo/dialect/AssemblyFormat.h -+++ stablehlo/stablehlo/dialect/AssemblyFormat.h -@@ -378,6 +378,65 @@ - return success(); - } - -+// ResultAccuracyAttr - Custom printing and parsing for ResultAccuracyAttr. -+// -+// ResultAccuractAttr ::= `<` OptAtolAccuracy OptRtolAccuracy -+// OptUlpAccuracy ModeAccuracy `>` -+// OptAtolAccuracy ::= `atol` `=` APFloat `, ` | eps -+// OptRtolAccuracy ::= `rtol` `=` APFloat `, ` | eps -+// OptUlpAccuracy ::= `ulps` `=` int64_t `, ` | eps -+// ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr -+void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol, -+ int64_t ulps, Attribute mode); -+ -+template -+Attribute parseResultAccuracyAttr(AsmParser& parser, Type type) { -+ APFloat resultAtol = APFloat::getZero(APFloat::IEEEdouble()); -+ APFloat resultRtol = APFloat::getZero(APFloat::IEEEdouble()); -+ int64_t resultUlps = 0; -+ -+ // Parse literal '<' -+ if (parser.parseLess()) return {}; -+ -+ // OptAtolAccuracy -+ if (succeeded(parser.parseOptionalKeyword("atol"))) { -+ double value; -+ if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma()) -+ return {}; -+ resultAtol = APFloat(value); -+ } -+ -+ // OptRtolAccuracy -+ if (succeeded(parser.parseOptionalKeyword("rtol"))) { -+ double value; -+ if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma()) -+ return {}; -+ resultRtol = APFloat(value); -+ } -+ -+ // OptUlpAccuracy -+ if (succeeded(parser.parseOptionalKeyword("ulps"))) { -+ int64_t value; -+ if (parser.parseEqual() || parser.parseInteger(value) || -+ parser.parseComma()) -+ return {}; -+ resultUlps = value; -+ } -+ -+ // ModeAccuracy -+ ModeTy modeAttr; -+ if (parser.parseKeyword("mode") || parser.parseEqual() || -+ parser.parseAttribute(modeAttr)) { -+ return {}; -+ } -+ -+ // Parse literal '>' -+ if (parser.parseGreater()) return {}; -+ return parser.getChecked( -+ parser.getCurrentLocation(), parser.getContext(), resultAtol, resultRtol, -+ resultUlps, modeAttr); -+} -+ - } // namespace hlo - } // namespace mlir - -diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp ---- stablehlo/stablehlo/dialect/Base.cpp -+++ stablehlo/stablehlo/dialect/Base.cpp -@@ -780,5 +780,22 @@ - numScales == rankedType.getDimSize(quantDim)); - } - -+bool hasSingleBoundedDimension(Type type) { -+ RankedTensorType rankedType = dyn_cast(type); -+ auto boundedAttr = -+ dyn_cast_or_null(rankedType.getEncoding()); -+ if (!boundedAttr) return false; -+ -+ // Count if bounded attr size is not kDynamic -+ int64_t numBoundedDims = llvm::count_if( -+ boundedAttr.getBounds(), -+ [](int64_t bound) { return !ShapedType::isDynamic(bound); }); -+ // Also check that there are only bounded dims and no unbounded dims. -+ int64_t numDynamicDims = llvm::count_if( -+ rankedType.getShape(), -+ [](int64_t bound) { return ShapedType::isDynamic(bound); }); -+ return numBoundedDims == 1 && numDynamicDims == 1; -+} -+ - } // namespace hlo - } // namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/Base.h b/stablehlo/stablehlo/dialect/Base.h ---- stablehlo/stablehlo/dialect/Base.h -+++ stablehlo/stablehlo/dialect/Base.h -@@ -101,6 +101,9 @@ - // mentioned in the StableHLO specification. - bool isValidQuantizedDimension(Type type); - -+// Returns true if the given type has a single bounded dimension. -+bool hasSingleBoundedDimension(Type type); -+ - // TODO(zhouxin) Move type inference related methods to TypeInference.cpp - - std::pair inferConcatenatedDimAndBound(int64_t leftSize, -diff --ruN a/stablehlo/stablehlo/dialect/Base.td b/stablehlo/stablehlo/dialect/Base.td ---- stablehlo/stablehlo/dialect/Base.td -+++ stablehlo/stablehlo/dialect/Base.td -@@ -29,6 +29,20 @@ - def I32RankedTensor : RankedTensorOf<[I32]>; - - def UI32RankedTensor : RankedTensorOf<[UI32]>; -+ -+//===----------------------------------------------------------------------===// -+// HLO type constraints. -+//===----------------------------------------------------------------------===// -+ -+// Note: Bounded dynamisms is largely unspecced and this feature needs more -+// thoguht as it is adopted to modern frameworks. The current support is -+// designed to allow existing TF programs to be representable in StableHLO and -+// is subject to change as a formal design for boudned dynamism is developed. -+def HLO_HasSingleBoundedDimensionPred -+ : CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">; -+ -+def HLO_HasStaticOrSingleBoundedShapePred -+ : Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>; - - //===----------------------------------------------------------------------===// - // HLO type definitions. -@@ -267,6 +281,9 @@ - def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], - [IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">; - -+def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt], -+ [IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">; -+ - def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>; - - def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>; -diff --ruN a/stablehlo/stablehlo/dialect/CMakeLists.txt b/stablehlo/stablehlo/dialect/CMakeLists.txt ---- stablehlo/stablehlo/dialect/CMakeLists.txt -+++ stablehlo/stablehlo/dialect/CMakeLists.txt -@@ -190,7 +190,7 @@ - set(LLVM_TARGET_DEFINITIONS VhloOps.td) - mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls) - mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs) --set(LLVM_TARGET_DEFINITIONS VhloAttrs.td) -+set(LLVM_TARGET_DEFINITIONS VhloEnums.td) - mlir_tablegen(VhloAttrInterfaces.h.inc -gen-attr-interface-decls) - mlir_tablegen(VhloAttrInterfaces.cpp.inc -gen-attr-interface-defs) - set(LLVM_TARGET_DEFINITIONS VhloTypes.td) -diff --ruN a/stablehlo/stablehlo/dialect/StablehloAttrs.td b/stablehlo/stablehlo/dialect/StablehloAttrs.td ---- stablehlo/stablehlo/dialect/StablehloAttrs.td -+++ stablehlo/stablehlo/dialect/StablehloAttrs.td -@@ -19,6 +19,7 @@ - - include "mlir/IR/OpBase.td" - include "mlir/IR/TensorEncoding.td" -+include "stablehlo/dialect/StablehloTypes.td" - - def StableHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> { - let parser = "parseDimSizes($_parser)"; -@@ -209,4 +210,18 @@ - let hasCustomAssemblyFormat = 1; - } - -+def StableHLO_ResultAccuracyAttr : AttrDef { -+ let mnemonic = "result_accuracy"; -+ let summary = "The requested accuracy for transcendental unary ops."; -+ let parameters = (ins -+ "APFloat":$atol, -+ "APFloat":$rtol, -+ "int64_t":$ulps, -+ StableHLO_ResultAccuracyModeAttr:$mode -+ ); -+ let hasCustomAssemblyFormat = 1; -+ let genVerifyDecl = 1; -+ let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))"; -+} -+ - #endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS -diff --ruN a/stablehlo/stablehlo/dialect/StablehloBytecode.cpp b/stablehlo/stablehlo/dialect/StablehloBytecode.cpp ---- stablehlo/stablehlo/dialect/StablehloBytecode.cpp -+++ stablehlo/stablehlo/dialect/StablehloBytecode.cpp -@@ -18,6 +18,7 @@ - #include - #include - -+#include "llvm/ADT/APFloat.h" - #include "llvm/ADT/SmallVector.h" - #include "llvm/ADT/StringRef.h" - #include "llvm/ADT/TypeSwitch.h" -@@ -180,6 +181,18 @@ - /// allowImpreciseAccumulation : svarint - /// } - kDotAlgorithmAttr = 15, -+ -+ // ResultAccuracyModeAttr { -+ // mode: varint (encoded enum) -+ // } -+ kResultAccuracyModeAttr = 16, -+ -+ // ResultAccuracyAttr { -+ // atol: APFloat -+ // rtol: APFloat -+ // ulps: svarint -+ // } -+ kResultAccuracyAttr = 17, - }; - - /// This enum contains marker codes used to indicate which type is -@@ -241,6 +254,10 @@ - OutputOperandAliasAttr readOutputOperandAliasAttr( - DialectBytecodeReader &reader) const; - PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const; -+ ResultAccuracyAttr readResultAccuracyAttr( -+ DialectBytecodeReader &reader) const; -+ ResultAccuracyModeAttr readResultAccuracyModeAttr( -+ DialectBytecodeReader &reader) const; - RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const; - RngDistributionAttr readRngDistributionAttr( - DialectBytecodeReader &reader) const; -@@ -264,6 +281,8 @@ - DialectBytecodeWriter &writer) const; - void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const; - void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const; -+ void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const; -+ void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const; - void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const; - void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const; - void write(ScatterDimensionNumbersAttr attr, -@@ -327,6 +346,10 @@ - return readOutputOperandAliasAttr(reader); - case stablehlo_encoding::kPrecisionAttr: - return readPrecisionAttr(reader); -+ case stablehlo_encoding::kResultAccuracyAttr: -+ return readResultAccuracyAttr(reader); -+ case stablehlo_encoding::kResultAccuracyModeAttr: -+ return readResultAccuracyModeAttr(reader); - case stablehlo_encoding::kRngAlgorithmAttr: - return readRngAlgorithmAttr(reader); - case stablehlo_encoding::kRngDistributionAttr: -@@ -352,13 +375,13 @@ - .Case( -- [&](auto attr) { -- LOG_WRITE_CALL; -- write(attr, writer); -- return success(); -- }) -+ PrecisionAttr, ResultAccuracyAttr, ResultAccuracyModeAttr, -+ RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr, -+ TransposeAttr, TypeExtensionsAttr>([&](auto attr) { -+ LOG_WRITE_CALL; -+ write(attr, writer); -+ return success(); -+ }) - .Default([&](Attribute) { - LOG_NOT_IMPLEMENTED; - return failure(); -@@ -806,6 +829,55 @@ - } - } - -+//===----------------------------------------------------------------------===// -+// ResultAccuracyModeAttr -+ -+ResultAccuracyModeAttr StablehloBytecodeInterface::readResultAccuracyModeAttr( -+ DialectBytecodeReader &reader) const { -+ LOG_READ_CALL; -+ return hlo::bytecode::readEnumAttribute( -+ reader, getContext(), -+ [](uint32_t val) { return symbolizeResultAccuracyMode(val); }); -+} -+ -+void StablehloBytecodeInterface::write(ResultAccuracyModeAttr attr, -+ DialectBytecodeWriter &writer) const { -+ writer.writeVarInt(stablehlo_encoding::kResultAccuracyModeAttr); -+ hlo::bytecode::writeEnumAttribute(attr, writer); -+} -+ -+//===----------------------------------------------------------------------===// -+// ResultAccuracyAttr -+ -+ResultAccuracyAttr StablehloBytecodeInterface::readResultAccuracyAttr( -+ DialectBytecodeReader &reader) const { -+ LOG_READ_CALL; -+ FailureOr atol; -+ FailureOr rtol; -+ int64_t ulps; -+ ResultAccuracyModeAttr mode; -+ if (failed(atol = -+ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || -+ failed(rtol = -+ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || -+ failed(reader.readSignedVarInt(ulps)) || -+ failed(reader.readAttribute(mode))) { -+ mlir::emitWarning(mlir::UnknownLoc::get(getContext())) -+ << "failed to read APFloat for atol"; -+ return ResultAccuracyAttr(); -+ } -+ return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode); -+} -+ -+void StablehloBytecodeInterface::write(ResultAccuracyAttr attr, -+ DialectBytecodeWriter &writer) const { -+ writer.writeVarInt(stablehlo_encoding::kResultAccuracyAttr); -+ writer.writeAPFloatWithKnownSemantics(attr.getAtol()); -+ writer.writeAPFloatWithKnownSemantics(attr.getRtol()); -+ writer.writeSignedVarInt(attr.getUlps()); -+ writer.writeAttribute(attr.getMode()); -+} -+ - } // namespace - - void addBytecodeInterface(StablehloDialect *dialect) { -diff --ruN a/stablehlo/stablehlo/dialect/StablehloEnums.td b/stablehlo/stablehlo/dialect/StablehloEnums.td ---- stablehlo/stablehlo/dialect/StablehloEnums.td -+++ stablehlo/stablehlo/dialect/StablehloEnums.td -@@ -45,6 +45,29 @@ - // TODO(b/129153247) See if it's possible to also validate the size. - def StableHLO_PrecisionConfigAttr: - TypedArrayAttrBase; -+ -+//===----------------------------------------------------------------------===// -+// Result Accuracy enum definitions. -+//===----------------------------------------------------------------------===// -+ -+def STABLEHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; -+def STABLEHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; -+def STABLEHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; -+ -+def StableHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode", -+ "XLA result accuracy mode.", -+ [ -+ STABLEHLO_RESULT_ACCURACY_DEFAULT, -+ STABLEHLO_RESULT_ACCURACY_HIGHEST, -+ STABLEHLO_RESULT_ACCURACY_TOLERANCE -+ ]> { -+ let genSpecializedAttr = 0; -+ let cppNamespace = "::mlir::stablehlo"; -+} -+ -+def StableHLO_ResultAccuracyModeAttr : EnumAttr { -+ let assemblyFormat = "`<` $value `>`"; -+} - - //===----------------------------------------------------------------------===// - // Fast Fourier Transform Type enum definitions. -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp ---- stablehlo/stablehlo/dialect/StablehloOps.cpp -+++ stablehlo/stablehlo/dialect/StablehloOps.cpp -@@ -792,6 +792,29 @@ - allowImpreciseAccumulation); - } - -+// ===----------------------------------------------------------------------===// -+// ExpOp -+//===----------------------------------------------------------------------===// -+ -+LogicalResult ResultAccuracyAttr::verify( -+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, -+ APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) { -+ return hlo::verifyResultAccuracyAttr( -+ emitError, atol, rtol, ulps, -+ stringifyResultAccuracyMode(mode.getValue())); -+} -+ -+LogicalResult ExpOp::verify() { -+ if (auto attr = getResultAccuracyAttr()) { -+ if (failed(ResultAccuracyAttr::verify([&] { return emitError(); }, -+ attr.getAtol(), attr.getRtol(), -+ attr.getUlps(), attr.getMode()))) { -+ return failure(); -+ } -+ } -+ return success(); -+} -+ - //===----------------------------------------------------------------------===// - // FftOp - //===----------------------------------------------------------------------===// -@@ -3127,6 +3150,20 @@ - lhsContractingDimensions, rhsContractingDimensions); - } - -+// ===----------------------------------------------------------------------===// -+// Custom unary op -+// ===----------------------------------------------------------------------===// -+ -+void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const { -+ hlo::printResultAccuracyAttr(odsPrinter, getAtol(), getRtol(), getUlps(), -+ getMode()); -+} -+ -+Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) { -+ return hlo::parseResultAccuracyAttr(parser, type); -+} -+ - namespace { - enum NonSpatialDim : int64_t { - IOBatch = -1, // Input or output batch dimension -diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.td b/stablehlo/stablehlo/dialect/StablehloOps.td ---- stablehlo/stablehlo/dialect/StablehloOps.td -+++ stablehlo/stablehlo/dialect/StablehloOps.td -@@ -328,6 +328,23 @@ - %result = stablehlo.exponential %operand : tensor<2x2xf64> - ``` - }]; -+ let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand, -+ DefaultValuedOptionalAttr:$result_accuracy); -+ let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result); -+ let extraClassDeclaration = commonClassDeclaration # [{ -+ LogicalResult reifyReturnTypeShapes( -+ OpBuilder& builder, ValueRange operands, -+ SmallVectorImpl& reifiedReturnShapes) { -+ return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), -+ operands.front(), -+ &reifiedReturnShapes); -+ } -+ }]; -+ let hasVerifier = 1; -+ -+ let assemblyFormat = [{ -+ $operand attr-dict `:` custom(type($operand), type($result)) -+ }]; - } - - def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one", -@@ -1963,7 +1980,7 @@ - DenseI64ArrayAttr:$broadcast_dimensions /*broadcast_in_dim_i2*/ - ); - -- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); -+ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); - - let hasVerifier = 1; - -@@ -2715,7 +2732,7 @@ - - let arguments = (ins HLO_TensorOrPerAxisQuantizedTensor:$operand); - -- let results = (outs HLO_StaticShapeTensorOrPerAxisQuantizedTensor); -+ let results = (outs HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor); - let hasVerifier = 1; - - let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; -diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo/dialect/TypeInference.cpp ---- stablehlo/stablehlo/dialect/TypeInference.cpp -+++ stablehlo/stablehlo/dialect/TypeInference.cpp -@@ -3724,9 +3724,8 @@ - Value operand, - ArrayRef broadcastDimensions, - Value result) { -+ // broadcast_in_dim_c1 - auto operandType = cast(operand.getType()); -- -- // broadcast_in_dim_c1 - if (failed(verifyQPerTensorScaleAndZeroPointConstraints(location, operandType, - result.getType()))) - return failure(); -@@ -4658,11 +4657,12 @@ - Value result) { - // If the operand type is dynamically shaped there is nothing to verify. - auto operandTy = cast(operand.getType()); -- if (!operandTy.hasStaticShape()) return success(); -+ auto resultTy = cast(result.getType()); -+ if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape()) -+ return success(); - - // If the operand type is statically shaped (not required) the number of - // elements must match that of the result type. -- auto resultTy = cast(result.getType()); - int64_t numResultElements = resultTy.getNumElements(); - int64_t numOperandElements = operandTy.getNumElements(); - if (numResultElements != numOperandElements) -@@ -5057,5 +5057,30 @@ - return success(); - } - -+LogicalResult verifyResultAccuracyCombination( -+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, -+ APFloat rtol, int64_t ulps, StringRef mode) { -+ if (mode == "DEFAULT" || mode == "HIGHEST") { -+ bool all_zero = atol.isZero() && rtol.isZero() && ulps == 0; -+ if (!all_zero) { -+ return emitError() -+ << "Invalid tolerances for ResultAccuracyAttr with mode " << mode -+ << ", must be all zero."; -+ } -+ } -+ return success(); -+} -+ -+LogicalResult verifyResultAccuracyAttr( -+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, -+ APFloat rtol, int64_t ulps, StringRef mode) { -+ if (atol.isNegative() || rtol.isNegative() || ulps < 0) -+ return emitError() << "Negative tolerance"; -+ if (failed( -+ verifyResultAccuracyCombination(emitError, atol, rtol, ulps, mode))) -+ return failure(); -+ return success(); -+} -+ - } // end namespace hlo - } // end namespace mlir -diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/dialect/TypeInference.h ---- stablehlo/stablehlo/dialect/TypeInference.h -+++ stablehlo/stablehlo/dialect/TypeInference.h -@@ -26,6 +26,7 @@ - #include "mlir/IR/SymbolTable.h" - #include "mlir/IR/Types.h" - #include "mlir/Interfaces/InferTypeOpInterface.h" -+#include "mlir/Support/LLVM.h" - #include "mlir/Support/LogicalResult.h" - #include "stablehlo/dialect/Base.h" - -@@ -596,6 +597,14 @@ - - LogicalResult verifyWhileOp(std::optional location, - ValueRange operand, Region& cond, Region& body); -+ -+LogicalResult verifyResultAccuracyCombination( -+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, -+ APFloat rtol, int64_t ulps, StringRef mode); -+ -+LogicalResult verifyResultAccuracyAttr( -+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, -+ APFloat rtol, int64_t ulps, StringRef mode); - } // end namespace hlo - } // end namespace mlir - -diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp ---- stablehlo/stablehlo/dialect/Version.cpp -+++ stablehlo/stablehlo/dialect/Version.cpp -@@ -75,7 +75,7 @@ - Version Version::fromCompatibilityRequirement( - CompatibilityRequirement requirement) { - // Compatibility requirement versions can be updated as needed, as long as the -- // version satisifies the requirement. -+ // version satisfies the requirement. - // The time frames used are from the date that the release was tagged on, not - // merged. The tag date is when the version has been verified and exported to - // XLA. See: https://github.com/openxla/stablehlo/tags -diff --ruN a/stablehlo/stablehlo/dialect/Version.h b/stablehlo/stablehlo/dialect/Version.h ---- stablehlo/stablehlo/dialect/Version.h -+++ stablehlo/stablehlo/dialect/Version.h -@@ -38,7 +38,7 @@ - static FailureOr fromString(llvm::StringRef versionRef); - - /// Return a Version representing the current VHLO dialect version. -- static Version getCurrentVersion() { return Version(1, 8, 11); } -+ static Version getCurrentVersion() { return Version(1, 9, 1); } - - /// Return a Version representing the minimum supported VHLO dialect version. - static Version getMinimumVersion() { return Version(0, 9, 0); } -diff --ruN a/stablehlo/stablehlo/dialect/VhloAttrs.td b/stablehlo/stablehlo/dialect/VhloAttrs.td ---- stablehlo/stablehlo/dialect/VhloAttrs.td -+++ stablehlo/stablehlo/dialect/VhloAttrs.td -@@ -21,18 +21,8 @@ - include "stablehlo/dialect/VhloBase.td" - include "stablehlo/dialect/VhloDialect.td" - include "stablehlo/dialect/VhloTypes.td" -- --def VHLO_VersionedAttrInterface : AttrInterface<"VersionedAttrInterface"> { -- let cppNamespace = "::mlir::vhlo"; -- let methods = [ -- InterfaceMethod< -- "Returns the minimum version of the VHLO dialect an attribute is supported in.", -- "mlir::vhlo::Version", "getMinVersion">, -- InterfaceMethod< -- "Returns the maximum version (inclusive) of the VHLO dialect an attribute is supported in.", -- "mlir::vhlo::Version", "getMaxVersion">, -- ]; --} -+include "stablehlo/dialect/VhloEnums.td" -+ - - class VHLO_AttrDef - : AttrDef { -@@ -190,4 +180,27 @@ - let assemblyFormat = "`<` struct(params) `>`"; - } - -+ -+def VHLO_ResultAccuracyAttrV1 : VHLO_AttrDef<"ResultAccuracyV1", "1.9.0", "current"> { -+ let mnemonic = "result_accuracy_v1"; -+ let summary = "The requested accuracy for transcendental unary ops."; -+ let parameters = (ins -+ VHLO_APFloatV1:$atol, -+ VHLO_APFloatV1:$rtol, -+ "int64_t":$ulps, -+ "mlir::Attribute":$mode -+ ); -+ let assemblyFormat = "`<` struct(params) `>`"; -+ let genVerifyDecl = 1; -+ let extraClassDefinition = [{ -+ LogicalResult ResultAccuracyV1Attr::verify( -+ llvm::function_ref errFn, -+ APFloat atol, APFloat rtol, int64_t ulps, -+ mlir::Attribute mode) { -+ if (!isFromVhlo(mode)) return errFn() << "expected VHLO result accuracy mode"; -+ return success(); -+ } -+ }]; -+} -+ - #endif // STABLEHLO_DIALECT_VHLO_ATTRS -diff --ruN a/stablehlo/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/stablehlo/dialect/VhloBytecode.cpp ---- stablehlo/stablehlo/dialect/VhloBytecode.cpp -+++ stablehlo/stablehlo/dialect/VhloBytecode.cpp -@@ -178,6 +178,18 @@ - /// bounds : svarint[] - /// } - kTypeExtensionsV1Attr = 18, -+ -+ // ResultAccuracyModeV1Attr { -+ // mode: varint (encoded enum) -+ // } -+ kResultAccuracyModeV1Attr = 19, -+ -+ // ResultAccuracyV1Attr { -+ // atol: APFloat -+ // rtol: APFloat -+ // ulps: svarint -+ // } -+ kResultAccuracyV1Attr = 20, - }; - - /// This enum contains marker codes used to indicate which type is -@@ -433,6 +445,10 @@ - TypeV1Attr readTypeV1Attr(DialectBytecodeReader &reader) const; - TypeExtensionsV1Attr readTypeExtensionsV1Attr( - DialectBytecodeReader &reader) const; -+ ResultAccuracyModeV1Attr readResultAccuracyModeV1Attr( -+ DialectBytecodeReader &reader) const; -+ ResultAccuracyV1Attr readResultAccuracyV1Attr( -+ DialectBytecodeReader &reader) const; - - // TO ADD ATTRIBUTE: Include a write method for each attribute in VHLO - // Ex: void write(SomeAttr attr, DialectBytecodeWriter &writer) const; -@@ -457,6 +473,9 @@ - void write(TransposeV1Attr attr, DialectBytecodeWriter &writer) const; - void write(TypeV1Attr attr, DialectBytecodeWriter &writer) const; - void write(TypeExtensionsV1Attr attr, DialectBytecodeWriter &writer) const; -+ void write(ResultAccuracyModeV1Attr attr, -+ DialectBytecodeWriter &writer) const; -+ void write(ResultAccuracyV1Attr attr, DialectBytecodeWriter &writer) const; - - //===--------------------------------------------------------------------===// - // Types -@@ -541,6 +560,10 @@ - return readTypeV1Attr(reader); - case vhlo_encoding::kTypeExtensionsV1Attr: - return readTypeExtensionsV1Attr(reader); -+ case vhlo_encoding::kResultAccuracyModeV1Attr: -+ return readResultAccuracyModeV1Attr(reader); -+ case vhlo_encoding::kResultAccuracyV1Attr: -+ return readResultAccuracyV1Attr(reader); - default: - reader.emitError() << "unknown vhlo attribute code: " << code; - return Attribute(); -@@ -558,7 +581,8 @@ - FftTypeV1Attr, FloatV1Attr, IntegerV1Attr, OutputOperandAliasV1Attr, - PrecisionV1Attr, RngAlgorithmV1Attr, RngDistributionV1Attr, - StringV1Attr, TensorV1Attr, TransposeV1Attr, TypeV1Attr, -- TypeExtensionsV1Attr>([&](auto attr) { -+ TypeExtensionsV1Attr, ResultAccuracyV1Attr, -+ ResultAccuracyModeV1Attr>([&](auto attr) { - LOG_WRITE_CALL; - write(attr, writer); - return success(); -@@ -1450,6 +1474,55 @@ - writer.writeType(type.getElementType()); - } - -+//===----------------------------------------------------------------------===// -+// ResultAccuracyModeAttr -+ -+ResultAccuracyModeV1Attr VhloBytecodeInterface::readResultAccuracyModeV1Attr( -+ DialectBytecodeReader &reader) const { -+ LOG_READ_CALL; -+ return hlo::bytecode::readEnumAttribute( -+ reader, getContext(), -+ [](uint32_t val) { return symbolizeResultAccuracyModeV1(val); }); -+} -+ -+void VhloBytecodeInterface::write(ResultAccuracyModeV1Attr attr, -+ DialectBytecodeWriter &writer) const { -+ writer.writeVarInt(vhlo_encoding::kResultAccuracyModeV1Attr); -+ hlo::bytecode::writeEnumAttribute(attr, writer); -+} -+ -+//===----------------------------------------------------------------------===// -+// ResultAccuracyAttr -+ -+ResultAccuracyV1Attr VhloBytecodeInterface::readResultAccuracyV1Attr( -+ DialectBytecodeReader &reader) const { -+ LOG_READ_CALL; -+ FailureOr atol; -+ FailureOr rtol; -+ int64_t ulps; -+ ResultAccuracyModeV1Attr mode; -+ if (failed(atol = -+ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || -+ failed(rtol = -+ reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || -+ failed(reader.readSignedVarInt(ulps)) || -+ failed(reader.readAttribute(mode))) { -+ mlir::emitWarning(mlir::UnknownLoc::get(getContext())) -+ << "failed to read APFloat for atol"; -+ return ResultAccuracyV1Attr(); -+ } -+ return ResultAccuracyV1Attr::get(getContext(), *atol, *rtol, ulps, mode); -+} -+ -+void VhloBytecodeInterface::write(ResultAccuracyV1Attr attr, -+ DialectBytecodeWriter &writer) const { -+ writer.writeVarInt(vhlo_encoding::kResultAccuracyV1Attr); -+ writer.writeAPFloatWithKnownSemantics(attr.getAtol()); -+ writer.writeAPFloatWithKnownSemantics(attr.getRtol()); -+ writer.writeSignedVarInt(attr.getUlps()); -+ writer.writeAttribute(attr.getMode()); -+} -+ - } // namespace - - void addBytecodeInterface(VhloDialect *dialect) { -diff --ruN a/stablehlo/stablehlo/dialect/VhloDialect.td b/stablehlo/stablehlo/dialect/VhloDialect.td ---- stablehlo/stablehlo/dialect/VhloDialect.td -+++ stablehlo/stablehlo/dialect/VhloDialect.td -@@ -47,6 +47,7 @@ - 1.6.0: Add DotAlgorithm specificaiton to `dot_general`. - 1.7.0: Introduce `f8E4M3` and `f8E3M4` types. - 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types. -+ 1.9.0: Add `ResultAccuracy` attribute to `exp` op. - }]; - - let useDefaultAttributePrinterParser = 0; -diff --ruN a/stablehlo/stablehlo/dialect/VhloEnums.td b/stablehlo/stablehlo/dialect/VhloEnums.td ---- stablehlo/stablehlo/dialect/VhloEnums.td -+++ stablehlo/stablehlo/dialect/VhloEnums.td -@@ -20,7 +20,20 @@ - include "mlir/IR/EnumAttr.td" - include "mlir/IR/PatternBase.td" - include "stablehlo/dialect/VhloBase.td" --include "stablehlo/dialect/VhloAttrs.td" -+include "stablehlo/dialect/VhloDialect.td" -+include "mlir/IR/AttrTypeBase.td" -+ -+def VHLO_VersionedAttrInterface : AttrInterface<"VersionedAttrInterface"> { -+ let cppNamespace = "::mlir::vhlo"; -+ let methods = [ -+ InterfaceMethod< -+ "Returns the minimum version of the VHLO dialect an attribute is supported in.", -+ "mlir::vhlo::Version", "getMinVersion">, -+ InterfaceMethod< -+ "Returns the maximum version (inclusive) of the VHLO dialect an attribute is supported in.", -+ "mlir::vhlo::Version", "getMaxVersion">, -+ ]; -+} - - class VHLO_I32EnumAttr cases> : - I32EnumAttr { -@@ -198,4 +211,23 @@ - def VHLO_TransposeAttrV1 - : VHLO_EnumAttr; - -+//===----------------------------------------------------------------------===// -+// ResultAccuracyMode -+//===----------------------------------------------------------------------===// -+ -+def VHLO_RESULT_V1_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; -+def VHLO_RESULT_V1_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; -+def VHLO_RESULT_V1_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; -+ -+def VHLO_ResultAccuracyModeV1 : VHLO_I32EnumAttr<"ResultAccuracyModeV1", -+ [ -+ VHLO_RESULT_V1_ACCURACY_DEFAULT, -+ VHLO_RESULT_V1_ACCURACY_HIGHEST, -+ VHLO_RESULT_V1_ACCURACY_TOLERANCE -+ ]> {} -+ -+def VHLO_ResultAccuracyModeV1Attr -+ : VHLO_EnumAttr; -+ -+ - #endif // STABLEHLO_DIALECT_VHLO_ENUMS -diff --ruN a/stablehlo/stablehlo/dialect/VhloOps.cpp b/stablehlo/stablehlo/dialect/VhloOps.cpp ---- stablehlo/stablehlo/dialect/VhloOps.cpp -+++ stablehlo/stablehlo/dialect/VhloOps.cpp -@@ -25,7 +25,7 @@ - #include "llvm/ADT/SmallVectorExtras.h" - #include "llvm/ADT/StringExtras.h" - #include "llvm/ADT/StringRef.h" --#include "llvm/ADT/TypeSwitch.h" -+#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep - #include "llvm/Support/Casting.h" - #include "mlir/Dialect/Shape/IR/Shape.h" - #include "mlir/IR/Attributes.h" -@@ -40,7 +40,7 @@ - #include "mlir/Support/LLVM.h" - #include "mlir/Support/LogicalResult.h" - #include "mlir/Support/TypeID.h" --#include "stablehlo/dialect/AssemblyFormat.h" -+#include "stablehlo/dialect/AssemblyFormat.h" // IWYU pragma: keep - #include "stablehlo/dialect/Version.h" - #include "stablehlo/dialect/VhloBytecode.h" - #include "stablehlo/dialect/VhloTypes.h" -@@ -184,12 +184,13 @@ - return success(); - } - --void TensorV1Attr::print(mlir::AsmPrinter& p) const { -- p << '<' -- << DenseIntOrFPElementsAttr::getFromRawBuffer( -- llvm::cast(convertTypeToBuiltinForPrint(getType())), -- getData()) -- << '>'; -+void TensorV1Attr::print(mlir::AsmPrinter& odsPrinter) const { -+ odsPrinter << '<' -+ << DenseIntOrFPElementsAttr::getFromRawBuffer( -+ llvm::cast( -+ convertTypeToBuiltinForPrint(getType())), -+ getData()) -+ << '>'; - } - - // Parse tensor elements using DenseIntOrFPElementsAttr printing. -diff --ruN a/stablehlo/stablehlo/dialect/VhloOps.td b/stablehlo/stablehlo/dialect/VhloOps.td ---- stablehlo/stablehlo/dialect/VhloOps.td -+++ stablehlo/stablehlo/dialect/VhloOps.td -@@ -618,8 +618,15 @@ - let results = (outs VHLO_AnyType:$result); - } - --def VHLO_ExpOpV1 : VHLO_Op<"exponential_v1", "0.9.0", "current"> { -- let arguments = (ins VHLO_AnyType:$operand); -+def VHLO_ExpOpV1 : VHLO_Op<"exponential_v1", "0.9.0", "1.8.0"> { -+ let arguments = (ins VHLO_AnyType:$operand); -+ let results = (outs VHLO_AnyType:$result); -+} -+ -+def VHLO_ExpOpV2 : VHLO_Op<"exponential_v2", "1.9.0", "current"> { -+ let arguments = (ins -+ VHLO_AnyType:$operand, -+ VHLO_AnyAttr:$result_accuracy); - let results = (outs VHLO_AnyType:$result); - } - -diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp b/stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp ---- stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp -+++ stablehlo/stablehlo/integrations/c/StablehloAttributes.cpp -@@ -16,6 +16,7 @@ - #include - #include - -+#include "llvm/ADT/APFloat.h" - #include "llvm/ADT/ArrayRef.h" - #include "llvm/Support/Casting.h" - #include "llvm/Support/ErrorHandling.h" -@@ -687,3 +688,69 @@ - return llvm::cast(unwrap(attr)) - .getBounds()[pos]; - } -+ -+//===----------------------------------------------------------------------===// -+// ResultAccuracyModeAttr -+//===----------------------------------------------------------------------===// -+ -+MlirAttribute stablehloResultAccuracyModeAttrGet(MlirContext ctx, -+ MlirStringRef value) { -+ std::optional accuracyMode = -+ mlir::stablehlo::symbolizeResultAccuracyMode(unwrap(value)); -+ if (!accuracyMode) llvm::report_fatal_error("Invalid value."); -+ return wrap(mlir::stablehlo::ResultAccuracyModeAttr::get( -+ unwrap(ctx), accuracyMode.value())); -+} -+ -+bool stablehloAttributeIsAResultAccuracyModeAttr(MlirAttribute attr) { -+ return llvm::isa(unwrap(attr)); -+} -+ -+MlirStringRef stablehloResultAccuracyModeAttrGetValue(MlirAttribute attr) { -+ return wrap(mlir::stablehlo::stringifyResultAccuracyMode( -+ llvm::cast(unwrap(attr)) -+ .getValue())); -+} -+//===----------------------------------------------------------------------===// -+// ResultAccuracyAttr -+//===----------------------------------------------------------------------===// -+ -+MlirAttribute stablehloResultAccuracyAttrGet(MlirContext ctx, double atol, -+ double rtol, int64_t ulps, -+ MlirStringRef mode) { -+ std::optional accuracyMode = -+ mlir::stablehlo::symbolizeResultAccuracyMode(unwrap(mode)); -+ if (!accuracyMode) llvm::report_fatal_error("Invalid value."); -+ mlir::stablehlo::ResultAccuracyModeAttr modeAttr = -+ mlir::stablehlo::ResultAccuracyModeAttr::get(unwrap(ctx), -+ accuracyMode.value()); -+ return wrap(mlir::stablehlo::ResultAccuracyAttr::get( -+ unwrap(ctx), llvm::APFloat(atol), llvm::APFloat(rtol), ulps, modeAttr)); -+} -+ -+bool stablehloAttributeIsAResultAccuracyAttr(MlirAttribute attr) { -+ return llvm::isa(unwrap(attr)); -+} -+ -+double stablehloResultAccuracyAttrGetAtol(MlirAttribute attr) { -+ llvm::APFloat result = -+ llvm::cast(unwrap(attr)).getAtol(); -+ return result.convertToDouble(); -+} -+ -+double stablehloResultAccuracyAttrGetRtol(MlirAttribute attr) { -+ llvm::APFloat result = -+ llvm::cast(unwrap(attr)).getRtol(); -+ return result.convertToDouble(); -+} -+ -+int64_t stablehloResultAccuracyAttrGetUlps(MlirAttribute attr) { -+ return llvm::cast(unwrap(attr)) -+ .getUlps(); -+} -+ -+MlirAttribute stablehloResultAccuracyAttrGetMode(MlirAttribute attr) { -+ mlir::stablehlo::ResultAccuracyModeAttr modeAttr = -+ llvm::cast(unwrap(attr)).getMode(); -+ return wrap(modeAttr); -+} -diff --ruN a/stablehlo/stablehlo/integrations/c/StablehloAttributes.h b/stablehlo/stablehlo/integrations/c/StablehloAttributes.h ---- stablehlo/stablehlo/integrations/c/StablehloAttributes.h -+++ stablehlo/stablehlo/integrations/c/StablehloAttributes.h -@@ -13,6 +13,7 @@ - #ifndef STABLEHLO_INTEGRATIONS_C_STABLEHLO_ATTRIBUTES_H - #define STABLEHLO_INTEGRATIONS_C_STABLEHLO_ATTRIBUTES_H - -+#include - #include - #include - -@@ -376,6 +377,42 @@ - MLIR_CAPI_EXPORTED int64_t - stablehloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos); - -+// ===---------------------------------------------------------------------===// -+// ResultAccuracyModeAttr -+//===----------------------------------------------------------------------===// -+ -+MLIR_CAPI_EXPORTED MlirAttribute -+stablehloResultAccuracyModeAttrGet(MlirContext ctx, MlirStringRef value); -+ -+MLIR_CAPI_EXPORTED bool stablehloAttributeIsAResultAccuracyModeAttr( -+ MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirStringRef -+stablehloResultAccuracyModeAttrGetValue(MlirAttribute attr); -+ -+// ===---------------------------------------------------------------------===// -+// ResultAccuracyAttr -+//===----------------------------------------------------------------------===// -+ -+MLIR_CAPI_EXPORTED MlirAttribute -+stablehloResultAccuracyAttrGet(MlirContext ctx, double atol, double rtol, -+ int64_t ulps, MlirStringRef value); -+ -+MLIR_CAPI_EXPORTED bool stablehloAttributeIsAResultAccuracyAttr( -+ MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED double stablehloResultAccuracyAttrGetAtol( -+ MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED double stablehloResultAccuracyAttrGetRtol( -+ MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED int64_t -+stablehloResultAccuracyAttrGetUlps(MlirAttribute attr); -+ -+MLIR_CAPI_EXPORTED MlirAttribute -+stablehloResultAccuracyAttrGetMode(MlirAttribute attr); -+ - #ifdef __cplusplus - } - #endif -diff --ruN a/stablehlo/stablehlo/integrations/python/StablehloModule.cpp b/stablehlo/stablehlo/integrations/python/StablehloModule.cpp ---- stablehlo/stablehlo/integrations/python/StablehloModule.cpp -+++ stablehlo/stablehlo/integrations/python/StablehloModule.cpp -@@ -599,6 +599,50 @@ - stablehloTypeExtensionsGetBoundsElem); - }); - -+ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -+ m, "ResultAccuracyAttr", stablehloAttributeIsAResultAccuracyAttr) -+ .def_classmethod( -+ "get", -+ [](nb::object cls, double atol, double rtol, int64_t ulps, -+ const std::string &mode, MlirContext ctx) { -+ return cls(stablehloResultAccuracyAttrGet( -+ ctx, atol, rtol, ulps, -+ mlirStringRefCreate(mode.c_str(), mode.size()))); -+ }, -+ nb::arg("cls"), nb::arg("atol"), nb::arg("rtol"), nb::arg("ulps"), -+ nb::arg("mode"), nb::arg("context") = nb::none(), -+ "Creates a ResultAccuracyAttr with the given values.") -+ .def_property_readonly("atol", -+ [](MlirAttribute self) { -+ return stablehloResultAccuracyAttrGetAtol(self); -+ }) -+ .def_property_readonly("rtol", -+ [](MlirAttribute self) { -+ return stablehloResultAccuracyAttrGetRtol(self); -+ }) -+ .def_property_readonly("ulps", -+ [](MlirAttribute self) { -+ return stablehloResultAccuracyAttrGetUlps(self); -+ }) -+ .def_property_readonly("mode", [](MlirAttribute self) { -+ return toPyString(stablehloResultAccuracyModeAttrGetValue( -+ stablehloResultAccuracyAttrGetMode(self))); -+ }); -+ -+ mlir::python::nanobind_adaptors::mlir_attribute_subclass( -+ m, "ResultAccuracyModeAttr", stablehloAttributeIsAResultAccuracyModeAttr) -+ .def_classmethod( -+ "get", -+ [](nb::object cls, const std::string &value, MlirContext ctx) { -+ return cls(stablehloResultAccuracyModeAttrGet( -+ ctx, mlirStringRefCreate(value.c_str(), value.size()))); -+ }, -+ nb::arg("cls"), nb::arg("value"), nb::arg("context") = nb::none(), -+ "Creates a ResultAccuracyModeAttr with the given values.") -+ .def_property_readonly("value", [](MlirAttribute self) { -+ return toPyString(stablehloResultAccuracyModeAttrGetValue(self)); -+ }); -+ - // - // StableHLO APIs - // -diff --ruN a/stablehlo/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/stablehlo/integrations/python/tests/stablehlo.py ---- stablehlo/stablehlo/integrations/python/tests/stablehlo.py -+++ stablehlo/stablehlo/integrations/python/tests/stablehlo.py -@@ -386,3 +386,23 @@ - cloned_module = module.operation.clone() - pipeline.run(cloned_module.operation) - assert str(module) == str(cloned_module) -+ -+ -+@run -+def test_result_accuracy_attr_default(): -+ attr = stablehlo.ResultAccuracyAttr.get(atol=0, rtol=0, ulps=0, mode="DEFAULT") -+ assert attr is not None -+ assert attr.mode == "DEFAULT" -+ assert attr.atol == 0 -+ assert attr.rtol == 0 -+ assert attr.ulps == 0 -+ -+@run -+def test_result_accuracy_attr_tolerance(): -+ attr = stablehlo.ResultAccuracyAttr.get(atol=1e-5, rtol=1.0, -+ ulps=2, mode="TOLERANCE") -+ assert attr is not None -+ assert attr.mode == "TOLERANCE" -+ assert attr.atol == 1e-5 -+ assert attr.rtol == 1.0 -+ assert attr.ulps == 2 -diff --ruN a/stablehlo/stablehlo/reference/Types.cpp b/stablehlo/stablehlo/reference/Types.cpp ---- stablehlo/stablehlo/reference/Types.cpp -+++ stablehlo/stablehlo/reference/Types.cpp -@@ -48,13 +48,12 @@ - } - - bool isSupportedFloatType(Type type) { -- return type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || -- type.isFloat6E3M2FN() || type.isFloat8E3M4() || -- type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3() || -- type.isFloat8E4M3FN() || type.isFloat8E4M3FNUZ() || -- type.isFloat8E5M2() || type.isFloat8E5M2FNUZ() || -- type.isFloat8E8M0FNU() || type.isF16() || type.isBF16() || -- type.isF32() || type.isF64(); -+ return llvm::isa< -+ mlir::Float4E2M1FNType, mlir::Float6E2M3FNType, mlir::Float6E3M2FNType, -+ mlir::Float8E3M4Type, mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3Type, -+ mlir::Float8E4M3FNType, mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type, -+ mlir::Float8E5M2FNUZType, mlir::Float8E8M0FNUType, mlir::Float16Type, -+ mlir::BFloat16Type, mlir::Float32Type, mlir::Float64Type>(type); - } - - bool isSupportedComplexType(Type type) { -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo.mlir -@@ -1274,6 +1274,22 @@ - - // ----- - -+// CHECK-LABEL: func @broadcast_in_dim_dynamic_i1 -+func.func @broadcast_in_dim_dynamic_i1(%arg0: tensor) -> tensor<1x3xi32> { -+ %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor) -> tensor<1x3xi32> -+ return %0 : tensor<1x3xi32> -+} -+ -+// ----- -+ -+func.func @broadcast_in_dim_dynamic_result(%arg0: tensor<3xi32>) -> tensor { -+ // expected-error@+1 {{must be statically shaped or single bounded dimension tensor}} -+ %0 = "stablehlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = array} : (tensor<3xi32>) -> tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ - // Regression test for b/180052624, where this was improperly marked as an - // invalid stablehlo.broadcast_in_dim op. - // CHECK-LABEL: func @broadcast_in_dim_dynamic_shaped_operand -@@ -1775,6 +1791,30 @@ - // expected-error@+1 {{'precision_config' failed to satisfy constraint}} - %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = ["FOO", #stablehlo]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %0: tensor<2x2xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @exponential_result_accuracy -+func.func @exponential_result_accuracy(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor -+ func.return %0: tensor -+} -+ -+// ----- -+ -+// CHECK-LABEL: func @exponential_result_accuracy_tol -+func.func @exponential_result_accuracy_tol(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor -+ func.return %0: tensor -+} -+ -+// ----- -+ -+func.func @exponential_result_accuracy_tol(%arg0: tensor) -> tensor { -+ // expected-error@+1 {{Invalid tolerances for ResultAccuracyAttr with mode HIGHEST, must be all zero.}} -+ %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor -+ func.return %0: tensor - } - - // ----- -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir -@@ -0,0 +1,63 @@ -+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s -+ -+// This file captures some quirks to bounded dynamism in StableHLO that are -+// included to allow StableHLO to repersent existing TF programs. -+ -+// CHECK-LABEL: reshape_with_single_bounded_dimension -+func.func @reshape_with_single_bounded_dimension(%arg0: tensor>) -> tensor<2x?xf32, #stablehlo.bounds> { -+ %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor<2x?xf32, #stablehlo.bounds> -+ // CHECK: return {{.*}} #stablehlo.bounds -+ return %0 : tensor<2x?xf32, #stablehlo.bounds> -+} -+ -+// ----- -+ -+// CHECK-LABEL: reshape_scalar_with_single_bounded_dimension -+func.func @reshape_scalar_with_single_bounded_dimension(%arg0: tensor>) -> tensor<1x?xf32, #stablehlo.bounds> { -+ %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #stablehlo.bounds> -+ // CHECK: return {{.*}} #stablehlo.bounds -+ return %0 : tensor<1x?xf32, #stablehlo.bounds> -+} -+ -+// ----- -+ -+func.func @reshape_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor> { -+ // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} -+ %0 = stablehlo.reshape %arg0 : (tensor>) -> tensor> -+ return %0 : tensor> -+} -+ -+// ----- -+ -+// CHECK-LABEL: broadcast_in_dim_with_single_bounded_dimension -+func.func @broadcast_in_dim_with_single_bounded_dimension(%arg0: tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> { -+ %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor<1x?xf32, #stablehlo.bounds>) -> tensor<2x1x?xf32, #stablehlo.bounds> -+ // CHECK: return {{.*}} #stablehlo.bounds -+ return %0 : tensor<2x1x?xf32, #stablehlo.bounds> -+} -+ -+// ----- -+ -+func.func @broadcast_in_dim_with_multiple_bounded_dimensions(%arg0: tensor>) -> tensor<2x?x?xf32, #stablehlo.bounds> { -+ // expected-error@+1 {{result #0 must be statically shaped or single bounded dimension tensor}} -+ %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1] : (tensor>) -> tensor<2x?x?xf32, #stablehlo.bounds> -+ return %0 : tensor<2x?x?xf32, #stablehlo.bounds> -+} -+ -+// ----- -+ -+// CHECK-LABEL: constant_splat_broadcast -+func.func @constant_splat_broadcast() -> tensor<1x?xf32, #stablehlo.bounds> { -+ %0 = stablehlo.constant dense<1.0> : tensor -+ %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1x?xf32, #stablehlo.bounds> -+ // CHECK: tensor<1x?xf32, #stablehlo.bounds> -+ return %1 : tensor<1x?xf32, #stablehlo.bounds> -+} -+ -+// ----- -+ -+func.func @constant_with_dynamic_shape() -> tensor<1x?xf32, #stablehlo.bounds> { -+ // expected-error@+2 {{elements literal type must have static shape}} -+ %c = stablehlo.constant dense<1> : tensor<1x?xf32, #stablehlo.bounds> -+ return %c : tensor<1x?xf32, #stablehlo.bounds> -+} -diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir ---- stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir -+++ stablehlo/stablehlo/tests/ops_stablehlo_roundtrip.mlir -@@ -766,6 +766,11 @@ - func.return %0 : tensor<3x4xf32> - } - -+func.func @test_unary_result_accuracy(%arg0: tensor<2xf32>) -> tensor<2xf32> { -+ %exp = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor<2xf32>) -> tensor<2xf32> -+ func.return %exp : tensor<2xf32> -+} -+ - func.func @test_unary_round_nearest_even(%arg0: tensor<2xf32>) -> tensor<2xf32> { - %0 = "stablehlo.round_nearest_even"(%arg0) {} : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -diff --ruN a/stablehlo/stablehlo/tests/print_stablehlo.mlir b/stablehlo/stablehlo/tests/print_stablehlo.mlir ---- stablehlo/stablehlo/tests/print_stablehlo.mlir -+++ stablehlo/stablehlo/tests/print_stablehlo.mlir -@@ -406,3 +406,16 @@ - %slice6 = stablehlo.slice %arg0 [1:3:1, 4:8:2] : (tensor<3x8xf32>) -> tensor<2x2xf32> - return %slice1, %slice2, %slice3, %slice4, %slice5, %slice6 : tensor<1xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>, tensor<2x2xf32>, tensor<2x2xf32> - } -+ -+func.func @result_accuracy_default() -> () attributes { -+ // CHECK: mode.default = #stablehlo.result_accuracy> -+ // CHECK: mode.highest = #stablehlo.result_accuracy> -+ // CHECK: mode.tolerance_full = #stablehlo.result_accuracy> -+ // CHECK: mode.tolerance_partial = #stablehlo.result_accuracy> -+ mode.default = #stablehlo.result_accuracy>, -+ mode.highest = #stablehlo.result_accuracy>, -+ mode.tolerance_full = #stablehlo.result_accuracy>, -+ mode.tolerance_partial = #stablehlo.result_accuracy> -+} { -+ func.return -+} -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir -@@ -1940,6 +1940,17 @@ - return %1 : tensor<12xi64> - } - -+// ----- -+ -+// CHECK-LABEL: @reorder_invalid_with_dynamic_shape -+func.func @reorder_invalid_with_dynamic_shape(%arg0: tensor<1x3x4xf32>) -> (tensor) { -+ // CHECK: %[[RESHAPE:.+]] = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> -+ // CHECK-NEXT: %[[CONVERT:.+]] = stablehlo.convert %[[RESHAPE]] : (tensor<3x4xf32>) -> tensor -+ // CHECK: return %[[CONVERT]] -+ %0 = stablehlo.reshape %arg0 : (tensor<1x3x4xf32>) -> tensor<3x4xf32> -+ %1 = stablehlo.convert %0 : (tensor<3x4xf32>) -> tensor -+ return %1 : tensor -+} - - // ----- - -diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir ---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir -+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir -@@ -0,0 +1,2966 @@ -+// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s -+// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=1.9.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s -+// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-opt > %t.0 -+// RUN: stablehlo-opt --strip-debuginfo %s > %t.1 -+// RUN: diff %t.0 %t.1 -+// RUN: stablehlo-translate --serialize --target=1.9.0 --strip-debuginfo %s > %t.2 -+// RUN: diff %s.bc %t.2 -+// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode -debug-only=vhlo-bytecode %s 2>&1 | FileCheck --check-prefix=CHECK-WARN %s -+// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode %s | stablehlo-opt -debug-only=vhlo-bytecode 2>&1 | FileCheck --check-prefix=CHECK-WARN %s -+ -+// CHECK-WARN-NOT: Not Implemented -+ -+// ============ ATTRIBUTES ============ -+ -+// CHECK-LABEL: "attr_comparison_direction_eq" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ // CHECK: comparison_direction = #vhlo -+ comparison_direction = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_direction_ne" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ // CHECK: comparison_direction = #vhlo -+ comparison_direction = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_direction_ge" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ // CHECK: comparison_direction = #vhlo -+ comparison_direction = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_direction_gt" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ // CHECK: comparison_direction = #vhlo -+ comparison_direction = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_direction_le" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ // CHECK: comparison_direction = #vhlo -+ comparison_direction = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_direction_lt" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ // CHECK: comparison_direction = #vhlo -+ comparison_direction = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_type_notype" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ comparison_direction = #stablehlo -+ // CHECK: compare_type = #vhlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_type_float" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ comparison_direction = #stablehlo, -+ // CHECK: compare_type = #vhlo, -+ compare_type = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_type_totalorder" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ comparison_direction = #stablehlo, -+ // CHECK: compare_type = #vhlo, -+ compare_type = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_type_signed" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ comparison_direction = #stablehlo, -+ // CHECK: compare_type = #vhlo, -+ compare_type = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_comparison_type_unsigned" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ comparison_direction = #stablehlo, -+ // CHECK: compare_type = #vhlo, -+ compare_type = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// ConvDimensionNumbers aka #stablehlo.conv is covered below. -+ -+// CHECK-LABEL: "attr_custom_call_api_version_unspecified" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo", -+ // CHECK: api_version = #vhlo -+ api_version = 0 : i32 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_custom_call_api_version_original" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo", -+ // CHECK: api_version = #vhlo -+ api_version = 1 : i32 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_custom_call_api_version_status_returning" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo", -+ // CHECK: api_version = #vhlo -+ api_version = 2 : i32 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo", -+ // CHECK: api_version = #vhlo -+ api_version = 3 : i32 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_dict" -+// CHECK: #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = #vhlo.integer_v1<1 : i32>, #vhlo.string_v1<"attr2"> = #vhlo.integer_v1<2 : i32>} -+func.func @attr_dict() attributes {stablehlo.attr = {attr1 = 1 : i32, attr2 = 2 : i32}} { -+ return -+} -+ -+// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+// CHECK: api_version = #vhlo -+// CHECK-SAME: backend_config = #vhlo.dict_v1<{#vhlo.string_v1<"bar"> = #vhlo.integer_v1<42 : i32>}> -+func.func @attr_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo", -+ backend_config= {bar = 42 : i32}, -+ api_version = 4 : i32 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+ -+// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi_no_backend_config" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+// CHECK: api_version = #vhlo -+// CHECK-SAME: backend_config = #vhlo.dict_v1<{}> -+func.func @attr_custom_call_api_version_typed_ffi_no_backend_config(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo", -+ api_version = 4 : i32 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// DotDimensionNumbers aka #stablehlo.dot is covered below. -+ -+// CHECK-LABEL: "attr_fft_type_fft" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { -+ %0 = "stablehlo.fft"(%arg0) { -+ // CHECK: fft_type = #vhlo -+ fft_type = #stablehlo, -+ fft_length = array -+ } : (tensor<16xcomplex>) -> tensor<16xcomplex> -+ func.return %0 : tensor<16xcomplex> -+} -+ -+// CHECK-LABEL: "attr_fft_type_ifft" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { -+ %0 = "stablehlo.fft"(%arg0) { -+ // CHECK: fft_type = #vhlo -+ fft_type = #stablehlo, -+ fft_length = array -+ } : (tensor<16xcomplex>) -> tensor<16xcomplex> -+ func.return %0 : tensor<16xcomplex> -+} -+ -+// CHECK-LABEL: "attr_fft_type_rfft" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { -+ %0 = "stablehlo.fft"(%arg0) { -+ // CHECK: fft_type = #vhlo -+ fft_type = #stablehlo, -+ fft_length = array -+ } : (tensor<16xf32>) -> tensor<9xcomplex> -+ func.return %0 : tensor<9xcomplex> -+} -+ -+// CHECK-LABEL: "attr_fft_type_irfft" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { -+ %0 = "stablehlo.fft"(%arg0) { -+ // CHECK: fft_type = #vhlo -+ fft_type = #stablehlo, -+ fft_length = array -+ } : (tensor<9xcomplex>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "attr_result_accuracy_HIGHEST" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} -+func.func @attr_result_accuracy_HIGHEST(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor<8x16xf32>) -> tensor<8x16xf32> -+ func.return %0 : tensor<8x16xf32> -+} -+ -+// CHECK-LABEL: "attr_result_accuracy_TOLERANCE" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} -+func.func @attr_result_accuracy_TOLERANCE(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor<8x16xf32>) -> tensor<8x16xf32> -+ func.return %0 : tensor<8x16xf32> -+} -+ -+// CHECK-LABEL: "attr_result_accuracy_DEFAULT" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} -+func.func @attr_result_accuracy_DEFAULT(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor<8x16xf32>) -> tensor<8x16xf32> -+ func.return %0 : tensor<8x16xf32> -+} -+ -+// GatherDimensionNumbers aka #stablehlo.gather is covered below. -+ -+// CHECK-LABEL: "attr_precision_config_default" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { -+ %0 = "stablehlo.dot"(%arg0, %arg1) { -+ // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> -+ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> -+ func.return %0 : tensor<8x8xf32> -+} -+ -+// CHECK-LABEL: "attr_precision_config_high" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { -+ %0 = "stablehlo.dot"(%arg0, %arg1) { -+ // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> -+ precision_config = [#stablehlo, #stablehlo] -+ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> -+ func.return %0 : tensor<8x8xf32> -+} -+ -+// CHECK-LABEL: "attr_precision_config_highest" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { -+ %0 = "stablehlo.dot"(%arg0, %arg1) { -+ // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> -+ precision_config = [#stablehlo, #stablehlo] -+ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> -+ func.return %0 : tensor<8x8xf32> -+} -+ -+// CHECK-LABEL: "attr_rng_algorithm_default" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { -+ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { -+ // CHECK: rng_algorithm = #vhlo -+ rng_algorithm = #stablehlo -+ } : (tensor) -> (tensor, tensor) -+ func.return %0#0, %0#1 : tensor, tensor -+} -+ -+// CHECK-LABEL: "attr_rng_algorithm_three_fry" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { -+ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { -+ // CHECK: rng_algorithm = #vhlo -+ rng_algorithm = #stablehlo -+ } : (tensor) -> (tensor, tensor) -+ func.return %0#0, %0#1 : tensor, tensor -+} -+ -+// CHECK-LABEL: "attr_rng_algorithm_philox" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { -+ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { -+ // CHECK: rng_algorithm = #vhlo -+ rng_algorithm = #stablehlo -+ } : (tensor) -> (tensor, tensor) -+ func.return %0#0, %0#1 : tensor, tensor -+} -+ -+// CHECK-LABEL: "attr_rng_distribution_uniform" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { -+ %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { -+ // CHECK: rng_distribution = #vhlo -+ rng_distribution = #stablehlo -+ } : (tensor, tensor, tensor<0xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "attr_rng_distribution_normal" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { -+ %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { -+ // CHECK: rng_distribution = #vhlo -+ rng_distribution = #stablehlo -+ } : (tensor, tensor, tensor<0xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. -+ -+// CHECK-LABEL: "attr_transpose_no_transpose" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { -+ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { -+ left_side = true, -+ lower = true, -+ unit_diagonal = true, -+ // transpose_a = #vhlo, -+ transpose_a = #stablehlo -+ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "attr_transpose_transpose" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { -+ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { -+ left_side = true, -+ lower = true, -+ unit_diagonal = true, -+ // transpose_a = #vhlo, -+ transpose_a = #stablehlo -+ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "attr_transpose_adjoint" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { -+ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { -+ left_side = true, -+ lower = true, -+ unit_diagonal = true, -+ // transpose_a = #vhlo, -+ transpose_a = #stablehlo -+ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. -+ -+// CHECK-LABEL: "attr_type_extensions_bounds" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @attr_type_extensions_bounds(%arg0: tensor>) -> tensor> { -+ // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> () -+ func.return %arg0 : tensor> -+} -+ -+ -+// ============ DEFAULTS ============ -+ -+// CHECK-LABEL: "default_all_gather" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { -+ // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ -+ // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, -+ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.all_gather"(%arg0) { -+ all_gather_dim = 1 : i64, -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> -+ } : (tensor<16x8xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "default_all_gather_variadic" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_all_gather_variadic(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) { -+ %0:2 = "stablehlo.all_gather"(%arg0, %arg1) { -+ all_gather_dim = 1 : i64, -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> -+ } : (tensor<16x8xf32>, tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) -+ func.return %0#0, %0#1 : tensor<16x16xf32>, tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "default_all_reduce" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_all_reduce(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) -+ // CHECK-SAME: <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, -+ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ -+ %0 = "stablehlo.all_reduce"(%arg0) ({ -+ ^bb0(%arg1: tensor, %arg2: tensor): -+ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "default_all_to_all" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { -+ // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, -+ // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> -+ // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> -+ %0 = "stablehlo.all_to_all"(%arg0) { -+ split_dimension = 1 : i64, -+ concat_dimension = 0 : i64, -+ split_count = 4 : i64, -+ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> -+ } : (tensor<4x16xf32>) -> tensor<16x4xf32> -+ func.return %0 : tensor<16x4xf32> -+} -+ -+// CHECK-LABEL: "default_all_to_all_variadic" -+func.func @default_all_to_all_variadic(%arg0: tensor<4x16xf32>, %arg1: tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) { -+ %0:2 = "stablehlo.all_to_all"(%arg0, %arg1) { -+ split_dimension = 1 : i64, -+ concat_dimension = 0 : i64, -+ split_count = 4 : i64, -+ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, -+ channel_handle = #stablehlo.channel_handle -+ } : (tensor<4x16xf32>, tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) -+ func.return %0#0, %0#1 : tensor<16x4xf32>, tensor<20x4xf32> -+} -+ -+// CHECK-LABEL: "default_cholesky" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { -+ // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: lower = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> -+ func.return %0 : tensor<1x16x16xf32> -+} -+ -+// CHECK-LABEL: "default_collective_permute" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { -+ // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> -+ %0 = "stablehlo.collective_permute"(%arg0) { -+ source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> -+ } : (tensor<16x8xf32>) -> tensor<16x8xf32> -+ func.return %0 : tensor<16x8xf32> -+} -+ -+// CHECK-LABEL: "default_collective_broadcast" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { -+ // CHECK: "vhlo.collective_broadcast_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> -+ %0 = "stablehlo.collective_broadcast"(%arg0) { -+ replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> -+ } : (tensor<16x8xf32>) -> tensor<16x8xf32> -+ func.return %0 : tensor<16x8xf32> -+} -+ -+// CHECK-LABEL: "default_compare" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: compare_type = #vhlo, -+ // CHECK-SAME: comparison_direction = #vhlo -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ comparison_direction = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "default_composite" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_composite(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{}> -+ // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> -+ // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> -+ // CHECK-SAME: version = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.composite"(%arg0) { -+ name = "stablehlo.composite_target", -+ decomposition = @composite_target -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "default_convolution" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { -+ // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, -+ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, -+ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> -+ %0 = "stablehlo.convolution"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, -+ feature_group_count = 1 : i64, -+ batch_group_count = 1 : i64 -+ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> -+ func.return %0 : tensor<1x6x6x16xf32> -+} -+ -+// CHECK-LABEL: "default_custom_call" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_custom_call(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: api_version = #vhlo, -+ // CHECK-SAME: backend_config = #vhlo.string_v1<"">, -+ // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, -+ // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, -+ // CHECK-SAME: has_side_effect = #vhlo.bool_v1, -+ // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, -+ // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> -+ // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo" -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "default_dot_general" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { -+ // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: accumulation_type = #vhlo.type_v1, -+ // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, -+ // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: lhs_component_count = #vhlo.type_v1, -+ // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, -+ // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, -+ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, -+ // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: rhs_component_count = #vhlo.type_v1, -+ // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> -+ %0 = "stablehlo.dot_general"(%arg0, %arg1) { -+ dot_dimension_numbers = #stablehlo.dot< -+ lhs_batching_dimensions = [0], -+ lhs_contracting_dimensions = [2], -+ rhs_batching_dimensions = [0], -+ rhs_contracting_dimensions = [1] -+ > -+ } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> -+ func.return %0 : tensor<8x8x8xf32> -+} -+ -+// CHECK-LABEL: "dot_general_algorithm" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { -+// CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ -+// CHECK-SAME: accumulation_type = #vhlo.type_v1, -+// CHECK-SAME: allow_imprecise_accumulation = #vhlo.bool_v1, -+// CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+// CHECK-SAME: lhs_component_count = #vhlo.integer_v1<1 : i64>, -+// CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+// CHECK-SAME: lhs_precision_type = #vhlo.type_v1, -+// CHECK-SAME: num_primitive_operations = #vhlo.integer_v1<1 : i64>, -+// CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, -+// CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+// CHECK-SAME: rhs_component_count = #vhlo.integer_v1<1 : i64>, -+// CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+// CHECK-SAME: rhs_precision_type = #vhlo.type_v1 -+// CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> -+ %0 = "stablehlo.dot_general"(%arg0, %arg1) { -+ dot_dimension_numbers = #stablehlo.dot< -+ lhs_batching_dimensions = [0], -+ lhs_contracting_dimensions = [2], -+ rhs_batching_dimensions = [0], -+ rhs_contracting_dimensions = [1] -+ >, -+ algorithm = #stablehlo.dot_algorithm< -+ lhs_precision_type = tf32, -+ rhs_precision_type = tf32, -+ accumulation_type = f32, -+ lhs_component_count = 1, -+ rhs_component_count = 1, -+ num_primitive_operations = 1, -+ allow_imprecise_accumulation = false -+ > -+ } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> -+ func.return %0 : tensor<8x8x8xf32> -+} -+ -+// CHECK-LABEL: "default_dynamic_broadcast_in_dim" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { -+ // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { -+ broadcast_dimensions = array -+ } : (tensor, tensor<2xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "default_dynamic_conv" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { -+ // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, -+ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> -+ %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { -+ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, -+ feature_group_count = 1 : i64, -+ batch_group_count = 1 : i64 -+ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> -+ func.return %0 : tensor<1x?x?x16xf32> -+} -+ -+// CHECK-LABEL: "default_dynamic_gather" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { -+ // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> -+ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [2], -+ collapsed_slice_dims = [0, 1], -+ start_index_map = [0, 1], -+ index_vector_dim = 2 -+ > -+ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> -+ func.return %0 : tensor<1x5x8xf32> -+} -+ -+func.func @default_func(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.func_v1"() <{ -+ // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, -+ // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, -+ // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, -+ // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, -+ // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : () -> () -+ func.return %arg0 : tensor -+} -+ -+// CHECK-LABEL: "default_gather" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { -+ // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, -+ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [2], -+ collapsed_slice_dims = [0, 1], -+ start_index_map = [0, 1], -+ index_vector_dim = 2 -+ >, -+ slice_sizes = array -+ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> -+ func.return %0 : tensor<1x5x1xf32> -+} -+ -+// CHECK-LABEL: "default_infeed" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { -+ // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, -+ // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> -+ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) -+ %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) -+ func.return %0#0, %0#1 : tensor, !stablehlo.token -+} -+ -+// CHECK-LABEL: "default_outfeed" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { -+ // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 -+ %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token -+ func.return %0 : !stablehlo.token -+} -+ -+// CHECK-LABEL: "default_recv" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { -+ // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) -+ %0:2 = "stablehlo.recv"(%arg0) { -+ channel_handle = #stablehlo.channel_handle -+ } : (!stablehlo.token) -> (tensor, !stablehlo.token) -+ func.return %0#0, %0#1 : tensor, !stablehlo.token -+} -+ -+// CHECK-LABEL: "default_send" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { -+ // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 -+ %0 = "stablehlo.send"(%arg0, %arg1) { -+ channel_handle = #stablehlo.channel_handle -+ } : (tensor, !stablehlo.token) -> !stablehlo.token -+ func.return %0 : !stablehlo.token -+} -+ -+// CHECK-LABEL: "default_reduce_scatter" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { -+ // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, -+ // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.reduce_scatter"(%arg0) ({ -+ ^bb0(%arg1: tensor, %arg2: tensor): -+ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ scatter_dimension = 0 : i64, -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> -+ } : (tensor<16xf32>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "default_reduce_window" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { -+ // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, -+ // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> -+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ -+ ^bb0(%arg2: tensor, %arg3: tensor): -+ %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ window_dimensions = array -+ } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> -+ func.return %0 : tensor<2x16x30x7xf32> -+} -+ -+// CHECK-LABEL: "default_scatter" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { -+ // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: unique_indices = #vhlo.bool_v1, -+ // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [1], -+ inserted_window_dims = [0, 1], -+ scatter_dims_to_operand_dims = [0, 1], -+ index_vector_dim = 1 -+ > -+ } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> -+ func.return %0 : tensor<200x100x300xf32> -+} -+ -+// CHECK-LABEL: "default_select_and_scatter" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { -+ // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, -+ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }, { -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> -+ %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }, { -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ window_dimensions = array -+ } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> -+ func.return %0 : tensor<10x24x24x64xf32> -+} -+ -+// CHECK-LABEL: "default_sort" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { -+ // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> -+ // CHECK-SAME: is_stable = #vhlo.bool_v1 -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.sort"(%arg0) ({ -+ ^bb0(%arg1: tensor, %arg2: tensor): -+ %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) : (tensor<16xf32>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// ============ OPS ============ -+ -+// CHECK-LABEL: "op_abs" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_abs(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_add" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_after_all" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { -+ // CHECK: "vhlo.after_all_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> !vhlo.token_v1 -+ %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token -+ func.return %0 : !stablehlo.token -+} -+ -+// CHECK-LABEL: "op_all_gather" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { -+ // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ -+ // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, -+ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.all_gather"(%arg0) { -+ all_gather_dim = 1 : i64, -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, -+ channel_handle = #stablehlo.channel_handle, -+ use_global_device_ids -+ } : (tensor<16x8xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "op_all_reduce" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_all_reduce(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, -+ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.all_reduce"(%arg0) ({ -+ ^bb0(%arg1: tensor, %arg2: tensor): -+ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, -+ channel_handle = #stablehlo.channel_handle, -+ use_global_device_ids -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_all_reduce_with_promotable_types" -+func.func @op_all_reduce_with_promotable_types(%operand: tensor) -> tensor { -+ // CHECK: "vhlo.all_reduce_v2"(%[[ARG0:.*]]) -+ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () -+ // CHECK: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %result = "stablehlo.all_reduce"(%operand) ({ -+ ^bb0(%arg0: tensor, %arg1: tensor): -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%0) : (tensor) -> () -+ }) { -+ replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, -+ channel_handle = #stablehlo.channel_handle, -+ use_global_device_ids -+ } : (tensor) -> tensor -+ -+ func.return %result : tensor -+} -+ -+// CHECK-LABEL: "default_all_reduce_variadic" -+func.func @default_all_reduce_variadic(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { -+ %0:2 = "stablehlo.all_reduce"(%arg0, %arg1) ({ -+ ^bb0(%arg2: tensor, %arg3: tensor): -+ %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> (tensor) -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> -+ } : (tensor, tensor) -> (tensor, tensor) -+ func.return %0#0, %0#1 : tensor, tensor -+} -+ -+// CHECK-LABEL: "op_all_to_all" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { -+ // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, -+ // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> -+ // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> -+ %0 = "stablehlo.all_to_all"(%arg0) { -+ split_dimension = 1 : i64, -+ concat_dimension = 0 : i64, -+ split_count = 4 : i64, -+ replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, -+ channel_handle = #stablehlo.channel_handle -+ } : (tensor<4x16xf32>) -> tensor<16x4xf32> -+ func.return %0 : tensor<16x4xf32> -+} -+ -+// CHECK-LABEL: "op_and" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_atan2" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.atan2_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_batch_norm_grad" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) -+func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { -+ // CHECK: "vhlo.batch_norm_grad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ -+ // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, -+ // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -+ %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { -+ epsilon = 0.001 : f32, -+ feature_index = 0 : i64 -+ } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -+ func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_batch_norm_inference" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) -+func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { -+ // CHECK: "vhlo.batch_norm_inference_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ -+ // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, -+ // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { -+ epsilon = 0.001 : f32, -+ feature_index = 0 : i64 -+ } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> -+ func.return %0 : tensor<16x16x16x16xf32> -+} -+ -+// CHECK-LABEL: "op_batch_norm_training" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { -+ // CHECK: "vhlo.batch_norm_training_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, -+ // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -+ %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { -+ epsilon = 0.001 : f32, -+ feature_index = 0 : i64 -+ } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -+ func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_bitcast_convert" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_bitcast_convert(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.bitcast_convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_broadcast_in_dim" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { -+ // CHECK: "vhlo.broadcast_in_dim_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.broadcast_in_dim"(%arg0) { -+ broadcast_dimensions = array -+ } : (tensor<16xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "op_broadcast" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { -+ // CHECK: "vhlo.broadcast_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.broadcast"(%arg0) { -+ broadcast_sizes = array -+ } : (tensor<16xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "op_case" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.case"(%arg0) ({ -+ "stablehlo.return"(%arg1) : (tensor) -> () -+ }) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_cbrt" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_cbrt(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.cbrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_ceil" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_ceil(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.ceil_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_cholesky" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { -+ // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: lower = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.cholesky"(%arg0) { -+ lower = true -+ } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> -+ func.return %0 : tensor<1x16x16xf32> -+} -+ -+// CHECK-LABEL: "op_clamp" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { -+ // CHECK: "vhlo.clamp_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_count_leading_zeros" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.count_leading_zeros_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_collective_permute" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { -+ // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> -+ %0 = "stablehlo.collective_permute"(%arg0) { -+ source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, -+ channel_handle = #stablehlo.channel_handle -+ } : (tensor<16x8xf32>) -> tensor<16x8xf32> -+ func.return %0 : tensor<16x8xf32> -+} -+ -+// CHECK-LABEL: "op_compare" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: compare_type = #vhlo, -+ // CHECK-SAME: comparison_direction = #vhlo -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.compare"(%arg0, %arg1) { -+ comparison_direction = #stablehlo, -+ compare_type = #stablehlo -+ } : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_complex" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { -+ // CHECK: "vhlo.complex_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> -+ %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> -+ func.return %0 : tensor> -+} -+ -+// CHECK-LABEL: "op_composite" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_composite(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_int"> = #vhlo.integer_v1<1 : i64>, #vhlo.string_v1<"my_string"> = #vhlo.string_v1<"foo">}> -+ // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> -+ // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> -+ // CHECK-SAME: version = #vhlo.integer_v1<1 : i32> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.composite"(%arg0) { -+ name = "stablehlo.composite_target", -+ decomposition = @composite_target, -+ version = 1 : i32, -+ composite_attributes = { -+ my_string = "foo", -+ my_int = 1 : i64 -+ } -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_concatenate" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { -+ // CHECK: "vhlo.concatenate_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.concatenate"(%arg0, %arg1) { -+ dimension = 0 : i64 -+ } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_constant" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_constant(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.constant_v1"() <{ -+ // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> -+ // CHECK-SAME: }> : () -> !vhlo.tensor_v1 -+ %0 = "stablehlo.constant"() { -+ value = dense<0.0> : tensor -+ } : () -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_convert" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_convert(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_convolution" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { -+ // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, -+ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, -+ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> -+ %0 = "stablehlo.convolution"(%arg0, %arg1) { -+ window_strides = array, -+ padding = dense<1> : tensor<2x2xi64>, -+ lhs_dilation = array, -+ rhs_dilation = array, -+ window_reversal = array, -+ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, -+ feature_group_count = 1 : i64, -+ batch_group_count = 1 : i64, -+ precision_config = [#stablehlo, #stablehlo] -+ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> -+ func.return %0 : tensor<1x7x7x16xf32> -+} -+ -+// CHECK-LABEL: "op_cosine" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_cosine(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.cosine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_create_token" -+func.func @op_create_token() -> !stablehlo.token { -+ // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 -+ %0 = "stablehlo.create_token"() : () -> !stablehlo.token -+ func.return %0 : !stablehlo.token -+} -+ -+// CHECK-LABEL: "op_cross_replica_sum" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.cross-replica-sum_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.cross-replica-sum"(%arg0) { -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_custom_call" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_custom_call(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: api_version = #vhlo, -+ // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, -+ // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, -+ // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, -+ // CHECK-SAME: has_side_effect = #vhlo.bool_v1, -+ // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, -+ // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ -+ // CHECK-SAME: #vhlo.output_operand_alias_v1< -+ // CHECK-SAME: outputTupleIndices = [], -+ // CHECK-SAME: operandIndex = 0, -+ // CHECK-SAME: operandTupleIndices = []>]> -+ // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo", -+ has_side_effect = true, -+ backend_config = "\08\03\1A\02", -+ api_version = 2 : i32, -+ called_computations = [@foo], -+ operand_layouts = [dense<> : tensor<0xindex>], -+ output_operand_aliases = [ -+ #stablehlo.output_operand_alias], -+ result_layouts = [dense<> : tensor<0xindex>] -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_custom_call_empty_result_layout" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func public @op_custom_call_empty_result_layout(%arg0: tensor) -> tensor { -+ // %0 = "vhlo.custom_call_v1"(%arg0) <{>}> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> -+ // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: api_version = #vhlo, -+ // CHECK-SAME: backend_config = #vhlo.string_v1<"">, -+ // CHECK-SAME: call_target_name = #vhlo.string_v1<"empty_output">, -+ // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, -+ // CHECK-SAME: has_side_effect = #vhlo.bool_v1, -+ // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, -+ // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]>, -+ // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> -+ %0 = "stablehlo.custom_call"(%arg0) <{ -+ api_version = 2 : i32, -+ call_target_name = "empty_output", -+ has_side_effect = true, -+ operand_layouts = [dense<> : tensor<0xindex>], -+ result_layouts = [] -+ }> : (tensor) -> tuple<> -+ return %arg0 : tensor -+} -+ -+// CHECK-LABEL: "op_divide" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.divide_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_dot_general" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { -+ // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: accumulation_type = #vhlo.type_v1, -+ // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, -+ // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: lhs_component_count = #vhlo.type_v1, -+ // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, -+ // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, -+ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, -+ // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: rhs_component_count = #vhlo.type_v1, -+ // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> -+ %0 = "stablehlo.dot_general"(%arg0, %arg1) { -+ dot_dimension_numbers = #stablehlo.dot< -+ lhs_batching_dimensions = [0], -+ lhs_contracting_dimensions = [2], -+ rhs_batching_dimensions = [0], -+ rhs_contracting_dimensions = [1] -+ >, -+ precision_config = [#stablehlo, #stablehlo] -+ } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> -+ func.return %0 : tensor<8x8x8xf32> -+} -+ -+// CHECK-LABEL: "op_dot" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { -+ // CHECK: "vhlo.dot_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> -+ %0 = "stablehlo.dot"(%arg0, %arg1) { -+ precision_config = [#stablehlo, #stablehlo] -+ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> -+ func.return %0 : tensor<8x8xf32> -+} -+ -+// CHECK-LABEL: "op_dynamic_broadcast_in_dim" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { -+ // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { -+ broadcast_dimensions = array, -+ known_expanding_dimensions = array, -+ known_nonexpanding_dimensions = array -+ } : (tensor, tensor<2xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_dynamic_conv" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { -+ // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, -+ // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> -+ %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { -+ window_strides = array, -+ lhs_dilation = array, -+ rhs_dilation = array, -+ window_reversal = array, -+ dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, -+ feature_group_count = 1 : i64, -+ batch_group_count = 1 : i64, -+ precision_config = [#stablehlo, #stablehlo] -+ } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> -+ func.return %0 : tensor<1x?x?x16xf32> -+} -+ -+// CHECK-LABEL: "op_dynamic_gather" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { -+ // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> -+ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [2], -+ collapsed_slice_dims = [0, 1], -+ start_index_map = [0, 1], -+ index_vector_dim = 2 -+ >, -+ indices_are_sorted = true -+ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> -+ func.return %0 : tensor<1x5x8xf32> -+} -+ -+// CHECK-LABEL: "op_dynamic_gather_with_batching_dims" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_dynamic_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<4xi32>) -> tensor<1x5x8xf32> { -+ // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> -+ %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [2], -+ collapsed_slice_dims = [1, 2], -+ operand_batching_dims = [0], -+ start_indices_batching_dims = [1], -+ start_index_map = [1, 2], -+ index_vector_dim = 2 -+ >, -+ indices_are_sorted = true -+ } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>, tensor<4xi32>) -> tensor<1x5x8xf32> -+ func.return %0 : tensor<1x5x8xf32> -+} -+ -+// CHECK-LABEL: "op_dynamic_iota" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { -+ // CHECK: "vhlo.dynamic_iota_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.dynamic_iota"(%arg0) { -+ iota_dimension = 0 : i64 -+ } : (tensor<1xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_dynamic_pad" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) -+func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { -+ // CHECK: "vhlo.dynamic_pad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_dynamic_reshape" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { -+ // CHECK: "vhlo.dynamic_reshape_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_dynamic_slice" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { -+ // CHECK: "vhlo.dynamic_slice_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> -+ %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { -+ slice_sizes = array -+ } : (tensor<16xf32>, tensor) -> tensor<4xf32> -+ func.return %0 : tensor<4xf32> -+} -+ -+// CHECK-LABEL: "op_dynamic_update_slice" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { -+ // CHECK: "vhlo.dynamic_update_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_einsum" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { -+ // CHECK: "vhlo.einsum_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> -+ %0 = "stablehlo.einsum"(%arg0, %arg1) { -+ einsum_config = "ab,bc->ac" -+ } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> -+ func.return %0 : tensor<8x8xf32> -+} -+ -+// CHECK-LABEL: "op_exponential_minus_one" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.exponential_minus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_exponential" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_exponential(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.result_accuracy_v1>}> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_fft" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { -+ // CHECK: "vhlo.fft_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: fft_type = #vhlo -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> -+ %0 = "stablehlo.fft"(%arg0) { -+ fft_type = #stablehlo, -+ fft_length = array -+ } : (tensor<16xcomplex>) -> tensor<16xcomplex> -+ func.return %0 : tensor<16xcomplex> -+} -+ -+// CHECK-LABEL: "op_floor" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_floor(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.floor_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { -+ // CHECK: "vhlo.func_v1"() <{ -+ // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, -+ // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, -+ // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, -+ // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, -+ // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : () -> () -+ -+ func.return %arg0 : tensor -+} -+ -+// CHECK-LABEL: "op_gather" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { -+ // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, -+ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [2], -+ collapsed_slice_dims = [0, 1], -+ start_index_map = [0, 1], -+ index_vector_dim = 2 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> -+ func.return %0 : tensor<1x5x1xf32> -+} -+ -+// CHECK-LABEL: "op_gather_with_batching_dims" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { -+ // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [2], -+ collapsed_slice_dims = [1, 2], -+ operand_batching_dims = [0], -+ start_indices_batching_dims = [1], -+ start_index_map = [1, 2], -+ index_vector_dim = 2 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> -+ func.return %0 : tensor<1x5x1xf32> -+} -+ -+// CHECK-LABEL: "op_get_dimension_size" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_get_dimension_size(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.get_dimension_size_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.get_dimension_size"(%arg0) { -+ dimension = 0 : i64 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_get_tuple_element" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { -+ // CHECK: "vhlo.get_tuple_element_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> -+ // CHECK-SAME: }> : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.get_tuple_element"(%arg0) { -+ index = 0 : i32 -+ } : (tuple, tensor>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_if" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { -+ // CHECK: "vhlo.if_v1"(%[[ARG0]]) ({ -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }, { -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG2]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.if"(%arg0) ({ -+ "stablehlo.return"(%arg1) : (tensor) -> () -+ }, { -+ "stablehlo.return"(%arg2) : (tensor) -> () -+ }) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_imag" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_imag(%arg0: tensor>) -> tensor { -+ // CHECK: "vhlo.imag_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_infeed" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { -+ // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, -+ // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> -+ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) -+ %0:2 = "stablehlo.infeed"(%arg0) { -+ infeed_config = "foo", -+ layout = [[]] -+ } : (!stablehlo.token) -> (tensor, !stablehlo.token) -+ func.return %0#0, %0#1 : tensor, !stablehlo.token -+} -+ -+// CHECK-LABEL: "op_iota" -+func.func @op_iota() -> tensor<16xf32> { -+ // CHECK: "vhlo.iota_v1"() <{ -+ // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.iota"() { -+ iota_dimension = 0 : i64 -+ } : () -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_is_finite" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_is_finite(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.is_finite_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_log" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_log(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.log_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_log_plus_one" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_log_plus_one(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.log_plus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_logistic" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_logistic(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.logistic_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_map" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { -+ // CHECK: "vhlo.map_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.map"(%arg0) ({ -+ ^bb0(%arg1: tensor): -+ %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ dimensions = array -+ } : (tensor<16xf32>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_maximum" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.maximum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_minimum" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.minimum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_multiply" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.multiply_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_negate" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_negate(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.negate_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_not" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_not(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.not_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_optimization_barrier" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_optimization_barrier(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.optimization_barrier_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_or" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.or_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_outfeed" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { -+ // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 -+ %0 = "stablehlo.outfeed"(%arg0, %arg1) { -+ outfeed_config = "foo" -+ } : (tensor, !stablehlo.token) -> !stablehlo.token -+ func.return %0 : !stablehlo.token -+} -+ -+// CHECK-LABEL: "op_pad" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { -+ // CHECK: "vhlo.pad_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.pad"(%arg0, %arg1) { -+ edge_padding_high = array, -+ edge_padding_low = array, -+ interior_padding = array -+ } : (tensor<8xf32>, tensor) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_popcnt" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_popcnt(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.popcnt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_power" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.power_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_real_dynamic_slice" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) -+func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { -+ // CHECK: "vhlo.real_dynamic_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_real" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_real(%arg0: tensor>) -> tensor { -+ // CHECK: "vhlo.real_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_recv" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { -+ // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: channel_type = #vhlo.integer_v1<3 : i64>, -+ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) -+ %0:2 = "stablehlo.recv"(%arg0) { -+ channel_handle = #stablehlo.channel_handle, -+ is_host_transfer = true -+ } : (!stablehlo.token) -> (tensor, !stablehlo.token) -+ func.return %0#0, %0#1 : tensor, !stablehlo.token -+} -+ -+// CHECK-LABEL: "op_reduce" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.reduce_v1"(%[[ARG0]], %[[ARG1]]) -+ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () -+ // CHECK: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.reduce"(%arg0, %arg1) ({ -+ ^bb0(%arg2: tensor, %arg3: tensor): -+ %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ dimensions = array -+ } : (tensor<16xf32>, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_reduce_precision" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_reduce_precision(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.reduce_precision_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> -+ // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.reduce_precision"(%arg0) { -+ exponent_bits = 8 : i32, -+ mantissa_bits = 10 : i32 -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK_lABEL: "op_reduce_with_promotable_types" -+func.func @op_reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) -+ -> (tensor<4xf64>) { -+ // CHECK: "vhlo.reduce_v1"(%[[ARG0:.*]], %[[ARG1:.*]]) -+ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () -+ // CHECK: }) : (!vhlo.tensor_v1<4x4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f64_v1> -+ %0 = "stablehlo.reduce"(%arg0, %arg1) ({ -+ ^bb0(%arg2: tensor, %arg3: tensor ): -+ %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ -+ }) {dimensions = array} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> -+ -+ func.return %0: tensor<4xf64> -+} -+ -+// CHECK-LABEL: "op_reduce_scatter" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { -+ // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, -+ // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.reduce_scatter"(%arg0) ({ -+ ^bb0(%arg1: tensor, %arg2: tensor): -+ %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ scatter_dimension = 0 : i64, -+ replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, -+ channel_handle = #stablehlo.channel_handle, -+ use_global_device_ids -+ } : (tensor<16xf32>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK_lABEL: "op_reduce_scatter_with_promotable_types" -+func.func @op_reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { -+ // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0:.*]]) -+ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () -+ // CHECK: }) : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f64_v1> -+ %0 = "stablehlo.reduce_scatter"(%data) ({ -+ ^bb0(%arg2: tensor, %arg3: tensor): -+ %1 = stablehlo.add %arg2, %arg3 : tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, -+ scatter_dimension = 1 : i64, -+ channel_handle = #stablehlo.channel_handle, -+ use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> -+ func.return %0 : tensor<4x4xf64> -+} -+ -+ -+// CHECK-LABEL: "op_reduce_window" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { -+ // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, -+ // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> -+ %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ -+ ^bb0(%arg2: tensor, %arg3: tensor): -+ %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ window_dimensions = array, -+ window_strides = array, -+ base_dilations = array, -+ window_dilations = array, -+ padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> -+ } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> -+ func.return %0 : tensor<2x9x16x7xf32> -+} -+ -+// CHECK-LABEL: "op_reduce_window_with_promotable_types" -+func.func @op_reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, -+ %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> -+ (tensor<2x2xf64>, tensor<2x2xf32>) { -+ // CHECK: "vhlo.reduce_window_v1"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) -+ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): -+ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]], %[[VAL2:.*]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> () -+ // CHECK: }) : (!vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1, !vhlo.tensor_v1) -> (!vhlo.tensor_v1<2x2x!vhlo.f64_v1>, !vhlo.tensor_v1<2x2x!vhlo.f32_v1>) -+ %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ -+ ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, -+ %b1: tensor): -+ %2 = stablehlo.add %a0, %b0 : tensor -+ %3 = stablehlo.add %a1, %b1 : tensor -+ "stablehlo.return"(%2,%3) : (tensor, tensor) -> () -+ }) -+ { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, -+ window_dimensions = array, -+ window_strides = array } -+ : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> -+ (tensor<2x2xf64>, tensor<2x2xf32>) -+ func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> -+} -+ -+// CHECK-LABEL: "op_remainder" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.remainder_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_replica_id" -+func.func @op_replica_id() -> tensor { -+ // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 -+ %0 = "stablehlo.replica_id"() : () -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_partition_id" -+func.func @op_partition_id() -> tensor { -+ // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 -+ %0 = "stablehlo.partition_id"() : () -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_reshape" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { -+ // CHECK: "vhlo.reshape_v1"(%[[ARG0]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> -+ %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> -+ func.return %0 : tensor<4x4xf32> -+} -+ -+// CHECK-LABEL: "op_return" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.case"(%arg0) ({ -+ "stablehlo.return"(%arg1) : (tensor) -> () -+ }) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_reverse" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { -+ // CHECK: "vhlo.reverse_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.reverse"(%arg0) { -+ dimensions = array -+ } : (tensor<16xf32>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_rng_bit_generator" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { -+ // CHECK: "vhlo.rng_bit_generator_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: rng_algorithm = #vhlo -+ // CHECK-SAME: }> : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) -+ %0:2 = "stablehlo.rng_bit_generator"(%arg0) { -+ rng_algorithm = #stablehlo -+ } : (tensor) -> (tensor, tensor) -+ func.return %0#0, %0#1 : tensor, tensor -+} -+ -+// CHECK-LABEL: "op_rng" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { -+ // CHECK: "vhlo.rng_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: rng_distribution = #vhlo -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { -+ rng_distribution = #stablehlo -+ } : (tensor, tensor, tensor<0xindex>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_round_nearest_afz" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.round_nearest_afz_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_round_nearest_even" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_round_nearest_even(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.round_nearest_even_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_rsqrt" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_rsqrt(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.rsqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_scatter" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { -+ // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, -+ // CHECK-SAME: unique_indices = #vhlo.bool_v1, -+ // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [1], -+ inserted_window_dims = [0, 1], -+ scatter_dims_to_operand_dims = [0, 1], -+ index_vector_dim = 1 -+ >, -+ indices_are_sorted = true, -+ unique_indices = true -+ } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> -+ func.return %0 : tensor<200x100x300xf32> -+} -+ -+// CHECK-LABEL: "op_scatter_with_batching_dims" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_scatter_with_batching_dims(%arg0: tensor<10x200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<10x200x100x300xf32> { -+ // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, -+ // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, -+ // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, -+ // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: unique_indices = #vhlo.bool_v1, -+ // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1> -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [1], -+ inserted_window_dims = [1, 2], -+ input_batching_dims = [0], -+ scatter_dims_to_operand_dims = [1, 2], -+ scatter_indices_batching_dims = [0], -+ index_vector_dim = 1 -+ >, -+ indices_are_sorted = true, -+ unique_indices = true -+ } : (tensor<10x200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<10x200x100x300xf32> -+ func.return %0 : tensor<10x200x100x300xf32> -+} -+ -+// CHECK_lABEL: "op_scatter_with_promotable_types" -+func.func @op_scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, -+ %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> -+ tensor<200x100x300xf64> { -+ // CHECK: "vhlo.scatter_v2"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]]) -+ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () -+ // CHECK: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f64_v1> -+ %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ -+ ^bb0(%lhs: tensor, %rhs: tensor): -+ %add = stablehlo.add %lhs, %rhs : tensor -+ "stablehlo.return"(%add) : (tensor) -> () -+ }) { -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [1], -+ inserted_window_dims = [0, 1], -+ scatter_dims_to_operand_dims = [0, 1], -+ index_vector_dim = 1 -+ >, -+ indices_are_sorted = true, -+ unique_indices = true -+ } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> -+ tensor<200x100x300xf64> -+ func.return %0 : tensor<200x100x300xf64> -+} -+ -+// CHECK-LABEL: "op_select_and_scatter" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { -+ // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ -+ // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, -+ // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, -+ // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }, { -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> -+ %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }, { -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ window_dimensions = array, -+ window_strides = array, -+ padding = dense<1> : tensor<4x2xi64> -+ } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> -+ func.return %0 : tensor<10x24x24x64xf32> -+} -+ -+// CHECK-LABEL: "op_select_and_scatter_with_promotable_types" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_select_and_scatter_with_promotable_types(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf64> { -+ // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -+ // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK: %[[VAL:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK: "vhlo.return_v1"(%[[VAL]]) : (!vhlo.tensor_v1) -> () -+ // CHECK: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f64_v1> -+ %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }, { -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ window_dimensions = array, -+ window_strides = array, -+ padding = dense<1> : tensor<4x2xi64> -+ } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf64> -+ func.return %0 : tensor<10x24x24x64xf64> -+} -+ -+// CHECK-LABEL: "op_select" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) -+func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { -+ // CHECK: "vhlo.select_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_send" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { -+ // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, -+ // CHECK-SAME: channel_type = #vhlo.integer_v1<2 : i64>, -+ // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 -+ %0 = "stablehlo.send"(%arg0, %arg1) { -+ channel_handle = #stablehlo.channel_handle, -+ is_host_transfer = true -+ } : (tensor, !stablehlo.token) -> !stablehlo.token -+ func.return %0 : !stablehlo.token -+} -+ -+// CHECK-LABEL: "op_set_dimension_size" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { -+ // CHECK: "vhlo.set_dimension_size_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { -+ dimension = 0 : i64 -+ } : (tensor, tensor) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_shift_left" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.shift_left_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_shift_right_arithmetic" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.shift_right_arithmetic_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_shift_right_logical" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.shift_right_logical_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_sign" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_sign(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.sign_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_sine" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_sine(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.sine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_slice" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { -+ // CHECK: "vhlo.slice_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, -+ // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> -+ %0 = "stablehlo.slice"(%arg0) { -+ start_indices = array, -+ limit_indices = array, -+ strides = array -+ } : (tensor<16xf32>) -> tensor<4xf32> -+ func.return %0 : tensor<4xf32> -+} -+ -+// CHECK-LABEL: "op_sort" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { -+ // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: is_stable = #vhlo.bool_v1 -+ // CHECK-SAME: }> ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> -+ // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> -+ %0 = "stablehlo.sort"(%arg0) ({ -+ ^bb0(%arg1: tensor, %arg2: tensor): -+ %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor -+ "stablehlo.return"(%1) : (tensor) -> () -+ }) { -+ dimension = 0 : i64, -+ is_stable = true -+ } : (tensor<16xf32>) -> tensor<16xf32> -+ func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "op_sqrt" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_sqrt(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.sqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_subtract" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.subtract_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_tan" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_tan(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.tan_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.tan"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_tanh" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_tanh(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.tanh_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_torch_index_select" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { -+ // CHECK: "vhlo.torch_index_select_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> -+ %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { -+ dim = 0 : i64, -+ batch_dims = 0 : i64 -+ } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> -+ func.return %0 : tensor<2x1x5xf32> -+} -+ -+// CHECK-LABEL: "op_transpose" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { -+ // CHECK: "vhlo.transpose_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> -+ %0 = "stablehlo.transpose"(%arg0) { -+ permutation = array -+ } : (tensor<16x8xf32>) -> tensor<8x16xf32> -+ func.return %0 : tensor<8x16xf32> -+} -+ -+// CHECK-LABEL: "op_triangular_solve" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { -+ // CHECK: "vhlo.triangular_solve_v1"(%[[ARG0]], %[[ARG1]]) <{ -+ // CHECK-SAME: left_side = #vhlo.bool_v1, -+ // CHECK-SAME: lower = #vhlo.bool_v1, -+ // CHECK-SAME: transpose_a = #vhlo, -+ // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> -+ %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { -+ left_side = true, -+ lower = true, -+ unit_diagonal = true, -+ transpose_a = #stablehlo -+ } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> -+ func.return %0 : tensor<16x16xf32> -+} -+ -+// CHECK-LABEL: "op_tuple" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_tuple(%arg0: tensor) -> tuple> { -+ // CHECK: "vhlo.tuple_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> -+ %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> -+ func.return %0 : tuple> -+} -+ -+// CHECK-LABEL: "op_unary_einsum" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { -+ // CHECK: "vhlo.unary_einsum_v1"(%[[ARG0]]) <{ -+ // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> -+ // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> -+ %0 = "stablehlo.unary_einsum"(%arg0) { -+ einsum_config = "ab->a" -+ } : (tensor<8x16xf32>) -> tensor<8xf32> -+ func.return %0 : tensor<8xf32> -+} -+ -+// CHECK-LABEL: "op_uniform_dequantize" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { -+ // CHECK: "vhlo.uniform_dequantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "op_uniform_quantize" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { -+ // CHECK: "vhlo.uniform_quantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> -+ %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> -+ func.return %0 : tensor> -+} -+ -+// CHECK-LABEL: "op_while" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @op_while(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.while_v1"(%[[ARG0]]) ({ -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }, { -+ // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) -+ // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () -+ // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.while"(%arg0) ({ -+ ^bb0(%arg1: tensor): -+ "stablehlo.return"(%arg1) : (tensor) -> () -+ }, { -+ ^bb0(%arg1: tensor): -+ "stablehlo.return"(%arg1) : (tensor) -> () -+ }) : (tensor) -> tensor -+ func.return %0: tensor -+} -+ -+// CHECK-LABEL: "op_xor" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.xor_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// ============ TYPES ============ -+ -+// CHECK-LABEL: "type_i1" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_i2" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_i2(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_i4" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_i8" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_i16" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_i32" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_i64" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_ui2" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_ui2(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_ui4" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_ui8" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_ui16" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_ui32" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_ui64" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f4E2M1FN" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f4E2M1FN(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f6E2M3FN" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f6E2M3FN(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f6E3M2FN" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f6E3M2FN(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E3M4" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E4M3" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E4M3FN" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E5M2" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E4M3FNUZ" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E4M3B11FNUZ" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E5M2FNUZ" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f8E8M0FNU" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f8E8M0FNU(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_bf16" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f16" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f32" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_f64" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_complex_f32" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> -+ func.return %0 : tensor> -+} -+ -+// CHECK-LABEL: "type_complex_f64" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> -+ func.return %0 : tensor> -+} -+ -+// CHECK-LABEL: "type_tf32" -+// CHECK: #vhlo.type_v1 -+func.func @type_tf32() attributes {stablehlo.attr = tf32 } { -+ return -+} -+ -+// CHECK-LABEL: "type_none" -+// CHECK: #vhlo.type_v1 -+func.func @type_none() attributes {stablehlo.attr = none } { -+ return -+} -+ -+// CHECK-LABEL: "type_dynamism_ranked" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { -+ // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// CHECK-LABEL: "type_per_tensor_quantization" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) -+func.func @type_per_tensor_quantization(%arg0: tensor>, %arg1: tensor>) -> tensor> { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> -+ %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> -+ func.return %0 : tensor> -+} -+ -+// CHECK-LABEL: "type_per_axis_quantization" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { -+ // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG0]]) : (!vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>, !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>) -> !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1> -+ %0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform> -+ func.return %0 : tensor<2x!quant.uniform> -+} -+ -+// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> -+// CHECK-LABEL: "type_token_callee" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { -+ // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> () -+ return %arg0 : !stablehlo.token -+} -+ -+// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> -+// CHECK-LABEL: "type_token_caller" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { -+ // CHECK: "vhlo.call_v1"(%[[ARG0]]) <{callee = #vhlo.string_v1<"type_token_callee">} -+ // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 -+ %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token -+ return %0 : !stablehlo.token -+} -+ -+// CHECK-LABEL: "type_tuple" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) -+func.func @type_tuple(%arg0: tuple>) -> tuple { -+ %0 = "stablehlo.custom_call"(%arg0) { -+ call_target_name = "foo" -+ // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 -+ } : (tuple>) -> tuple -+ return %0 : tuple -+} -+ -+// ============ DEPENDENCIES ============ -+ -+func.func @composite_target(%arg0: tensor) -> tensor { -+ return %arg0: tensor -+} -diff --ruN a/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir b/stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir ---- stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir -+++ stablehlo/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir -@@ -248,6 +248,36 @@ - fft_length = array - } : (tensor<9xcomplex>) -> tensor<16xf32> - func.return %0 : tensor<16xf32> -+} -+ -+// CHECK-LABEL: "attr_result_accuracy_HIGHEST" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} -+func.func @attr_result_accuracy_HIGHEST(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor<8x16xf32>) -> tensor<8x16xf32> -+ func.return %0 : tensor<8x16xf32> -+} -+ -+// CHECK-LABEL: "attr_result_accuracy_TOLERANCE" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} -+func.func @attr_result_accuracy_TOLERANCE(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor<8x16xf32>) -> tensor<8x16xf32> -+ func.return %0 : tensor<8x16xf32> -+} -+ -+// CHECK-LABEL: "attr_result_accuracy_DEFAULT" -+// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} -+func.func @attr_result_accuracy_DEFAULT(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: result_accuracy = #vhlo.result_accuracy_v1> -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor<8x16xf32>) -> tensor<8x16xf32> -+ func.return %0 : tensor<8x16xf32> - } - - // GatherDimensionNumbers aka #stablehlo.gather is covered below. -@@ -1621,7 +1651,7 @@ - // CHECK-LABEL: "op_exponential" - // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) - func.func @op_exponential(%arg0: tensor) -> tensor { -- // CHECK: "vhlo.exponential_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 -+ // CHECK: "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.result_accuracy_v1>}> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor - func.return %0 : tensor - } -diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir ---- stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir -+++ stablehlo/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir -@@ -0,0 +1,26 @@ -+// RUN: stablehlo-opt --vhlo-to-version=target=1.9.0 -verify-diagnostics --split-input-file %s -+ -+func.func @invalid_array_element() -> () attributes { -+ // expected-error @+1 {{expected array of VHLO attriutes}} -+ vhlo.attr = #vhlo.array_v1<[#stablehlo]> -+} { -+ return -+} -+ -+// ----- -+ -+func.func @invalid_dict_element_value() -> () attributes { -+ // expected-error @+1 {{expected VHLO attribute}} -+ vhlo.attr = #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = 3 : i32}> -+} { -+ return -+} -+ -+// ----- -+ -+func.func @invalid_result_accuracy() -> () attributes { -+ // expected-error @+1 {{expected VHLO result accuracy mode}} -+ vhlo.attr = #vhlo.result_accuracy_v1> -+} { -+ return -+} -diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir ---- stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir -+++ stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir -@@ -0,0 +1,24 @@ -+// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.8.0' %s | FileCheck %s -+ -+// ExpOp was changed in v1.9.0 to have -+// result_accuracy attribute. Ensure that serializing for 1.8.0 is valid and targets the -+// v1.8.0 opset. -+// -+// This will catch issues in op `isLegal` checks: -+// op.minVersion() <= target <= op.maxVersion() -+ -+// CHECK-LABEL: vhlo.func_v1 @exp_op -+func.func public @exp_op(%arg0: tensor) -> tensor { -+ // CHECK: vhlo.exponential_v1 -+ %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor -+ return %0 : tensor -+} -+ -+// CHECK-LABEL: vhlo.func_v1 @exp_op_default -+func.func @exp_op_default(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: vhlo.exponential_v1 -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir ---- stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir -+++ stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir -@@ -0,0 +1,22 @@ -+// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.8.0' --verify-diagnostics --split-input-file %s -+ -+ -+func.func @attr_result_accuracy_default(%arg0: tensor) -> tensor { -+ %0 = "stablehlo.exponential"(%arg0) { -+ // CHECK: vhlo.exponential_v1 -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -+// ----- -+ -+// expected-error @-3 {{failed to convert VHLO to v1.8.0}} -+func.func @attr_result_accuracy_highest(%arg0: tensor) -> tensor { -+ // expected-error @+1 {{failed to legalize operation 'vhlo.exponential_v2' that was explicitly marked illegal}} -+ %0 = "stablehlo.exponential"(%arg0) { -+ result_accuracy = #stablehlo.result_accuracy> -+ } : (tensor) -> tensor -+ func.return %0 : tensor -+} -+ -diff --ruN a/stablehlo/stablehlo/transforms/MapStablehloToVhlo.h b/stablehlo/stablehlo/transforms/MapStablehloToVhlo.h ---- stablehlo/stablehlo/transforms/MapStablehloToVhlo.h -+++ stablehlo/stablehlo/transforms/MapStablehloToVhlo.h -@@ -94,7 +94,7 @@ - MAP_STABLEHLO_TO_VHLO(DynamicUpdateSliceOp, V1) - MAP_STABLEHLO_TO_VHLO(EinsumOp, V1) - MAP_STABLEHLO_TO_VHLO(Expm1Op, V1) --MAP_STABLEHLO_TO_VHLO(ExpOp, V1) -+MAP_STABLEHLO_TO_VHLO(ExpOp, V2) - MAP_STABLEHLO_TO_VHLO(FftOp, V1) - MAP_STABLEHLO_TO_VHLO(FloorOp, V1) - MAP_STABLEHLO_TO_VHLO(GatherOp, V2) -diff --ruN a/stablehlo/stablehlo/transforms/Passes.h b/stablehlo/stablehlo/transforms/Passes.h ---- stablehlo/stablehlo/transforms/Passes.h -+++ stablehlo/stablehlo/transforms/Passes.h -@@ -25,12 +25,17 @@ - #include "mlir/Pass/Pass.h" - #include "mlir/Support/LogicalResult.h" - #include "mlir/Transforms/DialectConversion.h" -+#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "stablehlo/dialect/Version.h" - - namespace mlir { - namespace stablehlo { - - #define GEN_PASS_DECL -+ -+std::unique_ptr<::mlir::Pass> createStablehloAggressiveSimplificationPass( -+ GreedyRewriteConfig config); -+ - #define GEN_PASS_REGISTRATION - #include "stablehlo/transforms/Passes.h.inc" - -diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp ---- stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp -+++ stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp -@@ -11,6 +11,7 @@ - #include - #include - #include -+#include - #include - #include - -@@ -21,6 +22,7 @@ - #include "llvm/ADT/SmallVector.h" - #include "llvm/ADT/SmallVectorExtras.h" - #include "llvm/Support/ErrorHandling.h" -+#include "mlir/Dialect/Arith/IR/Arith.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Block.h" - #include "mlir/IR/Builders.h" -@@ -38,6 +40,7 @@ - #include "mlir/IR/TypeUtilities.h" - #include "mlir/IR/Value.h" - #include "mlir/IR/ValueRange.h" -+#include "mlir/Pass/Pass.h" - #include "mlir/Rewrite/FrozenRewritePatternSet.h" - #include "mlir/Support/LLVM.h" - #include "mlir/Support/LogicalResult.h" -@@ -1447,12 +1450,18 @@ - return rewriter.notifyMatchFailure( - op, "defining operation of unexpected type"); - -+ // Reshape and broadcast are not allowed to have dynamic shape. -+ Value result = op->getResult(0); -+ if (isa(definingOp) && -+ !cast(result.getType()).hasStaticShape()) -+ return rewriter.notifyMatchFailure( -+ op, "cannot reorder around reshape/broadcast with dynamic shape"); -+ - // Only reorder if the defining op has no other uses. - if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) - return rewriter.notifyMatchFailure(op, "operation has more than one use"); - - Value input = definingOp->getOperand(0); -- Value result = op->getResult(0); - auto intermediateType = cast(input.getType()) - .clone(getElementTypeOrSelf(result.getType())); - -@@ -1470,6 +1479,9 @@ - struct StablehloAggressiveSimplificationPass final - : impl::StablehloAggressiveSimplificationPassBase< - StablehloAggressiveSimplificationPass> { -+ StablehloAggressiveSimplificationPass() = default; -+ StablehloAggressiveSimplificationPass(GreedyRewriteConfig config) -+ : config(config) {} - LogicalResult initialize(MLIRContext *context) override { - RewritePatternSet patterns_(context); - populateStablehloCanonicalizationPatterns(context, &patterns_); -@@ -1478,11 +1490,12 @@ - } - - void runOnOperation() override { -- if (failed(applyPatternsGreedily(getOperation(), patterns))) -+ if (failed(applyPatternsGreedily(getOperation(), patterns, config))) - signalPassFailure(); - } - - private: -+ GreedyRewriteConfig config; - FrozenRewritePatternSet patterns; - }; - -@@ -1515,5 +1528,10 @@ - DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context); - } - -+std::unique_ptr createStablehloAggressiveSimplificationPass( -+ GreedyRewriteConfig config) { -+ return std::make_unique(config); -+} -+ - } // namespace stablehlo - } // namespace mlir -diff --ruN a/stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td ---- stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td -+++ stablehlo/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td -@@ -683,12 +683,15 @@ - // Notice that for `y != 0`, neither `cos(y)` nor `sin(y)` is never - // zero on the set of floating point numbers. - // --def ExpOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_ExpOp ComplexElementType:$z), -+def ConstDefaultResultAccuracyAttr : -+ ConstantAttr; -+ -+def ExpOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_ExpOp ComplexElementType:$z, ConstDefaultResultAccuracyAttr), - (StableHLO_ComplexOp - (StableHLO_SelectOp - (StableHLO_CompareOp:$eq_e_constant_posinf - (StableHLO_ExpOp:$e -- (StableHLO_RealOp:$x $z)), -+ (StableHLO_RealOp:$x $z), ConstDefaultResultAccuracyAttr), - (StableHLO_ConstantLikePosInfValue $x), - StableHLO_ComparisonDirectionValue<"EQ">, - (STABLEHLO_DEFAULT_COMPARISON_TYPE)), -@@ -697,7 +700,7 @@ - (StableHLO_ExpOp:$e2 - (StableHLO_MulOp - $x, -- (StableHLO_ConstantLike<"0.5"> $x))), -+ (StableHLO_ConstantLike<"0.5"> $x)), ConstDefaultResultAccuracyAttr), - (StableHLO_CosineOp:$cs - (StableHLO_ImagOp:$y $z))), - $e2), -diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp ---- stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -+++ stablehlo/stablehlo/transforms/StablehloLegalizeToVhlo.cpp -@@ -19,6 +19,7 @@ - - #include "llvm/Support/Casting.h" - #include "llvm/Support/Debug.h" -+#include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/Builders.h" -@@ -129,6 +130,16 @@ - } - if (auto attr = dyn_cast(stablehloAttr)) { - RETURN_CONVERTED_ENUM_ATTR(Transpose, V1); -+ } -+ if (auto attr = dyn_cast(stablehloAttr)) { -+ RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode, V1); -+ } -+ if (auto attr = dyn_cast(stablehloAttr)) { -+ auto modeAttr = convertGeneric(attr.getMode(), typeConverter); -+ if (!modeAttr) return {}; -+ return vhlo::ResultAccuracyV1Attr::get(attr.getContext(), attr.getAtol(), -+ attr.getRtol(), attr.getUlps(), -+ modeAttr); - } - if (stablehloAttr.getDialect().getNamespace() == - stablehlo::StablehloDialect::getDialectNamespace()) { -@@ -815,6 +826,19 @@ - } - } - } -+ if constexpr (std::is_same::value) { -+ if (!stablehloOp.getResultAccuracyAttr()) -+ addDefaultAttr("result_accuracy", -+ stablehlo::ResultAccuracyAttr::get( -+ pattern.getContext(), -+ /*atol=*/APFloat(0.0), -+ /*rtol=*/APFloat(0.0), -+ /*ulps=*/0, -+ /*mode=*/ -+ stablehlo::ResultAccuracyModeAttr::get( -+ pattern.getContext(), -+ stablehlo::ResultAccuracyMode::DEFAULT))); -+ } - if constexpr (std::is_same::value) { - if (!stablehloOp.getKnownExpandingDimensionsAttr()) -diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ---- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp -@@ -109,12 +109,14 @@ - // their operands and results. Any operand type in these ops can change - // within what's supported by `inferMostSpecificType` without breaking - // verification of the op. -- if (isa(user->getDialect())) -+ if (isa( -+ user->getDialect())) - continue; - // TODO(bartchr): Consider if the dialect allow-listing approach is too - // strict. In the meantime, allow some shape interop with the shardy - // dialect. -- if (user->getDialect()->getNamespace() == "sdy") continue; -+ if (user->getDialect()->getNamespace() == "sdy") -+ continue; - - // Simply changing operand type of `func.return` won't work because - // that won't update the FunctionType of the enclosing `func.func`. -diff --ruN a/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/transforms/VhloLegalizeToStablehlo.cpp -@@ -23,6 +23,7 @@ - #include "llvm/Support/AllocatorBase.h" - #include "llvm/Support/Casting.h" - #include "llvm/Support/Debug.h" -+#include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/Attributes.h" - #include "mlir/IR/BuiltinAttributes.h" -@@ -168,6 +169,17 @@ - auto builtinType = typeConverter->convertType(attr.getValue()); - if (!builtinType) return {}; - return TypeAttr::get(builtinType); -+ } -+ if (auto attr = dyn_cast(vhloAttr)) { -+ RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode, V1); -+ } -+ if (auto attr = dyn_cast(vhloAttr)) { -+ auto modeAttr = dyn_cast_or_null( -+ convertGeneric(attr.getMode(), typeConverter)); -+ if (!modeAttr) return {}; -+ return stablehlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(), -+ attr.getRtol(), attr.getUlps(), -+ modeAttr); - } - - // All VHLO Attributes must be converted by now. -@@ -737,6 +749,13 @@ - }); - } - -+bool isDefaultResultAccuracyAttribute(Attribute vhloAttr) { -+ auto attr = dyn_cast_or_null(vhloAttr); -+ return attr.getAtol().isZero() && attr.getRtol().isZero() && -+ attr.getUlps() == 0 && -+ dyn_cast(attr.getMode()).getValue() == -+ vhlo::ResultAccuracyModeV1::DEFAULT; -+} - template - bool isSplatTensor(const ConversionPattern& pattern, Attribute vhloAttr, - T splatValue) { -@@ -897,6 +916,11 @@ - eraseAttrs(vhloAttrs, "dimension"); - if (isBoolean(vhloOp.getIsStableAttr(), false)) - eraseAttrs(vhloAttrs, "is_stable"); -+ } -+ if constexpr (std::is_same::value) { -+ if (isDefaultResultAccuracyAttribute(vhloOp.getResultAccuracyAttr())) { -+ eraseAttrs(vhloAttrs, "result_accuracy"); -+ } - } - return success(); - } -diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp ---- stablehlo/stablehlo/transforms/VhloToVersion.cpp -+++ stablehlo/stablehlo/transforms/VhloToVersion.cpp -@@ -139,6 +139,8 @@ - return isLegalType(tensorAttr.getType(), targetVersion); - if (auto typeAttr = dyn_cast(attr)) - return isLegalType(typeAttr.getValue(), targetVersion); -+ if (auto resultAccuracyAttr = dyn_cast(attr)) -+ return isLegalAttribute(resultAccuracyAttr.getMode(), targetVersion); - - // Is VHLO and valid version, success. - return success(); -@@ -324,6 +326,22 @@ - denseElements.getRawData()); - } - -+bool isDefaultResultAccuracy(Attribute attr) { -+ auto resultAccuracy = dyn_cast(attr); -+ auto default_mode = ResultAccuracyModeV1Attr::get( -+ attr.getContext(), ResultAccuracyModeV1::DEFAULT); -+ return resultAccuracy.getAtol().isZero() && -+ resultAccuracy.getRtol().isZero() && resultAccuracy.getUlps() == 0 && -+ resultAccuracy.getMode() == default_mode; -+} -+ -+ResultAccuracyV1Attr getDefaultResultAccuracy(OpBuilder& builder) { -+ return ResultAccuracyV1Attr::get( -+ builder.getContext(), APFloat(0.0), APFloat(0.0), 0, -+ ResultAccuracyModeV1Attr::get(builder.getContext(), -+ ResultAccuracyModeV1::DEFAULT)); -+} -+ - // DRR has limited support for ops with regions - struct ScatterOpV2ToV1 : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; -@@ -393,6 +411,40 @@ - } - }; - -+struct ExpOpV1ToV2 : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(ExpOpV1 op, -+ PatternRewriter& rewriter) const override { -+ ResultAccuracyV1Attr defaultResultAccuracy = ResultAccuracyV1Attr::get( -+ rewriter.getContext(), APFloat(0.0), APFloat(0.0), 0, -+ ResultAccuracyModeV1Attr::get(rewriter.getContext(), -+ ResultAccuracyModeV1::DEFAULT)); -+ rewriter.replaceOpWithNewOp( -+ op, op->getResultTypes(), op.getOperand(), defaultResultAccuracy); -+ return success(); -+ } -+}; -+ -+struct ExpOpV2ToV1 : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(ExpOpV2 op, -+ PatternRewriter& rewriter) const override { -+ auto defaultResultAccuracy = ResultAccuracyV1Attr::get( -+ rewriter.getContext(), APFloat(0.0), APFloat(0.0), 0, -+ ResultAccuracyModeV1Attr::get(rewriter.getContext(), -+ ResultAccuracyModeV1::DEFAULT)); -+ if (op.getResultAccuracy() != defaultResultAccuracy) { -+ return rewriter.notifyMatchFailure(op, -+ "non-default result accuracy attr"); -+ } -+ rewriter.replaceOpWithNewOp(op, op->getResultTypes(), -+ op.getOperand()); -+ return success(); -+ } -+}; -+ - #include "stablehlo/transforms/VhloToVersionPatterns.h.inc" - - } // namespace -@@ -405,6 +457,7 @@ - vhlo::populateWithGenerated(*patterns); - patterns->add(context); - patterns->add(context); -+ patterns->add(context); - } - - } // namespace stablehlo -diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersionPatterns.td b/stablehlo/stablehlo/transforms/VhloToVersionPatterns.td ---- stablehlo/stablehlo/transforms/VhloToVersionPatterns.td -+++ stablehlo/stablehlo/transforms/VhloToVersionPatterns.td -@@ -15,6 +15,9 @@ - - include "mlir/IR/OpBase.td" - include "stablehlo/dialect/VhloOps.td" -+include "mlir/IR/CommonAttrConstraints.td" -+include "stablehlo/dialect/VhloEnums.td" -+include "stablehlo/dialect/VhloAttrs.td" - - def VHLO_GetEmptyDims : NativeCodeCall<"getEmptyI64Tensor($_builder)">; - -@@ -31,6 +34,11 @@ - def VHLO_GetFirstOperand : NativeCodeCall<"$0.front()">; - - def VHLO_WrapInVector : NativeCodeCall<"{$0}">; -+ -+def VHLO_GetDefaultResultAccuracyAttr : NativeCodeCall<"getDefaultResultAccuracy($_builder)">; -+ -+ -+def VHLO_DefaultResultAccuracy : AttrConstraint, "Default result accuracy">; - - def DynamicConvUpgradeV1ToV2: - Pat<(VHLO_DynamicConvOpV1 $lhs, $rhs, $d_padding, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config), -@@ -83,3 +91,11 @@ - Pat<(VHLO_DotGeneralOpV1 $lhs, $rhs, $lhs_batching_dimensions, $rhs_batching_dimensions, $lhs_contracting_dimensions, $rhs_contracting_dimensions, $precision_config), - (VHLO_DotGeneralOpV2 $lhs, $rhs, $lhs_batching_dimensions, $rhs_batching_dimensions, $lhs_contracting_dimensions, $rhs_contracting_dimensions, $precision_config, - (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType))>; -+ -+def ExpOpDowngradeV2ToV1 : -+ Pat<(VHLO_ExpOpV2 $operand, VHLO_DefaultResultAccuracy:$result_accuracy), -+ (VHLO_ExpOpV1 $operand)>; -+ -+def ExpOpUpgradeV1ToV2 : -+ Pat<(VHLO_ExpOpV1 $operand), -+ (VHLO_ExpOpV2 $operand, (VHLO_GetDefaultResultAccuracyAttr))>; diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index ace3c825e4b13..c146ccae5ca13 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "c27ba678a712a401e4a6db75ec0ef9e1ce9e1777" - STABLEHLO_SHA256 = "af3ade86200a10ef75d816147db1b5151aa2788da99289ef133e49453aee3f14" + STABLEHLO_COMMIT = "48a1e14edc8219577fcad53de1924876f855f431" + STABLEHLO_SHA256 = "b35e16723afe3ea142c4fe6a44e56885985b28e1b036945f4c2d230e1a8907cb" # LINT.ThenChange(Google-internal path) tf_http_archive(