From 73023cbf6884d2dacc794a12f518a785a4c7d2d7 Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Mon, 4 Mar 2024 13:25:16 -0800 Subject: [PATCH] Pattern matching for uniform_quantize/dequantize ops with convolution --- xla/service/cpu/cpu_compiler.cc | 36 +-- xla/service/cpu/onednn_convolution.cc | 108 +++++++- .../cpu/onednn_convolution_rewriter.cc | 245 +++++++++++++++++- xla/service/cpu/onednn_memory_util.cc | 20 +- xla/tests/onednn_convolution_test.cc | 138 ++++++++++ 5 files changed, 510 insertions(+), 37 deletions(-) diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index ba2ebb6db5346..4b89261aa1d4b 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -910,24 +910,6 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( ? module->config().intra_op_parallelism_threads() : tsl::port::NumSchedulableCPUs(); -#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - // AOT compiled code runs in single thread. - if (!is_aot_compile) { - // Run SimplifyFPConversions pass to simplify the BF16 pattern and make it - // easier to match. - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(max_parallelism, - compile_options.thread_pool); - // Run SimplifyFPConversions pass again to remove redundant Convert ops - // that may exist as a result of running OneDnnMatMulRewriter pass. - pipeline.AddPass(); - } -#endif // INTEL_MKL && ENABLE_ONEDNN_V3 - - // Add a fusion pass now that layout assignment is done. - pipeline.AddPass(); - // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. // Run this to a fixed point. @@ -951,6 +933,24 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( pipeline.AddPass(/*is_layout_sensitive=*/true); }(); +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + // AOT compiled code runs in single thread. + if (!is_aot_compile) { + // Run SimplifyFPConversions pass to simplify the BF16 pattern and make it + // easier to match. + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + compile_options.thread_pool); + // Run SimplifyFPConversions pass again to remove redundant Convert ops + // that may exist as a result of running OneDnnMatMulRewriter pass. + pipeline.AddPass(); + } +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 + + // Add a fusion pass now that layout assignment is done. + pipeline.AddPass(); + // Outline ops in the entry computation into calls to subcomputations. if (!is_aot_compile) { // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. diff --git a/xla/service/cpu/onednn_convolution.cc b/xla/service/cpu/onednn_convolution.cc index c6da68c550bf8..5e08b77876bd1 100644 --- a/xla/service/cpu/onednn_convolution.cc +++ b/xla/service/cpu/onednn_convolution.cc @@ -149,6 +149,17 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( MemrefInfo ker_minfo(args[arg_indx++]); MemrefInfo res_minfo(result); + memory::data_type res_dt = res_minfo.GetOneDnnDataType(); + bool quant_result = + (res_dt == memory::data_type::s8 || res_dt == memory::data_type::u8); + bool quant_operands = ker_minfo.GetOneDnnDataType() == memory::data_type::s8; + memory::data_type inp_dt = inp_minfo.GetOneDnnDataType(); + if (quant_operands) { + // Hybrid quantization is not currently supported. + XLA_LIGHTWEIGHT_CHECK(inp_dt == memory::data_type::s8 || + inp_dt == memory::data_type::u8); + } + // Permute memory descriptors auto inp_md = inp_minfo.GetOneDnnMemDesc(); auto ker_md = ker_minfo.GetOneDnnMemDesc(); @@ -200,14 +211,94 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( << std::endl; } } - - XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx); dnnl::primitive_attr attrs; if (post_ops.len() > 0) { - attrs.set_post_ops(post_ops); + attrs.set_post_ops(post_ops); + } + + auto src_scale_mem = memory(nullptr); + auto src_zp_mem = memory(nullptr); + auto wei_scale_mem = memory(nullptr); + auto wei_zp_mem = memory(nullptr); + auto dst_scale_mem = memory(nullptr); + auto dst_zp_mem = memory(nullptr); + + std::vector src_zp_vec(1); + std::vector dst_zp_vec(1); + std::vector dst_scale_vec(1); + + if (quant_operands) { + MemrefInfo src_scale_minfo(args[arg_indx++]); + MemrefInfo src_zp_minfo(args[arg_indx++]); + MemrefInfo wei_scale_minfo(args[arg_indx++]); + MemrefInfo wei_zp_minfo(args[arg_indx++]); + + auto src_scale_md = src_scale_minfo.GetOneDnnMemDesc(); + auto src_zp_md = src_zp_minfo.GetOneDnnMemDesc(); + auto wei_scale_md = wei_scale_minfo.GetOneDnnMemDesc(); + int wei_scale_size = wei_scale_md.get_dims()[0]; + auto wei_zp_md = wei_zp_minfo.GetOneDnnMemDesc(); + + // oneDNN only supports common scale/zp for src (no per-channel support). + XLA_LIGHTWEIGHT_CHECK(src_scale_md.get_dims()[0] == 1); + XLA_LIGHTWEIGHT_CHECK(src_zp_md.get_dims()[0] == + src_scale_md.get_dims()[0]); + + src_scale_mem = memory(src_scale_md, cpu_engine, src_scale_minfo.Data()); + int* src_zp_data = (int*)src_zp_minfo.Data(); + + // We need to negate the sign of the zp to get the original one because the + // hlo optimizer flips the zp sign in uniform_dequantize pattern. + // TODO (intel-tf): we need to do that based on some flag passed from the + // rewriter. + src_zp_vec[0] = src_zp_data[0] * -1; + src_zp_mem = memory(src_zp_md, cpu_engine, src_zp_vec.data()); + wei_scale_mem = memory(wei_scale_md, cpu_engine, wei_scale_minfo.Data()); + wei_zp_mem = memory(wei_zp_md, cpu_engine, wei_zp_minfo.Data()); + + if (quant_result) { + MemrefInfo dst_scale_minfo(args[arg_indx++]); + MemrefInfo dst_zp_minfo(args[arg_indx++]); + + auto dst_scale_md = dst_scale_minfo.GetOneDnnMemDesc(); + auto dst_zp_md = dst_zp_minfo.GetOneDnnMemDesc(); + + // oneDNN only supports common scale/zp for dst (no per-channel support). + XLA_LIGHTWEIGHT_CHECK(dst_scale_md.get_dims()[0] == 1); + XLA_LIGHTWEIGHT_CHECK(dst_zp_md.get_dims()[0] == + dst_scale_md.get_dims()[0]); + + float* scale_data = (float*)dst_scale_minfo.Data(); + // We need to compute the reciprocal of scale to get the original one + // because the hlo optimizer changes it in uniform_quantize pattern. + // TODO (intel-tf): we need to do that based on some flag passed from the + // rewriter. + dst_scale_vec[0] = 1.0 / scale_data[0]; + if (dst_zp_md.get_data_type() == memory::data_type::f32) { + // oneDNN expects zp to be int32 not f32. + dst_zp_vec[0] = static_cast(((float*)dst_zp_minfo.Data())[0]); + } else { + dst_zp_vec[0] = ((int*)dst_zp_minfo.Data())[0]; + } + dst_scale_mem = memory(dst_scale_md, cpu_engine, dst_scale_vec.data()); + auto dst_zp_md_new = memory::desc( + dst_zp_md.get_dims(), memory::data_type::s32, memory::format_tag::x); + dst_zp_mem = memory(dst_zp_md_new, cpu_engine, dst_zp_vec.data()); + } + + attrs.set_scales_mask(DNNL_ARG_SRC, 0); + attrs.set_zero_points_mask(DNNL_ARG_SRC, 0); + const int wei_mask = (wei_scale_size == 1) ? 0 : (groups > 1) ? 3 : 1; + attrs.set_scales_mask(DNNL_ARG_WEIGHTS, wei_mask); + if (quant_result) { + attrs.set_scales_mask(DNNL_ARG_DST, 0); + attrs.set_zero_points_mask(DNNL_ARG_DST, 0); + } } + XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx); + memory::dims strides_dims = strds; memory::dims padding_dims_l = pad_l; memory::dims padding_dims_r = pad_r; @@ -240,6 +331,17 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( {DNNL_ARG_WEIGHTS, new_ker_mem}, {DNNL_ARG_BIAS, bias_mem}, {DNNL_ARG_DST, new_res_mem}}; + if (quant_operands) { + conv_args.insert( + {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scale_mem}, + {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scale_mem}}); + if (quant_result) { + conv_args.insert( + {{DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_scale_mem}, + {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_mem}}); + } + } conv_prim.execute(onednn_stream, conv_args); diff --git a/xla/service/cpu/onednn_convolution_rewriter.cc b/xla/service/cpu/onednn_convolution_rewriter.cc index 90e0c82f8ca5d..baf8b7faf81e3 100644 --- a/xla/service/cpu/onednn_convolution_rewriter.cc +++ b/xla/service/cpu/onednn_convolution_rewriter.cc @@ -17,6 +17,7 @@ limitations under the License. #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal_util.h" #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/onednn_util.h" @@ -152,13 +153,165 @@ bool OneDnnConvolutionRewriter::ShouldRewrite(const HloInstruction* conv) { return true; } +inline bool S32ToF32(const HloInstruction* instr) { + return instr->shape().element_type() == F32 && + instr->operand(0)->shape().element_type() == S32; +} + +inline bool Int8ToS32(const HloInstruction* instr) { + auto input_type = instr->operand(0)->shape().element_type(); + return instr->shape().element_type() == S32 && + (input_type == S8 || input_type == U8); +} + +inline bool F32ToInt8(const HloInstruction* instr) { + auto output_type = instr->shape().element_type(); + return instr->operand(0)->shape().element_type() == F32 && + (output_type == S8 || output_type == U8); +} + +inline auto Int8ToS32Pattern(HloInstruction** quant_input) { + return m::Convert(m::Op(quant_input)).WithPredicate(Int8ToS32); +} + +inline auto AddZPPattern(HloInstruction** quant_input, HloInstruction** zp) { + return m::AddAnyOrder(Int8ToS32Pattern(quant_input), + m::Broadcast(m::Constant(zp))); +} + +inline auto S32ToF32WithOptionalAddZP(HloInstruction** quant_input, + HloInstruction** zp) { + return m::Convert(m::AnyOf(AddZPPattern(quant_input, zp), + Int8ToS32Pattern(quant_input))) + .WithPredicate(S32ToF32); +} + +template +auto OptionalCopyAndBitcastPattern(Pattern pattern, HloInstruction** copy, + HloInstruction** bitcast) { + return m::AnyOf(m::Copy(copy, m::Bitcast(bitcast, pattern)), + pattern); +} + +auto DequantizePattern(HloInstruction** quant_input, HloInstruction** scale, + HloInstruction** zp, HloInstruction** copy, + HloInstruction** bitcast) { + auto deq_pattern = m::AnyOf( + m::MultiplyAnyOrder(S32ToF32WithOptionalAddZP(quant_input, zp), + m::Broadcast(m::Constant(scale))), + S32ToF32WithOptionalAddZP(quant_input, zp)); + // Layout assignment pass may insert Transpose | Bitcast -> Copy pattern. + // For now we only handle Bitcast -> Copy assuming Transpose was replaced with + // Bitcast by AlgebraicSimplifier. + return OptionalCopyAndBitcastPattern(deq_pattern, copy, bitcast); +} + +auto OptionalAddZPAndMultiplyScale(HloInstruction** scale, HloInstruction** zp, + HloInstruction** input_custom_call) { + auto multiply_scale = + m::MultiplyAnyOrder(OneDnnConvolutionInstr(input_custom_call), + m::Broadcast(m::Constant(scale))); + auto add_zp = m::AddAnyOrder( + m::AnyOf(multiply_scale, + OneDnnConvolutionInstr(input_custom_call)), + m::Broadcast(m::Constant(zp))); + return m::AnyOf(add_zp, multiply_scale, + OneDnnConvolutionInstr(input_custom_call)); +} + +auto QuantizePattern(HloInstruction** scale, HloInstruction** zp, + HloInstruction** input_custom_call, + HloInstruction** clamp_min, HloInstruction** clamp_max) { + auto quant_pattern = + m::Convert( + m::Op() + .WithOpcode(HloOpcode::kRoundNearestEven) + .WithOperand(0, m::Clamp(m::Broadcast(m::Constant(clamp_min)), + OptionalAddZPAndMultiplyScale( + scale, zp, input_custom_call), + m::Broadcast(m::Constant(clamp_max))))) + .WithPredicate(F32ToInt8); + return quant_pattern; +} + +class OneDnnConvolutionRequantizeVisitor : public DfsHloRewriteVisitor { + public: + Status HandleCustomCall(HloInstruction* custom_call) override { + HloInstruction *conv = nullptr, *input_custom_call = nullptr, + *scale = nullptr, *zp = nullptr, *clamp_min = nullptr, + *clamp_max = nullptr; + if (Match(custom_call, OneDnnConvolutionInstr(&conv))) { + // Try to match the requantize case: + // onednn_custom_call[int8 in, f32 out] -> uniform_quantize_pattern -> + // onednn_custom_call[int8 in, f32 out]. + // This will be replaced by + // onednn_custom_call[int8 in, int8 out] -> onednn_custom_call[int8 in, + // f32 out] + bool requant_conv = Match( + custom_call, + m::Op() + .WithOpcode(HloOpcode::kCustomCall) + .WithOperand(0, QuantizePattern(&scale, &zp, &input_custom_call, + &clamp_min, &clamp_max))); + if (requant_conv) { + if (input_custom_call != nullptr && scale != nullptr && zp != nullptr) { + std::vector requant_call_operands; + for (auto operand : input_custom_call->operands()) { + requant_call_operands.push_back(operand); + } + // Currently we don't pass clamp_min/clamp_max to the custom-call. + // We assume they have the default values which are the + // bounds of the range of the integer data type used. + requant_call_operands.push_back(scale); + requant_call_operands.push_back(zp); + auto requant_conv_call = + Cast(custom_call->AddInstruction( + input_custom_call->CloneWithNewOperands( + ShapeUtil::ChangeElementType( + input_custom_call->shape(), + custom_call->operands()[0]->shape().element_type()), + requant_call_operands))); + const int size = custom_call->operands().size(); + std::vector new_operands(size); + new_operands[0] = requant_conv_call; + for (int i = 1; i < size; ++i) { + new_operands[i] = custom_call->operands()[i]; + } + auto new_conv_call = Cast( + custom_call->AddInstruction(custom_call->CloneWithNewOperands( + custom_call->shape(), new_operands))); + TF_RETURN_IF_ERROR(ReplaceInstruction(custom_call, new_conv_call)); + } + } + } + return OkStatus(); + } +}; + class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { public: Status HandleConvolution(HloInstruction* conv) override { auto pattern = match::Op(&conv).WithOpcode(HloOpcode::kConvolution); + HloInstruction *quant_src = nullptr, *src_scale = nullptr, + *src_zp = nullptr, *quant_wei = nullptr, + *wei_scale = nullptr, *wei_zp = nullptr, *copy_src = nullptr, + *bitcast_src = nullptr, *copy_wei = nullptr, + *bitcast_wei = nullptr; + if (!Match(conv, pattern)) return OkStatus(); if (!OneDnnConvolutionRewriter::ShouldRewrite(conv)) return OkStatus(); + // Try to match uniform_dequantize_pattern -> convolution. + // This will be replaced with onednn_custom_call[int8 in, f32 out]. + bool quant_conv = Match( + conv, + m::Op(&conv) + .WithOpcode(HloOpcode::kConvolution) + .WithOperand(0, DequantizePattern(&quant_src, &src_scale, &src_zp, + ©_src, &bitcast_src)) + .WithOperand(1, DequantizePattern(&quant_wei, &wei_scale, &wei_zp, + ©_wei, &bitcast_wei))); + const Shape& conv_shape = conv->shape(); auto dims = conv->window().dimensions().size(); const ConvolutionDimensionNumbers& conv_ddata = @@ -206,10 +359,69 @@ class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { conv_ddata.output_spatial_dimensions()[i] + 1); } - HloInstruction* custom_call = - conv->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)}, - "__onednn$convolution")); + auto create_bitcast_copy = + [&](HloInstruction* quant_operand, HloInstruction* bitcast, + HloInstruction* copy, HloInstruction*& new_bitcast, + HloInstruction*& new_copy) { + auto quant_type = quant_operand->shape().element_type(); + new_bitcast = conv->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::ChangeElementType(bitcast->shape(), quant_type), + quant_operand)); + new_copy = conv->AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::ChangeElementType(copy->shape(), quant_type), + HloOpcode::kCopy, new_bitcast)); + }; + + HloInstruction* custom_call; + if (quant_conv) { + HloInstruction *new_bitcast_src, *new_copy_src, *new_bitcast_wei, + *new_copy_wei; + // We need to add bitcast and copy instructions that were in the original + // pattern after each operand, + if (copy_src != nullptr && bitcast_src != nullptr) { + create_bitcast_copy(quant_src, bitcast_src, copy_src, new_bitcast_src, + new_copy_src); + } else { + new_copy_src = quant_src; + } + if (copy_wei != nullptr && bitcast_wei != nullptr) { + create_bitcast_copy(quant_wei, bitcast_wei, copy_wei, new_bitcast_wei, + new_copy_wei); + } else { + new_copy_wei = quant_wei; + } + + const float scale = 1.0; + const int zp = 0; + auto set_default_value = [&](HloInstruction*& instr, + T value) { + instr = conv->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(value))); + }; + // If scale and/or zp are not found. That means they had a default values + // that were optimized out. Hence, we set a default value of 1.0 and 0 for + // scale and zp respectively. + if (src_scale == nullptr) { + set_default_value(src_scale, scale); + } + if (src_zp == nullptr) { + set_default_value(src_zp, zp); + } + if (wei_scale == nullptr) { + set_default_value(wei_scale, scale); + } + if (wei_zp == nullptr) { + set_default_value(wei_zp, zp); + } + custom_call = conv->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, + {new_copy_src, new_copy_wei, src_scale, src_zp, wei_scale, wei_zp}, + "__onednn$convolution")); + } else { + custom_call = conv->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)}, + "__onednn$convolution")); + } TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call)); @@ -239,10 +451,6 @@ class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { ->mutable_fusions()->ops(0) == OneDnnFusionConfig::BIAS) { return OkStatus(); } - std::vector new_operands; - for (auto operand : conv->operands()) { - new_operands.push_back(operand); - } HloInstruction* addend = nullptr; HloInstruction* optional_addend_broadcast = nullptr; @@ -255,8 +463,20 @@ class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { m::Op(&addend)); if (!Match(addend_intermediate, addend_pattern)) return OkStatus(); + // Make sure bias is always the third argument as opposed to adding to the + // end as the onednn_custom_call may have more than two operands (ex: + // quantized custom_call). + std::vector new_operands(conv->operands().size() + 1); + int idx = 0; + const int kBiasIdx = 2; + for (auto operand : conv->operands()) { + // Skip bias index + if (idx == kBiasIdx) idx++; + new_operands[idx++] = operand; + } + if (CompatibleElementType(addend) && IsOperandFusible(addend, conv)) { - new_operands.push_back(addend); + new_operands[kBiasIdx] = addend; } else { return OkStatus(); } @@ -347,7 +567,12 @@ StatusOr OneDnnConvolutionRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { OneDnnConvolutionRewriterVisitor visitor; - return visitor.RunOnModule(module, execution_threads); + TF_ASSIGN_OR_RETURN(auto result, + visitor.RunOnModule(module, execution_threads)); + OneDnnConvolutionRequantizeVisitor visitor_requantize; + TF_ASSIGN_OR_RETURN(auto result_requantize, visitor_requantize.RunOnModule( + module, execution_threads)); + return (result || result_requantize); } } // namespace cpu diff --git a/xla/service/cpu/onednn_memory_util.cc b/xla/service/cpu/onednn_memory_util.cc index 372ce97c27893..e8ebc1d12d97d 100644 --- a/xla/service/cpu/onednn_memory_util.cc +++ b/xla/service/cpu/onednn_memory_util.cc @@ -74,14 +74,22 @@ MemrefInfoHandler CreateMemrefInfoFromLiteral(const Literal* literal) { StackAlloca GetAllocaAndEmitMemrefInfo(llvm::IRBuilder<>& builder, const llvm_ir::IrArray& ir_array) { const Shape& shape = ir_array.GetShape(); - int64_t rank = shape.rank(); - absl::Span dims = shape.dimensions(); + // oneDNN handles scalar as a vector of size 1. + int64_t rank = shape.rank() == 0 ? 1 : shape.rank(); + std::vector scalar_shape(1, 1); + absl::Span dims = + shape.dimensions().size() == 0 ? scalar_shape : shape.dimensions(); std::vector strides(rank); - int64_t stride = 1; - for (int i : shape.layout().minor_to_major()) { - strides.at(i) = stride; - stride *= dims.at(i); + if (shape.dimensions().size() == 0) { + // Scalar case. + strides[0] = 1; + } else { + int64_t stride = 1; + for (int i : shape.layout().minor_to_major()) { + strides.at(i) = stride; + stride *= dims.at(i); + } } // Type of struct diff --git a/xla/tests/onednn_convolution_test.cc b/xla/tests/onednn_convolution_test.cc index 5a2fdcd77ebaa..3a24ee0c01c6e 100644 --- a/xla/tests/onednn_convolution_test.cc +++ b/xla/tests/onednn_convolution_test.cc @@ -123,6 +123,144 @@ TEST_F(ConvolutionTest, TestFusedConv3D) { MatchOptimizedHlo(convolution_module_str, conv_rewrite_bias_relu_str_); } +TEST_F(ConvolutionTest, DequantizeConv2D) { + const char* convolution_module_str = R"( + HloModule convolution.test.f32, entry_computation_layout={(s8[1,3,224,224]{3,2,1,0}, s8[64,3,7,7]{3,2,1,0})->f32[1,112,112,64]{3,2,1,0}} + + ENTRY convolution.test.f32 { + Arg_inp = s8[1,3,224,224]{3,2,1,0} parameter(0) + convert.194 = s32[1,3,224,224]{3,2,1,0} convert(Arg_inp) + constant.65 = s32[] constant(-4) + broadcast.1 = s32[1,3,224,224]{3,2,1,0} broadcast(constant.65), dimensions={} + add = s32[1,3,224,224]{3,2,1,0} add(convert.194, broadcast.1) + convert.196 = f32[1,3,224,224]{3,2,1,0} convert(add) + constant.48 = f32[] constant(0.5) + broadcast.186 = f32[1,3,224,224]{3,2,1,0} broadcast(constant.48), dimensions={} + multiply.197 = f32[1,3,224,224]{3,2,1,0} multiply(convert.196, broadcast.186) + transpose = f32[1,224,224,3]{3,2,1,0} transpose(multiply.197), dimensions={0,2,3,1} + Arg_9.10 = s8[64,3,7,7]{3,2,1,0} parameter(1) + convert.205 = s32[64,3,7,7]{3,2,1,0} convert(Arg_9.10) + constant.66 = s32[] constant(0) + broadcast.3 = s32[64,3,7,7]{3,2,1,0} broadcast(constant.66), dimensions={} + add.1 = s32[64,3,7,7]{3,2,1,0} add(convert.205, broadcast.3) + convert.207 = f32[64,3,7,7]{3,2,1,0} convert(add.1) + broadcast.163 = f32[64,3,7,7]{3,2,1,0} broadcast(constant.48), dimensions={} + multiply.208 = f32[64,3,7,7]{3,2,1,0} multiply(convert.207, broadcast.163) + transpose.1 = f32[7,7,3,64]{3,2,1,0} transpose(multiply.208), dimensions={2,3,1,0} + ROOT convolution = f32[1,112,112,64]{3,2,1,0} convolution(transpose, transpose.1), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + })"; + + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); + // TODO(intel-tf): Check that the fusion has the expected quantized type. + MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_); +} + +TEST_F(ConvolutionTest, DequantizeConv2DBiasRelu) { + const char* convolution_module_str = R"( + HloModule convolution.test.f32, entry_computation_layout={(s8[1,3,224,224]{3,2,1,0}, s8[64,3,7,7]{3,2,1,0}, f32[64]{0})->f32[1,112,112,64]{3,2,1,0}} + + ENTRY convolution.test.f32 { + Arg_inp = s8[1,3,224,224]{3,2,1,0} parameter(0) + convert.194 = s32[1,3,224,224]{3,2,1,0} convert(Arg_inp) + constant.65 = s32[] constant(-4) + broadcast.1 = s32[1,3,224,224]{3,2,1,0} broadcast(constant.65), dimensions={} + add = s32[1,3,224,224]{3,2,1,0} add(convert.194, broadcast.1) + convert.196 = f32[1,3,224,224]{3,2,1,0} convert(add) + constant.48 = f32[] constant(0.5) + broadcast.186 = f32[1,3,224,224]{3,2,1,0} broadcast(constant.48), dimensions={} + multiply.197 = f32[1,3,224,224]{3,2,1,0} multiply(convert.196, broadcast.186) + transpose = f32[1,224,224,3]{3,2,1,0} transpose(multiply.197), dimensions={0,2,3,1} + Arg_9.10 = s8[64,3,7,7]{3,2,1,0} parameter(1) + convert.205 = s32[64,3,7,7]{3,2,1,0} convert(Arg_9.10) + constant.66 = s32[] constant(0) + broadcast.3 = s32[64,3,7,7]{3,2,1,0} broadcast(constant.66), dimensions={} + add.1 = s32[64,3,7,7]{3,2,1,0} add(convert.205, broadcast.3) + convert.207 = f32[64,3,7,7]{3,2,1,0} convert(add.1) + broadcast.163 = f32[64,3,7,7]{3,2,1,0} broadcast(constant.48), dimensions={} + multiply.208 = f32[64,3,7,7]{3,2,1,0} multiply(convert.207, broadcast.163) + transpose.1 = f32[7,7,3,64]{3,2,1,0} transpose(multiply.208), dimensions={2,3,1,0} + convolution = f32[1,112,112,64]{3,2,1,0} convolution(transpose, transpose.1), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + Arg_8.9 = f32[64]{0} parameter(2) + broadcast.171 = f32[1,112,112,64]{2,1,3,0} broadcast(Arg_8.9), dimensions={3} + add.57 = f32[1,112,112,64]{3,2,1,0} add(convolution, broadcast.171) + constant.171 = f32[] constant(0) + broadcast.205 = f32[1,112,112,64]{2,1,3,0} broadcast(constant.171), dimensions={} + ROOT maximum.0 = f32[1,112,112,64]{3,2,1,0} maximum(add.57, broadcast.205) + })"; + + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); + // TODO(intel-tf): Check that the fusion has the expected quantized type. + MatchOptimizedHlo(convolution_module_str, conv_rewrite_bias_relu_str_); +} + +TEST_F(ConvolutionTest, DequantizeConv2DBiasReluRequantize) { + const char* convolution_module_str = R"( + HloModule convolution.test.f32, entry_computation_layout={(s8[1,64,56,56]{3,2,1,0}, s8[64,64,3,3]{3,2,1,0}, f32[64]{0}, s8[64,64,3,3]{3,2,1,0})->f32[1,56,56,64]{3,2,1,0}} + + ENTRY convolution.test.f32 { + Arg_inp = s8[1,64,56,56]{3,2,1,0} parameter(0) + convert.194 = s32[1,64,56,56]{3,2,1,0} convert(Arg_inp) + constant.65 = s32[] constant(-4) + broadcast.1 = s32[1,64,56,56]{3,2,1,0} broadcast(constant.65), dimensions={} + add = s32[1,64,56,56]{3,2,1,0} add(convert.194, broadcast.1) + convert.196 = f32[1,64,56,56]{3,2,1,0} convert(add) + constant.48 = f32[] constant(0.5) + broadcast.186 = f32[1,64,56,56]{3,2,1,0} broadcast(constant.48), dimensions={} + multiply.197 = f32[1,64,56,56]{3,2,1,0} multiply(convert.196, broadcast.186) + transpose = f32[1,56,56,64]{3,2,1,0} transpose(multiply.197), dimensions={0,2,3,1} + Arg_9.10 = s8[64,64,3,3]{3,2,1,0} parameter(1) + convert.205 = s32[64,64,3,3]{3,2,1,0} convert(Arg_9.10) + constant.66 = s32[] constant(0) + broadcast.3 = s32[64,64,3,3]{3,2,1,0} broadcast(constant.66), dimensions={} + add.1 = s32[64,64,3,3]{3,2,1,0} add(convert.205, broadcast.3) + convert.207 = f32[64,64,3,3]{3,2,1,0} convert(add.1) + broadcast.163 = f32[64,64,3,3]{3,2,1,0} broadcast(constant.48), dimensions={} + multiply.208 = f32[64,64,3,3]{3,2,1,0} multiply(convert.207, broadcast.163) + transpose.1 = f32[3,3,64,64]{3,2,1,0} transpose(multiply.208), dimensions={2,3,1,0} + convolution = f32[1,56,56,64]{3,2,1,0} convolution(transpose, transpose.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f + Arg_8.9 = f32[64]{0} parameter(2) + broadcast.171 = f32[1,56,56,64]{2,1,3,0} broadcast(Arg_8.9), dimensions={3} + add.57 = f32[1,56,56,64]{3,2,1,0} add(convolution, broadcast.171) + constant.171 = f32[] constant(0) + broadcast.205 = f32[1,56,56,64]{2,1,3,0} broadcast(constant.171), dimensions={} + maximum.0 = f32[1,56,56,64]{3,2,1,0} maximum(add.57, broadcast.205) + constant = f32[] constant(2) + broadcast.266 = f32[1,56,56,64]{3,2,1,0} broadcast(constant), dimensions={} + multiply.76 = f32[1,56,56,64]{3,2,1,0} multiply(maximum.0, broadcast.266) + constant.46 = f32[] constant(4) + broadcast.336 = f32[1,56,56,64]{3,2,1,0} broadcast(constant.46), dimensions={} + add.93 = f32[1,56,56,64]{3,2,1,0} add(multiply.76, broadcast.336) + constant.184 = f32[] constant(127) + broadcast.413 = f32[1,56,56,64]{3,2,1,0} broadcast(constant.184), dimensions={} + constant.183 = f32[] constant(-128) + broadcast.328 = f32[1,56,56,64]{3,2,1,0} broadcast(constant.183), dimensions={} + clamp.15 = f32[1,56,56,64]{3,2,1,0} clamp(broadcast.328, add.93, broadcast.413) + round-nearest-even.15 = f32[1,56,56,64]{3,2,1,0} round-nearest-even(clamp.15) + convert.17 = s8[1,56,56,64]{3,2,1,0} convert(round-nearest-even.15) + Arg_12.13 = s8[64,64,3,3]{3,2,1,0} parameter(3) + convert.38 = s32[1,56,56,64]{3,2,1,0} convert(convert.17) + broadcast.468 = s32[1,56,56,64]{2,1,3,0} broadcast(constant.65), dimensions={} + add.116 = s32[1,56,56,64]{3,2,1,0} add(convert.38, broadcast.468) + convert.61 = f32[1,56,56,64]{3,2,1,0} convert(add.116) + broadcast.515 = f32[1,56,56,64]{2,1,3,0} broadcast(constant.48), dimensions={} + multiply.92 = f32[1,56,56,64]{3,2,1,0} multiply(convert.61, broadcast.515) + convert.274 = s32[64,64,3,3]{3,2,1,0} convert(Arg_12.13) + constant.67 = s32[] constant(0) + broadcast.9 = s32[64,64,3,3]{3,2,1,0} broadcast(constant.67), dimensions={} + add.6 = s32[64,64,3,3]{3,2,1,0} add(convert.274, broadcast.9) + convert.276 = f32[64,64,3,3]{3,2,1,0} convert(add.6) + broadcast.145 = f32[64,64,3,3]{3,2,1,0} broadcast(constant.48), dimensions={} + multiply.277 = f32[64,64,3,3]{3,2,1,0} multiply(convert.276, broadcast.145) + transpose.7 = f32[3,3,64,64]{3,2,1,0} transpose(multiply.277), dimensions={2,3,1,0} + ROOT convolution.2 = f32[1,56,56,64]{3,2,1,0} convolution(multiply.92, transpose.7), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f + })"; + + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); + // TODO(intel-tf): Check that the fusion has the expected quantized type. + MatchOptimizedHlo(convolution_module_str, conv_rewrite_bias_relu_str_); +} + + } // namespace cpu } // namespace xla