From 5ce8002d24d357a259598937d13e1b8ab3bcdcab Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 25 Oct 2023 16:34:16 +0000 Subject: [PATCH] Revert "Remove deprecated fbgemm operators (#104535)" This reverts commit 57c7aa12dbf71617bd21fe7e076df8e823b5b7bb. Reverted https://github.com/pytorch/pytorch/pull/104535 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/104535#issuecomment-1779650412)) --- aten/src/ATen/native/QuantizedLinear.cpp | 585 ++++++++++++++ aten/src/ATen/native/RNN.cpp | 203 +++++ aten/src/ATen/native/native_functions.yaml | 25 + build_variables.bzl | 1 + caffe2/CMakeLists.txt | 1 + .../check_forward_backward_compatibility.py | 12 - test/jit/test_models.py | 13 + test/quantization/core/test_quantized_op.py | 75 ++ .../jit/test_deprecated_jit_quant.py | 147 +++- test/test_nn.py | 19 + torch/jit/quantized.py | 765 +++++++++++++++++- torch/overrides.py | 17 + 12 files changed, 1808 insertions(+), 55 deletions(-) create mode 100644 aten/src/ATen/native/QuantizedLinear.cpp diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp new file mode 100644 index 00000000000000..002bb1adc43861 --- /dev/null +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -0,0 +1,585 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include + +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#include + +#ifdef USE_FBGEMM +#include +#include +#include +#endif // USE_FBGEMM + +namespace caffe2 { +CAFFE_KNOWN_TYPE(c10::intrusive_ptr); +} // namespace caffe2 + +#ifdef USE_FBGEMM +namespace caffe2 { +// Required for cpp_custom_type_hack to work +CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix); +CAFFE_KNOWN_TYPE(c10::intrusive_ptr); +} // namespace caffe2 +#endif // USE_FBGEMM + +namespace at { +namespace native { + +#ifdef USE_FBGEMM + +Tensor fbgemm_linear_int8_weight_fp32_activation( + const Tensor& input, + const Tensor& weight, + const Tensor& packed, + const Tensor& col_offsets, + const Scalar& weight_scale, + const Scalar& weight_zero_point, + const Tensor& bias) { + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); + + TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated " + "and will be removed in a future PyTorch release.") + + const Tensor input_contig = input.contiguous(); + const float* input_ptr = input_contig.data_ptr(); + + TORCH_CHECK(input.dim() >= 2); + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const int64_t M = size_to_dim_(input.dim() - 1, input.sizes()); + const int64_t K = input.size(input.dim() - 1); + TORCH_CHECK(weight.dim() == 2); + TORCH_CHECK(K == weight.size(1)); + const int64_t N = weight.size(0); + TORCH_CHECK(bias.dim() == 1); + TORCH_CHECK(bias.size(0) == N); + TORCH_CHECK(weight_scale.isFloatingPoint()); + TORCH_CHECK(weight_zero_point.isIntegral(false)); + + // Calculate statistics for quantization of the input Tensor + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float x_min; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float x_max; + fbgemm::FindMinMax( + /*m=*/input_ptr, + /*min=*/&x_min, + /*max=*/&x_max, + /*len=*/input.numel()); + + // Input tensor is quantized as 8-bit unsigned values + constexpr int kPrecision = 8; + constexpr bool kIsSigned = false; + constexpr int kBound = (1 << (kPrecision - 1)); + + // Calculate scale and zero point for quantization of input tensor + auto q_params = fbgemm::ChooseQuantizationParams( + /*min=*/x_min, + /*max=*/x_max, + /*qmin=*/kIsSigned ? -kBound : 0, + /*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1, + /*preserve_sparsity=*/false); + q_params.precision = kPrecision; + + // ReQuantizeForFloat requires pointers to the scale and zero point values, + // since in the case of rowwise quantization these will be arrays rather than + // scalars. But in this case, we're doing whole-tensor quantization so we just + // pass a pointer to the scale values (and internally ReQuantizeFor Float + // won't index past 0 + const float weight_scale_float = + static_cast(weight_scale.to()); + const int32_t weight_zero_point_int32 = + static_cast(weight_zero_point.to()); + + const Tensor bias_contig = bias.contiguous(); + + // Allocate output Tensor and a buffer for fbgemmPacked to use + std::vector output_size = input.sizes().vec(); + output_size.back() = N; + Tensor output = at::empty(output_size, input.options().dtype(at::kFloat), LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor buffer = at::empty(output_size, input.options().dtype(at::kInt), LEGACY_CONTIGUOUS_MEMORY_FORMAT); + + // Pull out the PackBMatrix instance from the owning tensor + auto& pack_b = + cpp_custom_type_hack::cast>(packed); + + const int num_tasks = at::get_num_threads(); + at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { + // This operation does the following: + // 1) Quantizes the input matrix given the statistics we've calculated + // above. + // 2) Creates a "row buffer" vector with offset values that must be added + // to the integer matrix multiplication operation to ensure correctness. + // 3) Packs the resulting quantized matrix into vector-register and cache + // friendly tiles. + // + // Note this is not executed eagerly, but rather within the fbgemmPacked + // call below. + fbgemm::PackAWithQuantRowOffset pack_a( + /*trans=*/fbgemm::matrix_op_t::NoTranspose, + /*nRow=*/M, + /*nCol=*/K, + /*smat=*/input_ptr, + /*ld=*/K, + /*pmat=*/nullptr, // pack_a manages ownership of `pmat` + /*scale=*/q_params.scale, + /*zero_pt=*/q_params.zero_point); + + // This is the end of the pipeline, pass the resulting matrix through + fbgemm::DoNothing kDoNothingObj{}; + for (const auto task_id : c10::irange(begin, end)) { + // After the uint8 * int8 matrix multiplication is performed, this + // operation does: + // 1) Add in row and column offsets to the rows and columns, respectively + // 2) Dequantize the results into floating point + // 3) Add in the bias term + fbgemm::ReQuantizeForFloat output_proc_obj( + /*nextop=*/kDoNothingObj, + /*Aq_scale=*/q_params.scale, + /*Bq_scale=*/&weight_scale_float, + /*Aq_zero_point=*/q_params.zero_point, + /*Bq_zero_point=*/&weight_zero_point_int32, + /*row_offsets=*/pack_a.getRowOffsetBuffer(), + /*col_offsets=*/col_offsets.data_ptr(), + /*bias=*/bias_contig.data_ptr(), + /*nCol=*/N); + // Do the GEMM + fbgemm::fbgemmPacked( + /*packA=*/pack_a, + /*packB=*/pack_b, + /*C=*/output.data_ptr(), + /*C_buffer=*/buffer.data_ptr(), + /*ldc=*/N, + /*outProcess=*/output_proc_obj, + /*thread_id=*/task_id, + /*num_threads=*/num_tasks); + } + }); + + return output; +} + +Tensor fbgemm_linear_int8_weight( + const Tensor& input, + const Tensor& weight, + const Tensor& packed, + const Tensor& col_offsets, + const Scalar& weight_scale, + const Scalar& weight_zero_point, + const Tensor& bias) { + return at::native::fbgemm_linear_int8_weight_fp32_activation( + input, + weight, + packed, + col_offsets, + weight_scale, + weight_zero_point, + bias); +} + +namespace { + +// Calculate the column offsets +// Note this includes the sum of the columns as well as the scalar term +// B_zero_point * K, whereas the row_offsets created by +// PackAWithQuantRowOffset is only the sum of the A rows. +void CalcColOffsetsTranspose( + int K, + int N, + const int8_t* Bint8, + int32_t B_zero_point, + int32_t* col_offsets) { + for (const auto i : c10::irange(N)) { + int32_t sum = 0; + for (const auto j : c10::irange(K)) { + sum += Bint8[i * K + j]; + } + col_offsets[i] = sum - B_zero_point * K; + } +} + +} // namespace + +std::tuple fbgemm_linear_quantize_weight( + const Tensor& weight) { + TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); + const Tensor weight_contig = weight.contiguous(); + + // Calculate weight statistics + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float w_min; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float w_max; + fbgemm::FindMinMax( + /*m=*/weight_contig.data_ptr(), + /*min=*/&w_min, + /*max=*/&w_max, + /*len=*/weight_contig.numel()); + + // Choose parameters for quantizing the weight as 8-bit signed integer + constexpr bool kIsSigned = true; + constexpr int kPrecision = 8; + constexpr int kBound = (1 << (kPrecision - 1)); + auto q_params = fbgemm::ChooseQuantizationParams( + /*min=*/w_min, + /*max=*/w_max, + /*qmin=*/kIsSigned ? -kBound : 0, + /*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1, + /*preserve_sparsity=*/false); + q_params.precision = kPrecision; + + Tensor quantized = at::native::empty_like( + weight_contig, + at::kChar, + weight_contig.options().layout_opt(), + weight_contig.options().device_opt(), + weight_contig.options().pinned_memory_opt(), + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + // Tensor quantized = at::native::empty_cpu( + // weight_contig.sizes(), weight_contig.options().dtype(at::kChar)); + fbgemm::Quantize( + /*src=*/weight_contig.data_ptr(), + /*dst=*/quantized.data_ptr(), + /*len=*/weight_contig.numel(), + /*qparams=*/q_params); + + // Calculate column offsets of the weight and store them away in a tensor. + // Similarly to quantization, this can be done once and cached. + Tensor col_offsets = at::empty( + {weight_contig.size(0)}, + at::kInt, + weight_contig.options().layout_opt(), + weight_contig.options().device_opt(), + weight_contig.options().pinned_memory_opt(), + LEGACY_CONTIGUOUS_MEMORY_FORMAT); + CalcColOffsetsTranspose( + /*K=*/quantized.size(1), + /*N=*/quantized.size(0), + /*Bint8=*/quantized.data_ptr(), + /*B_zero_point=*/q_params.zero_point, + /*col_offsets=*/col_offsets.data_ptr()); + + return std::make_tuple( + quantized, col_offsets, q_params.scale, q_params.zero_point); +} + +Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) { + TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); + const int64_t K = weight.size(1); + const int64_t N = weight.size(0); + const Tensor weight_contig = weight.contiguous(); + const int8_t* weight_ptr = weight_contig.data_ptr(); + auto ptr = std::make_unique>( + /*trans=*/fbgemm::matrix_op_t::Transpose, + /*nRow=*/K, + /*nCol=*/N, + /*smat=*/weight_ptr, + /*ld=*/K, + /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat + /*groups=*/1); + return cpp_custom_type_hack::create(std::move(ptr), weight.options()); +} + +Tensor fbgemm_pack_quantized_matrix( + const Tensor& weight, + int64_t K, + int64_t N) { + // Replace after https://github.com/pytorch/pytorch/issues/24354 is fixed + // TORCH_WARN( + // "fbgemm_pack_quantized_matrix(weight, K, N) will be deprecated soon." + // "Please use fbgemm_pack_quantized_matrix(weight) instead."); + return at::native::fbgemm_pack_quantized_matrix(weight); +} + +namespace { + +float RawUint16ToFp16(unsigned short value) { + // Convert raw 16 bits half precision floating point number + // to single precision floating point number. + const unsigned short sign_bits = value >> 15; + const unsigned short exponent_bits = value >> 10 & 0x1f; + const unsigned short significand_bits = value & 0x3ff; + + const float sign = sign_bits ? -1 : 1; + const float significand = + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + 1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10 + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const float exponent = exponent_bits - 0xf; + + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + return sign * std::ldexp(significand, exponent); +} + +template +bool CheckAndSaturate(T max_val, T* element) { + if (*element > max_val) { + *element = max_val; + return true; + } + if (*element < -max_val) { + *element = -max_val; + return true; + } + return false; +} + +// The range for using FP16 quantization of weights requires that the elements +// should be in the range of [5.96e-8, 65504]. If it is out of range, then the +// number will be saturated to max or min representable values by FP16. +void HandleWeightsSaturation(int64_t N, float* weight) { + const float kFp16Max = RawUint16ToFp16(0x7BFF); + bool found_out_of_range = false; + for (const auto i : c10::irange(N)) { + if (CheckAndSaturate(kFp16Max, weight + i)) { + found_out_of_range = true; + } + } + if (found_out_of_range) { + TORCH_WARN("FOUND weight out of range "); + } +} + +} // namespace + +Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { + TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); + + const int64_t K = weight.size(1); + const int64_t N = weight.size(0); + Tensor weight_contig = weight.contiguous(); + float* weight_contig_ptr = weight_contig.data_ptr(); + HandleWeightsSaturation(K * N, weight_contig_ptr); + + // TODO(mingzhe09088): + // Consider using a functor here in PackedGemmMatrixFP16 + // Comments from (XQ): Not entirely sure this make_unique is safe. make_unique + // is created with regular "new", and freed through TypeMetaData::deleteFn in + // this function. This is perfectly fine if the tensors are created and freed + // within this translation unit. It might be very problematic if that tensor + // flows across dll boundaries. + auto ptr = std::make_unique( + fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr); + c10::intrusive_ptr packed_weight = + c10::make_intrusive(std::move(ptr), c10::nullopt); + auto unique_ptr_wrapper = + std::make_unique(std::move(packed_weight)); + return cpp_custom_type_hack::create( + std::move(unique_ptr_wrapper), weight.options()); +} + +Tensor fbgemm_linear_fp16_weight_fp32_activation( + const Tensor& input, + const Tensor& packed_weight, + const Tensor& bias) { + TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM."); + + const Tensor input_contig = input.contiguous(); + const float* input_ptr = input_contig.data_ptr(); + + // Pull out the PackedGemmMatrixFP16 instance from the owning tensor + const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 = + *c10::dynamic_intrusive_pointer_cast( + cpp_custom_type_hack::cast< + c10::intrusive_ptr>(packed_weight)) + ->w; + + TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows()) + TORCH_CHECK(input.dim() >= 2); + TORCH_CHECK(bias.dim() == 1); + + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + const int64_t M = size_to_dim_(input.dim() - 1, input.sizes()); + const int64_t N = packed_weight_fp16.numCols(); + std::vector output_size = input.sizes().vec(); + output_size.back() = N; + Tensor output = at::empty(output_size, input.options().dtype(at::kFloat)); + + // Call the fp16 gemm interface + fbgemm::cblas_gemm_compute( + fbgemm::matrix_op_t::NoTranspose, + M, + input_ptr, + packed_weight_fp16, + 0.0f, + output.data_ptr()); + + // Add bias term + output.add_(bias); + + return output; +} + +Tensor fbgemm_linear_fp16_weight( + const Tensor& input, + const Tensor& packed_weight, + const Tensor& bias) { + return at::native::fbgemm_linear_fp16_weight_fp32_activation( + input, packed_weight, bias); +} + +#else // USE_FBGEMM + +Tensor fbgemm_linear_int8_weight_fp32_activation( + const Tensor& /*input*/, + const Tensor& /*weight*/, + const Tensor& /*packed*/, + const Tensor& /*col_offsets*/, + const Scalar& /*weight_scale*/, + const Scalar& /*weight_zero_point*/, + const Tensor& /*bias*/) { + TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_linear_int8_weight( + const Tensor& /*input*/, + const Tensor& /*weight*/, + const Tensor& /*packed*/, + const Tensor& /*col_offsets*/, + const Scalar& /*weight_scale*/, + const Scalar& /*weight_zero_point*/, + const Tensor& /*bias*/) { + TORCH_WARN_ONCE("fbgemm_linear_int8_weight is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +std::tuple fbgemm_linear_quantize_weight( + const Tensor& /*weight*/) { + TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_pack_quantized_matrix(const Tensor& /*input*/) { + TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_pack_quantized_matrix( + const Tensor& /*input*/, + int64_t /*K*/, + int64_t /*N*/) { + TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) { + TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_linear_fp16_weight_fp32_activation( + const Tensor& input, + const Tensor& packed_weight, + const Tensor& bias) { + TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +Tensor fbgemm_linear_fp16_weight( + const Tensor& input, + const Tensor& packed_weight, + const Tensor& bias) { + TORCH_WARN_ONCE("fbgemm_linear_fp16_weight is deprecated " + "and will be removed in a future PyTorch release.") + + // We make a strong guarantee that models using these operators will have the + // same numerics across different machines. Therefore, we do not provide a + // fallback path and rather fail loudly if we cannot run FBGEMM. + TORCH_CHECK( + false, "This PyTorch installation was not built with FBGEMM operators"); +} + +#endif // USE_FBGEMM + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index ad8eb018f6787e..9b9b2c5edbf59f 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -32,12 +32,19 @@ #include #include #include +#include +#include +#include #include #include #include #include #include #include +#include +#include +#include +#include #include #include #include @@ -201,6 +208,158 @@ struct CellParams : public CellParamsBase { } }; +c10::intrusive_ptr make_quantized_cell_params( + const at::Tensor& w_ih, + const at::Tensor& w_hh, + at::Tensor bias_ih, + at::Tensor bias_hh); + +struct QuantizedCellParams : public CellParamsBase { + QuantizedCellParams( + Tensor _w_ih, + Tensor _w_hh, + Tensor _b_ih, + Tensor _b_hh, + Tensor _packed_ih, + Tensor _packed_hh, + Tensor _col_offsets_ih, + Tensor _col_offsets_hh, + Scalar _scale_ih, + Scalar _scale_hh, + Scalar _zero_point_ih, + Scalar _zero_point_hh) + : w_ih(std::move(_w_ih)), + w_hh(std::move(_w_hh)), + b_ih_(std::move(_b_ih)), + b_hh_(std::move(_b_hh)), + packed_ih(std::move(_packed_ih)), + packed_hh(std::move(_packed_hh)), + col_offsets_ih(std::move(_col_offsets_ih)), + col_offsets_hh(std::move(_col_offsets_hh)), + scale_ih(std::move(_scale_ih)), + scale_hh(std::move(_scale_hh)), + zero_point_ih(std::move(_zero_point_ih)), + zero_point_hh(std::move(_zero_point_hh)) {} + + const Tensor w_ih; + const Tensor w_hh; + const Tensor b_ih_; + const Tensor b_hh_; + const Tensor packed_ih; + const Tensor packed_hh; + const Tensor col_offsets_ih; + const Tensor col_offsets_hh; + const Scalar scale_ih; + const Scalar scale_hh; + const Scalar zero_point_ih; + const Scalar zero_point_hh; + + Tensor matmul_ih(const Tensor& input) const override { + TORCH_CHECK(false, "matmul is not supported with quantized cell params"); + } + Tensor matmul_hh(const Tensor& h) const override { + TORCH_CHECK(false, "matmul is not supported with quantized cell params"); + } + Tensor linear_ih(const Tensor& input) const override { + return at::fbgemm_linear_int8_weight_fp32_activation( + input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih_); + } + Tensor linear_hh(const Tensor& h) const override { + return at::fbgemm_linear_int8_weight_fp32_activation( + h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh_); + } + const Tensor& b_ih() const override { + return b_ih_; + } + const Tensor& b_hh() const override { + return b_hh_; + } + CellParamsSerializationType __getstate__() const override { + std::vector tensors_to_serialize = { + w_ih, w_hh, b_ih_, b_hh_, col_offsets_ih, col_offsets_hh}; + std::vector doubles_to_serialize = {scale_ih.toDouble(), + scale_hh.toDouble()}; + std::vector longs_to_serialize = {zero_point_ih.toLong(), + zero_point_hh.toLong()}; + return CellParamsSerializationType( + "quantized", + std::move(tensors_to_serialize), + std::move(doubles_to_serialize), + std::move(longs_to_serialize), + {}); + } + static c10::intrusive_ptr __setstate__( + CellParamsSerializationType state) { + std::vector tensors; + std::vector doubles; + std::vector longs; + std::tie(std::ignore, tensors, doubles, longs, std::ignore) = + std::move(state); + TORCH_INTERNAL_ASSERT(tensors.size() == 6); + TORCH_INTERNAL_ASSERT(doubles.size() == 2); + TORCH_INTERNAL_ASSERT(longs.size() == 2); + + at::Tensor qw_ih = std::move(tensors[0]), qw_hh = std::move(tensors[1]), + b_ih = std::move(tensors[2]), b_hh = std::move(tensors[3]), + col_offsets_ih = std::move(tensors[4]), + col_offsets_hh = std::move(tensors[5]); + double scale_ih = doubles[0], scale_hh = doubles[1]; + int64_t zero_point_ih = longs[0], zero_point_hh = longs[1]; + + at::Tensor packed_ih = at::native::fbgemm_pack_quantized_matrix(qw_ih); + at::Tensor packed_hh = at::native::fbgemm_pack_quantized_matrix(qw_hh); + + return c10::make_intrusive( + /*w_ih=*/std::move(qw_ih), + /*w_hh=*/std::move(qw_hh), + /*b_ih_=*/std::move(b_ih), + /*b_hh_=*/std::move(b_hh), + /*packed_ih=*/std::move(packed_ih), + /*packed_hh=*/std::move(packed_hh), + /*col_offsets_ih=*/std::move(col_offsets_ih), + /*col_offsets_hh=*/std::move(col_offsets_hh), + /*scale_ih=*/scale_ih, + /*scale_hh=*/scale_hh, + /*zero_point_ih=*/zero_point_ih, + /*zero_point_hh=*/zero_point_hh); + } +}; + +c10::intrusive_ptr make_quantized_cell_params( + const at::Tensor& w_ih, + const at::Tensor& w_hh, + at::Tensor b_ih, + at::Tensor b_hh) { + auto make_vals = [&](const at::Tensor& W) { + auto params = at::native::fbgemm_linear_quantize_weight(W); + at::Tensor packed_weight = + at::native::fbgemm_pack_quantized_matrix(std::get<0>(params)); + return std::tuple_cat( + std::make_tuple(std::move(packed_weight)), std::move(params)); + }; + + at::Tensor qw_ih, qw_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh; + at::Scalar scale_ih, scale_hh, zero_point_ih, zero_point_hh; + + std::tie(packed_ih, qw_ih, col_offsets_ih, scale_ih, zero_point_ih) = + make_vals(w_ih); + std::tie(packed_hh, qw_hh, col_offsets_hh, scale_hh, zero_point_hh) = + make_vals(w_hh); + + return c10::make_intrusive( + /*qw_ih=*/std::move(qw_ih), + /*qw_hh=*/std::move(qw_hh), + /*b_ih=*/std::move(b_ih), + /*b_hh=*/std::move(b_hh), + /*packed_ih=*/std::move(packed_ih), + /*packed_hh=*/std::move(packed_hh), + /*col_offsets_ih=*/std::move(col_offsets_ih), + /*col_offsets_hh=*/std::move(col_offsets_hh), + /*scale_ih=*/std::move(scale_ih), + /*scale_hh=*/std::move(scale_hh), + /*zero_point_ih=*/std::move(zero_point_ih), + /*zero_point_hh=*/std::move(zero_point_hh)); +} // QuantizedCellParams vs. QuantizedCellParamsDynamic // @@ -377,6 +536,7 @@ static std::unordered_map< std::string, c10::intrusive_ptr (*)(CellParamsSerializationType)> cell_params_deserializers = { + {"quantized", &QuantizedCellParams::__setstate__}, {"quantized_dynamic", &QuantizedCellParamsDynamic::__setstate__}, {"quantized_fp16", &QuantizedCellParamsFP16::__setstate__}}; @@ -1681,6 +1841,38 @@ static std::tuple quantized_lstm_data_legacy( "using the newer definitions in torch.jit.quantized"); } +#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \ +return_type name( \ + const Tensor& input, \ + hx_type hx, \ + const Tensor& w_ih, \ + const Tensor& w_hh, \ + const Tensor& b_ih, \ + const Tensor& b_hh, \ + const Tensor& packed_ih, \ + const Tensor& packed_hh, \ + const Tensor& col_offsets_ih, \ + const Tensor& col_offsets_hh, \ + const Scalar& scale_ih, \ + const Scalar& scale_hh, \ + const Scalar& zero_point_ih, \ + const Scalar& zero_point_hh) { \ + QuantizedCellParams params( \ + w_ih, \ + w_hh, \ + b_ih, \ + b_hh, \ + packed_ih, \ + packed_hh, \ + col_offsets_ih, \ + col_offsets_hh, \ + scale_ih, \ + scale_hh, \ + zero_point_ih, \ + zero_point_hh); \ + return cell_type{}( \ + input, prepare_hx_fn(hx), params); \ +} // Set reduced range to be True for all RNN Cells by default. This flag is used only for FBGEMM kernels // QNNPACK does not reduce range for activations #define DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(name, hx_type, cell_type, return_type, prepare_hx_fn) \ @@ -1703,6 +1895,7 @@ return_type name( \ } // Quantized LSTM cell +using quantized_lstm_cell_type = LSTMCell; using quantized_lstm_return_type = std::tuple; static std::tuple prepare_quantized_lstm_hx(TensorList hx) { return std::make_tuple(hx[0], hx[1]); @@ -1711,6 +1904,7 @@ static std::tuple prepare_quantized_lstm_hx(TensorList hx) { // Quantized LSTM cell using quantized_lstm_cell_dynamic_type = LSTMCell; +DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx); static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx); @@ -1721,15 +1915,22 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) { } // Quantized GRU cell +using quantized_gru_cell_type = GRUCell; using quantized_gru_cell_dynamic_type = GRUCell; +DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx); + static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx); // Quantized RNN w/ ReLU cell +using quantized_rnn_relu_cell_type = SimpleCell; +DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx); using quantized_rnn_relu_cell_dynamic_type = SimpleCell; static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx); // Quantized RNN w/ tanh cell +using quantized_rnn_tanh_cell_type = SimpleCell; +DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx); using quantized_rnn_tanh_cell_dynamic_type = SimpleCell; static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx); @@ -1771,6 +1972,7 @@ TORCH_LIBRARY_FRAGMENT(aten, m) { TORCH_LIBRARY_FRAGMENT(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase")); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor")); @@ -1790,6 +1992,7 @@ TORCH_LIBRARY_IMPL(aten, CPU, m) { TORCH_LIBRARY_IMPL(quantized, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params_dynamic"), TORCH_FN(make_quantized_cell_params_dynamic)); + m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params"), TORCH_FN(make_quantized_cell_params)); m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_lstm_cell_dynamic"), TORCH_FN(quantized_lstm_cell_dynamic)); m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_gru_cell_dynamic"), TORCH_FN(quantized_gru_cell_dynamic)); m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_rnn_relu_cell_dynamic"), TORCH_FN(quantized_rnn_relu_cell_dynamic)); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2176882644d9f0..2b969ebd6b2b94 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3314,6 +3314,22 @@ dispatch: CUDA: _mixed_dtypes_linear +- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + +- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + +- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) + +- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor + +- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + +- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + +- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor + +- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor + - func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor variants: function, method @@ -7624,6 +7640,15 @@ # - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) # +# Quantized RNN cells +- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor) + +- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + +- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + +- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + # PackedSequence utilities - func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) dispatch: diff --git a/build_variables.bzl b/build_variables.bzl index 11152d69c09171..2a306c2c2e9bd1 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1323,6 +1323,7 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/PointwiseOps.cpp", "aten/src/ATen/native/Pooling.cpp", "aten/src/ATen/native/Pow.cpp", + "aten/src/ATen/native/QuantizedLinear.cpp", "aten/src/ATen/native/RNN.cpp", "aten/src/ATen/native/RangeFactories.cpp", "aten/src/ATen/native/ReduceAllOps.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index a44dc552a8a4ee..078531c0776c1f 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1132,6 +1132,7 @@ set(ATen_CPU_INCLUDE ${CMAKE_BINARY_DIR}/aten/src) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/QuantizedLinear.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/RNN.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/qlinear_unpack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations) diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 1b7dd2c2aa6441..6b1bdfe37db85e 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -312,18 +312,6 @@ ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), - ("aten::fbgemm_linear_int8_weight_fp32_activation", datetime.date(2023, 12, 31)), - ("aten::fbgemm_linear_int8_weight", datetime.date(2023, 12, 31)), - ("aten::fbgemm_linear_quantize_weight", datetime.date(2023, 12, 31)), - ("aten::fbgemm_pack_gemm_matrix_fp16", datetime.date(2023, 12, 31)), - ("aten::fbgemm_linear_fp16_weight_fp32_activation", datetime.date(2023, 12, 31)), - ("aten::fbgemm_linear_fp16_weight", datetime.date(2023, 12, 31)), - ("aten::fbgemm_pack_quantized_matrix", datetime.date(2023, 12, 31)), - ("aten::quantized_lstm_cell", datetime.date(2023, 12, 31)), - ("aten::quantized_gru_cell", datetime.date(2023, 12, 31)), - ("aten::quantized_rnn_relu_cell", datetime.date(2023, 12, 31)), - ("aten::quantized_rnn_tanh_cell", datetime.date(2023, 12, 31)), - ("quantized::make_quantized_cell_params", datetime.date(2023, 12, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/jit/test_models.py b/test/jit/test_models.py index 3ed0acbdb27858..252fb3dcb8d2f3 100644 --- a/test/jit/test_models.py +++ b/test/jit/test_models.py @@ -16,6 +16,7 @@ sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA from torch.testing._internal.common_utils import slowTest, suppress_warnings +from torch.testing._internal.common_quantization import skipIfNoFBGEMM if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" @@ -406,6 +407,12 @@ class Config: def test_snli(self): self._test_snli(self, device='cpu') + @skipIfNoFBGEMM + # Suppression: this exercises a deprecated API + @suppress_warnings + def test_snli_quantized(self): + self._test_snli(self, device='cpu', quantized=True) + @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_snli_cuda(self): # XXX: export_import on CUDA modules doesn't work (#11480) @@ -546,6 +553,12 @@ def forward(self, x): def test_vae(self): self._test_vae(self, device='cpu') + @skipIfNoFBGEMM + # Suppression: this exercises a deprecated API + @suppress_warnings + def test_vae_quantized(self): + self._test_vae(self, device='cpu', quantized=True) + @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_vae_cuda(self): # XXX: export_import on CUDA modules doesn't work (#11480) diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 30f8ba72dce8a1..7f0c48c36a6d55 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -3200,6 +3200,81 @@ def test_qlinear(self, batch_size, input_channels, output_channels, self.assertEqual(Y_fp32, Y_fp32_ref, msg="torch.ops.quantized.linear_dynamic results are off") + @skipIfNoFBGEMM + @given( + batch_size=st.integers(1, 4), + input_channels=st.integers(16, 32), + output_channels=st.integers(4, 8), + ) + def test_qlinear_legacy(self, batch_size, input_channels, output_channels): + X_scale = 1.0 + X_zp = 0 + X_value_min = 0 + X_value_max = 255 + X_q0 = np.round(np.random.rand(batch_size, input_channels) * ( + X_value_max - X_value_min) + X_value_min + ).astype(np.uint8) + X_q0[0, 0] = X_value_min + X_q0[0, 1] = X_value_max + + W_scale = 1.0 + W_zp = 0 + W_value_min = -128 + W_value_max = 127 + W_q0 = np.round( + np.random.rand(output_channels, input_channels) + * (W_value_max - W_value_min) + + W_value_min + ).astype(np.int8) + W_q0[0, 0] = W_value_min + W_q0[1, 0] = W_value_max + + b_value_min = -10 + b_value_max = 10 + b_q0 = np.round( + np.random.rand(output_channels) * (b_value_max - b_value_min) + + b_value_min + ).astype(np.int32) + + avoid_vpmaddubsw_overflow_linear( + batch_size, + input_channels, + output_channels, + X_q0, + X_value_min, + X_value_max, + W_q0, + W_value_min, + W_value_max, + ) + + X_fp32 = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float) + W_fp32 = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float) + b_fp32 = torch.from_numpy( + _dequantize(b_q0, X_scale * W_scale, 0) + ).to(dtype=torch.float) + + W_scale, W_zp = _calculate_dynamic_qparams(W_fp32, torch.qint8) + W_q = torch.quantize_per_tensor(W_fp32, scale=W_scale, zero_point=W_zp, dtype=torch.qint8) + + # Observe X_fp32 and determine X_scale and X_zero_point, this should match + # internals of dynamic linear. + X_scale, X_zp = _calculate_dynamic_qparams(X_fp32, torch.quint8) + X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) + + W_int8, col_offsets, W_scale, W_zp = torch.fbgemm_linear_quantize_weight(W_q.dequantize()) + W_prepack = torch.fbgemm_pack_quantized_matrix(W_int8.clone(), W_int8.size(1), W_int8.size(0)) + # Quantized Linear operator with prepacked weight + Y_fp32 = torch.fbgemm_linear_int8_weight( + X_q.dequantize(), W_q.dequantize(), W_prepack, col_offsets, + W_scale, W_zp, b_fp32) + + Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32) + # Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) + + self.assertEqual(Y_fp32, Y_fp32_ref, + msg="torch.ops.quantized.fbgemm_linear_dynamic results are off") + @skipIfNoFBGEMM @given( input_channels=st.integers(16, 32), diff --git a/test/quantization/jit/test_deprecated_jit_quant.py b/test/quantization/jit/test_deprecated_jit_quant.py index ec9d54fe7c0d8a..806cff230fe4a8 100644 --- a/test/quantization/jit/test_deprecated_jit_quant.py +++ b/test/quantization/jit/test_deprecated_jit_quant.py @@ -4,8 +4,11 @@ from torch.testing._internal.common_quantization import ( skipIfNoFBGEMM ) +from torch.testing._internal.common_utils import suppress_warnings from torch.testing._internal.jit_utils import JitTestCase +from typing import Tuple +import copy class TestDeprecatedJitQuantized(JitTestCase): @skipIfNoFBGEMM @@ -51,8 +54,54 @@ def test_rnn_cell_quantized(self): torch.tensor(vals, dtype=torch.float), requires_grad=False) - with self.assertRaisesRegex(RuntimeError, "quantize_rnn_cell_modules function is no longer supported"): - cell = torch.jit.quantized.quantize_rnn_cell_modules(cell) + ref = copy.deepcopy(cell) + + cell = torch.jit.quantized.quantize_rnn_cell_modules(cell) + x = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float) + h0_vals = [[-155, 100], + [-155, 155], + [100, -155]] + hx = torch.tensor(h0_vals, dtype=torch.float) + if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): + cx = torch.tensor(h0_vals, dtype=torch.float) + hiddens = (hx, cx) + else: + hiddens = hx + + if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell): + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super().__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x: torch.Tensor, + hiddens: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.cell(x, hiddens) + else: + + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super().__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor: + return self.cell(x, hiddens) + + cell = ScriptWrapper(cell) + outs = cell(x, hiddens) + cell = self.getExportImportCopyWithPacking(cell) + + outs = cell(x, hiddens) + ref_outs = ref(x, hiddens) + + self.assertEqual(len(outs), len(ref_outs)) + for out, ref_out in zip(outs, ref_outs): + torch.testing.assert_close(out, ref_out) @skipIfNoFBGEMM def test_rnn_quantized(self): @@ -94,14 +143,85 @@ def test_rnn_quantized(self): torch.tensor(vals, dtype=torch.float), requires_grad=False) - with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"): - cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8) + ref = copy.deepcopy(cell) + cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8) + cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16) + + niter = 10 + x = torch.tensor([[100, -155], + [-155, 100], + [100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1) + h0_vals = [[-155, 100], + [-155, 155], + [100, -155]] + hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) + cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0) + + if isinstance(ref, torch.nn.LSTM): + hiddens = (hx, cx) + elif isinstance(ref, torch.nn.GRU): + hiddens = hx + + ref_out, ref_hid = ref(x, hiddens) - with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"): - cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16) + # Compare int8 quantized to unquantized + output_int8, final_hiddens_int8 = cell_int8(x, hiddens) + torch.testing.assert_close(output_int8, ref_out) + for out, ref in zip(final_hiddens_int8, ref_hid): + torch.testing.assert_close(out, ref) + + # Compare fp16 quantized to unquantized + output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) + + torch.testing.assert_close(output_fp16, ref_out) + for out, ref in zip(final_hiddens_fp16, ref_hid): + torch.testing.assert_close(out, ref) + + def compare_quantized_unquantized(ScriptWrapper, cell): + wrapper = ScriptWrapper(cell) + + # Compare quantize scripted module to unquantized + script_out, script_hid = wrapper(x, hiddens) + torch.testing.assert_close(script_out, ref_out) + for out, ref in zip(script_hid, ref_hid): + torch.testing.assert_close(out, ref) + + # Compare export/import to unquantized + export_import_wrapper = self.getExportImportCopyWithPacking(wrapper) + ei_out, ei_hid = export_import_wrapper(x, hiddens) + torch.testing.assert_close(ei_out, ref_out) + for out, ref in zip(ei_hid, ref_hid): + torch.testing.assert_close(out, ref) + + if isinstance(cell, torch.jit.quantized.QuantizedGRU): + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super().__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.cell(x, hiddens) + + compare_quantized_unquantized(ScriptWrapper, cell) + elif isinstance(cell, torch.jit.quantized.QuantizedLSTM): + for cell in [cell_int8, cell_fp16]: + class ScriptWrapper(torch.jit.ScriptModule): + def __init__(self, cell): + super().__init__() + self.cell = cell + + @torch.jit.script_method + def forward(self, x, hiddens): + # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) + # -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + return self.cell(x, hiddens) + compare_quantized_unquantized(ScriptWrapper, cell) if 'fbgemm' in torch.backends.quantized.supported_engines: + # Suppression: using deprecated quant api + @suppress_warnings def test_quantization_modules(self): K1, N1 = 2, 2 @@ -124,11 +244,18 @@ def forward(self, x): y_ref = fb(value) - with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"): - fb_int8 = torch.jit.quantized.quantize_linear_modules(fb) + fb_int8 = torch.jit.quantized.quantize_linear_modules(fb) + traced_int8 = torch.jit.trace(fb_int8, (x,)) + fb_int8 = self.getExportImportCopyWithPacking(traced_int8) + y_int8 = fb_int8(value) + + fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16) + traced_fp16 = torch.jit.trace(fb_fp16, (x,)) + fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16) + y_fp16 = fb_fp16(value) - with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"): - fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16) + torch.testing.assert_close(y_int8, y_ref, rtol=0.0001, atol=1e-3) + torch.testing.assert_close(y_fp16, y_ref, rtol=0.0001, atol=1e-3) @skipIfNoFBGEMM def test_erase_class_tensor_shapes(self): diff --git a/test/test_nn.py b/test/test_nn.py index 2815ddcf2d4d5f..115a70485990ee 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2112,6 +2112,25 @@ def test_threshold_bfloat16_half(self): res_bf16 = F.threshold(x.to(dtype=dtype), threshold, 0).float() self.assertEqual(res_bf16, expected) + @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, + 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs' + ' with instruction set support avx2 or newer.') + def test_fb_fc_packed(self): + X = np.random.rand(16, 16).astype(np.float32) - 0.5 + W = np.random.rand(16, 16).astype(np.float32) - 0.5 + b = np.random.rand(16).astype(np.float32) - 0.5 + + def fc_op(X, W, b): + return np.dot(X, W.T) + b + + x_tensor = torch.tensor(X) + w_tensor = torch.tensor(W) + b_tensor = torch.tensor(b) + packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor) + actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor) + expected_output = fc_op(X, W, b) + torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3) + def test_pad_scalar_error(self): inputs = torch.tensor(0., requires_grad=True) self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1))) diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index c7c679c7945697..63de5c5bb4632e 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -1,99 +1,798 @@ +import warnings + +from typing import List, Optional, Tuple + import torch +from torch import _VF, Tensor # noqa: F401 +from torch.nn.utils.rnn import PackedSequence class QuantizedLinear(torch.jit.ScriptModule): + __constants__ = ["scale", "zero_point"] + def __init__(self, other): - raise RuntimeError( - "torch.jit.QuantizedLinear is no longer supported. Please use " - "torch.ao.nn.quantized.dynamic.Linear instead." + super().__init__() + warnings.warn( + "torch.jit.QuantizedLinear is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead." + ) + + self.in_features = other.in_features + self.out_features = other.out_features + # Quantize weight and discard the original + ( + self.weight, + self.col_offsets, + self.scale, + self.zero_point, + ) = torch.fbgemm_linear_quantize_weight( + other.weight.clone(memory_format=torch.contiguous_format).float() + ) + self.weight = torch.nn.Parameter(self.weight, requires_grad=False) + self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False) + assert other.bias is not None, "QuantizedLinear requires a bias" + self.bias = torch.nn.Parameter( + other.bias.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) + + self.register_buffer( + "packed_tensor_ptr", + torch.fbgemm_pack_quantized_matrix( + self.weight.clone(memory_format=torch.contiguous_format) + ), + ) + + @torch.jit.script_method + def _unpack(self): + self.packed_tensor_ptr.set_(torch.fbgemm_pack_quantized_matrix(self.weight)) + + @torch.jit.script_method + def _pack(self): + self.packed_tensor_ptr.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() + ) + + @torch.jit.script_method + def forward(self, input): + out = torch.fbgemm_linear_int8_weight_fp32_activation( + input.float(), + self.weight, + self.packed_tensor_ptr, + self.col_offsets, + self.scale, + self.zero_point, + self.bias, ) + return out.to(input.dtype) + + def extra_repr(self): + repr = ( + "in_features={in_features}, out_features={out_features}, " + "scale={scale}, zero_point={zero_point}".format(**self.__dict__) + ) + return repr # FP16 weights class QuantizedLinearFP16(torch.jit.ScriptModule): def __init__(self, other): super().__init__() - raise RuntimeError( - "torch.jit.QuantizedLinearFP16 is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic.Linear instead." + warnings.warn( + "torch.jit.QuantizedLinearFP16 is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead." + ) + self.in_features = other.in_features + self.out_features = other.out_features + self.original_weight = other.weight + self.weight = torch.fbgemm_pack_gemm_matrix_fp16( + other.weight.clone(memory_format=torch.contiguous_format).float() + ) + assert other.bias is not None, "QuantizedLinearFP16 requires a bias" + self.bias = torch.nn.Parameter( + other.bias.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) + self.register_buffer("packed_weight", self.weight) + + @torch.jit.script_method + def _unpack(self): + self.packed_weight.set_( + torch.fbgemm_pack_gemm_matrix_fp16(self.original_weight) + ) + + @torch.jit.script_method + def _pack(self): + self.packed_weight.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() + ) + + @torch.jit.script_method + def forward(self, input): + out = torch.fbgemm_linear_fp16_weight_fp32_activation( + input.float(), self.packed_weight, self.bias ) + return out + + def extra_repr(self): + repr = "in_features={in_features}, out_features={out_features}, ".format( + **self.__dict__ + ) + return repr # Quantized RNN cell implementations class QuantizedRNNCellBase(torch.jit.ScriptModule): + __constants__ = [ + "input_size", + "hidden_size", + "bias", + "scale_hh", + "scale_ih", + "zero_point_ih", + "zero_point_hh", + ] + def __init__(self, other): - raise RuntimeError( - "torch.jit.QuantizedRNNCellBase is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." + super().__init__() + warnings.warn( + "torch.jit.QuantizedRNNCellBase is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." + ) + + self.input_size = other.input_size + self.hidden_size = other.hidden_size + self.bias = other.bias + if not self.bias: + raise ValueError("Quantized RNN cells require bias terms") + + ( + weight_ih, + col_offsets_ih, + self.scale_ih, + self.zero_point_ih, + ) = torch.fbgemm_linear_quantize_weight( + other.weight_ih.clone(memory_format=torch.contiguous_format).float() + ) + self.register_buffer("weight_ih", weight_ih) + self.register_buffer("col_offsets_ih", col_offsets_ih) + ( + weight_hh, + col_offsets_hh, + self.scale_hh, + self.zero_point_hh, + ) = torch.fbgemm_linear_quantize_weight( + other.weight_hh.clone(memory_format=torch.contiguous_format).float() + ) + self.register_buffer("weight_hh", weight_hh) + self.register_buffer("col_offsets_hh", col_offsets_hh) + + packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih) + self.register_buffer("packed_ih", packed_ih) + packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh) + self.register_buffer("packed_hh", packed_hh) + + self.bias_ih = torch.nn.Parameter( + other.bias_ih.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) + self.bias_hh = torch.nn.Parameter( + other.bias_hh.clone(memory_format=torch.contiguous_format).float(), + requires_grad=False, + ) + + def extra_repr(self): + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": + s += ", nonlinearity={nonlinearity}" + return s.format(**self.__dict__) + + @torch.jit.script_method + def check_forward_input(self, input): + if input.size(1) != self.input_size: + raise RuntimeError( + f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}" + ) + + @torch.jit.script_method + def check_forward_hidden( + self, input: Tensor, hx: Tensor, hidden_label: str = "" + ) -> None: + if input.size(0) != hx.size(0): + raise RuntimeError( + f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}" + ) + + if hx.size(1) != self.hidden_size: + raise RuntimeError( + f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}" + ) + + # TODO: for some reason weak_script_method causes a destruction of the + # module to occur, which in turn frees the packed_ih object via its DataPtr + # deleter. This is bizarre and should probably get fixed. + # @torch._jit_internal.weak_script_method + @torch.jit.script_method + def _unpack(self): + self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih)) + self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh)) + + # @torch._jit_internal.weak_script_method + @torch.jit.script_method + def _pack(self): + self.packed_ih.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() + ) + self.packed_hh.set_( + torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach() ) class QuantizedRNNCell(QuantizedRNNCellBase): + __constants__ = [ + "input_size", + "hidden_size", + "bias", + "scale_hh", + "scale_ih", + "zero_point_ih", + "zero_point_hh", + "nonlinearity", + ] + def __init__(self, other): - raise RuntimeError( - "torch.jit.QuantizedRNNCell is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." + super().__init__(other) + warnings.warn( + "torch.jit.QuantizedRNNCell is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead." ) + self.nonlinearity = other.nonlinearity + + @torch.jit.script_method + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + if self.nonlinearity == "tanh": + ret = _VF.quantized_rnn_tanh_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, + ) + elif self.nonlinearity == "relu": + ret = _VF.quantized_rnn_relu_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, + ) + else: + ret = input # TODO: remove when jit supports exception flow + raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") + return ret class QuantizedLSTMCell(QuantizedRNNCellBase): def __init__(self, other): super().__init__(other) - raise RuntimeError( - "torch.jit.QuantizedLSTMCell is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead." + warnings.warn( + "torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead." + ) + + @torch.jit.script_method + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tensor]: + self.check_forward_input(input) + if hx is None: + zeros = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + hx = (zeros, zeros) + self.check_forward_hidden(input, hx[0], "[0]") + self.check_forward_hidden(input, hx[1], "[1]") + return _VF.quantized_lstm_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, ) class QuantizedGRUCell(QuantizedRNNCellBase): def __init__(self, other): super().__init__(other) - raise RuntimeError( - "torch.jit.QuantizedGRUCell is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic.GRUCell instead." + warnings.warn( + "torch.jit.QuantizedGRUCell is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRUCell instead." ) + @torch.jit.script_method + def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: + self.check_forward_input(input) + if hx is None: + hx = torch.zeros( + input.size(0), self.hidden_size, dtype=input.dtype, device=input.device + ) + self.check_forward_hidden(input, hx, "") + return _VF.quantized_gru_cell( + input, + hx, + self.weight_ih, + self.weight_hh, + self.bias_ih, + self.bias_hh, + self.packed_ih, + self.packed_hh, + self.col_offsets_ih, + self.col_offsets_hh, + self.scale_ih, + self.scale_hh, + self.zero_point_ih, + self.zero_point_hh, + ) + + +def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: + return tensor.index_select(dim, permutation) + class QuantizedRNNBase(torch.jit.ScriptModule): + __constants__ = [ + "mode", + "input_size", + "hidden_size", + "num_layers", + "bias", + "batch_first", + "dropout", + "bidirectional", + "dtype", + ] + def __init__(self, other, dtype=torch.int8): - raise RuntimeError( - "torch.jit.QuantizedRNNBase is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic instead." + super().__init__() + warnings.warn( + "torch.jit.QuantizedRNNBase is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic instead." ) + self.mode = other.mode + self.input_size = other.input_size + self.hidden_size = other.hidden_size + self.num_layers = other.num_layers + self.bias = other.bias + self.batch_first = other.batch_first + if self.mode != "GRU": + assert not self.batch_first + self.dropout = other.dropout + self.bidirectional = other.bidirectional + num_directions = 2 if self.bidirectional else 1 + self.dtype = dtype + + assert self.bias + + # TODO: support more than just LSTM + if self.mode != "LSTM" and self.mode != "GRU": + raise RuntimeError("Only LSTM or GRU is supported for QuantizedRNN") + + if dtype != torch.int8 and dtype != torch.float16: + raise RuntimeError(f"Unsupported dtype: {dtype}") + + self.all_weights = [] + for layer in range(self.num_layers): + for direction in range(num_directions): + layer_input_size = ( + self.input_size if layer == 0 else self.hidden_size * num_directions + ) + + suffix = "_reverse" if direction == 1 else "" + + def get_weight_bias(ihhh): + weight_name = f"weight_{ihhh}_l{layer}{suffix}" + bias_name = f"bias_{ihhh}_l{layer}{suffix}" + + weight = getattr(other, weight_name) + bias = getattr(other, bias_name) + return weight, bias + + weight_ih, bias_ih = get_weight_bias("ih") + weight_hh, bias_hh = get_weight_bias("hh") + + if dtype == torch.int8: + cell_params = torch.ops.quantized.make_quantized_cell_params( + weight_ih, weight_hh, bias_ih, bias_hh + ) + else: + packed_ih = torch.ops.quantized.linear_prepack_fp16( + weight_ih.float(), bias_ih + ) + packed_hh = torch.ops.quantized.linear_prepack_fp16( + weight_hh.float(), bias_hh + ) + + cell_params = torch.ops.quantized.make_quantized_cell_params_fp16( + packed_ih, packed_hh + ) + + setattr(self, f"cell_params_{layer}_{suffix}", cell_params) + self.all_weights.append(cell_params) + + @torch.jit.script_method + def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: + expected_input_dim = 2 if batch_sizes is not None else 3 + if input.dim() != expected_input_dim: + raise RuntimeError( + f"input must have {expected_input_dim} dimensions, got {input.dim()}" + ) + if self.input_size != input.size(-1): + raise RuntimeError( + f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" + ) + + @torch.jit.script_method + def get_expected_hidden_size( + self, input: Tensor, batch_sizes: Optional[Tensor] + ) -> Tuple[int, int, int]: + if batch_sizes is not None: + mini_batch = int(batch_sizes[0]) + else: + mini_batch = input.size(0) if self.batch_first else input.size(1) + num_directions = 2 if self.bidirectional else 1 + expected_hidden_size = ( + self.num_layers * num_directions, + mini_batch, + self.hidden_size, + ) + return expected_hidden_size + + @torch.jit.script_method + def check_hidden_size( + self, + hx: Tensor, + expected_hidden_size: Tuple[int, int, int], + msg: str = "Expected hidden size {}, got {}", + ) -> None: + if hx.size() != expected_hidden_size: + raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) + + @torch.jit.script_method + def check_forward_args( + self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + self.check_hidden_size( + hidden, expected_hidden_size, msg="Expected hidden size {}, got {}" + ) + + @torch.jit.script_method + def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor: + if permutation is None: + return hx + return apply_permutation(hx, permutation) class QuantizedLSTM(QuantizedRNNBase): + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + def __init__(self, other, dtype): - raise RuntimeError( - "torch.jit.QuantizedLSTM is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic.LSTM instead." + super().__init__(other, dtype) + warnings.warn( + "torch.jit.QuantizedLSTM is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTM instead." + ) + + @torch.jit.script_method + def forward_impl( + self, + input: Tensor, + hx: Optional[Tuple[Tensor, Tensor]], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + zeros = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + hx = (zeros, zeros) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + assert batch_sizes is None + result = torch.quantized_lstm( + input, + hx, + self.all_weights, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + self.batch_first, + dtype=self.dtype, + use_dynamic=False, + ) + output = result[0] + hidden = result[1:] + + return output, hidden + + @torch.jit.script_method + def forward_tensor( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.script_method + def forward_packed( + self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None + ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + + output, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.script_method + def permute_hidden( + self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor] + ) -> Tuple[Tensor, Tensor]: + if permutation is None: + return hx + return apply_permutation(hx[0], permutation), apply_permutation( + hx[1], permutation + ) + + @torch.jit.script_method + def check_forward_args( + self, + input: Tensor, + hidden: Tuple[Tensor, Tensor], + batch_sizes: Optional[Tensor], + ) -> None: + self.check_input(input, batch_sizes) + expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) + + self.check_hidden_size( + hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}" ) + self.check_hidden_size( + hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}" + ) + + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) class QuantizedGRU(QuantizedRNNBase): + __overloads__ = {"forward": ["forward_packed", "forward_tensor"]} + def __init__(self, *args, **kwargs): - raise RuntimeError( - "torch.jit.QuantizedGRU is no longer supported. " - "Please use the torch.ao.nn.quantized.dynamic.GRU instead." + super().__init__(*args, **kwargs) + warnings.warn( + "torch.jit.QuantizedGRU is deprecated and will be removed in an upcoming " + "PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRU instead." ) + @torch.jit.script_method + def forward_impl( + self, + input: Tensor, + hx: Optional[Tensor], + batch_sizes: Optional[Tensor], + max_batch_size: int, + sorted_indices: Optional[Tensor], + ) -> Tuple[Tensor, Tensor]: + if hx is None: + num_directions = 2 if self.bidirectional else 1 + hx = torch.zeros( + self.num_layers * num_directions, + max_batch_size, + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + else: + # Each batch of the hidden state should match the input sequence that + # the user believes he/she is passing in. + hx = self.permute_hidden(hx, sorted_indices) + + self.check_forward_args(input, hx, batch_sizes) + if batch_sizes is None: + result = torch.quantized_gru( + input, + hx, + self.all_weights, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + self.batch_first, + ) + else: + result = torch.quantized_gru( + input, + batch_sizes, + hx, + self.all_weights, + self.bias, + self.num_layers, + float(self.dropout), + self.training, + self.bidirectional, + ) + + output = result[0] + hidden = result[1] + + return output, hidden + + @torch.jit.script_method + def forward_tensor( + self, input: Tensor, hx: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + batch_sizes = None + max_batch_size = input.size(0) if self.batch_first else input.size(1) + sorted_indices = None + unsorted_indices = None + + output, hidden = self.forward_impl( + input, hx, batch_sizes, max_batch_size, sorted_indices + ) + return output, self.permute_hidden(hidden, unsorted_indices) + + @torch.jit.script_method + def forward_packed( + self, input: PackedSequence, hx: Optional[Tensor] = None + ) -> Tuple[PackedSequence, Tensor]: + input_, batch_sizes, sorted_indices, unsorted_indices = input + max_batch_size = int(batch_sizes[0]) + + output, hidden = self.forward_impl( + input_, hx, batch_sizes, max_batch_size, sorted_indices + ) + + output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) + return output, self.permute_hidden(hidden, unsorted_indices) + + def forward(self, input, hx=None): + if isinstance(input, PackedSequence): + return self.forward_packed(input, hx) + else: + return self.forward_tensor(input, hx) + def quantize_rnn_cell_modules(module): - raise RuntimeError( - "quantize_rnn_cell_modules function is no longer supported. " + warnings.warn( + "quantize_rnn_cell_modules function has been deprecated. " "Please use torch.ao.quantization.quantize_dynamic API instead." ) + reassign = {} + for name, mod in module.named_modules(): + if mod is module: + continue + new_mod = quantize_rnn_cell_modules(mod) + if new_mod is not mod: + reassign[name] = new_mod + for name, mod in reassign.items(): + setattr(module, name, mod) + if isinstance(module, torch.nn.LSTMCell): + return QuantizedLSTMCell(module) + if isinstance(module, torch.nn.GRUCell): + return QuantizedGRUCell(module) + if isinstance(module, torch.nn.RNNCell): + return QuantizedRNNCell(module) + return module def quantize_linear_modules(module, dtype=torch.int8): - raise RuntimeError( - "quantize_linear_modules function is no longer supported. " + warnings.warn( + "quantize_linear_modules function has been deprecated. " "Please use torch.ao.quantization.quantize_dynamic API instead." ) + reassign = {} + for name, mod in module.named_modules(): + if mod is module: + continue + new_mod = quantize_linear_modules(mod, dtype) + if new_mod is not mod: + reassign[name] = new_mod + + for name, mod in reassign.items(): + setattr(module, name, mod) + if isinstance(module, torch.nn.Linear): + if dtype == torch.int8: + return QuantizedLinear(module) + elif dtype == torch.float16: + return QuantizedLinearFP16(module) + else: + raise RuntimeError(f"Unsupported dtype: {dtype}") + return module + def quantize_rnn_modules(module, dtype=torch.int8): - raise RuntimeError( - "quantize_rnn_modules function is no longer supported. " + warnings.warn( + "quantize_rnn_modules function has been deprecated. " "Please use torch.ao.quantization.quantize_dynamic API instead." ) + reassign = {} + for name, mod in module.named_modules(): + if mod is module: + continue + new_mod = quantize_rnn_modules(mod, dtype) + if new_mod is not mod: + reassign[name] = new_mod + + for name, mod in reassign.items(): + setattr(module, name, mod) + if isinstance(module, torch.nn.LSTM): + if dtype != torch.int8 and dtype != torch.float16: + raise RuntimeError(f"Unsupported dtype: {dtype}") + return QuantizedLSTM(module, dtype) + if isinstance(module, torch.nn.GRU): + return QuantizedGRU(module) + return module diff --git a/torch/overrides.py b/torch/overrides.py index 2e219160e5b4cf..fd97a61878bda0 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -586,6 +586,14 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1), + torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1, + torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1, + torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, + torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale, + weight_zero_point, bias: -1), + torch.fbgemm_linear_quantize_weight: lambda input: -1, + torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1, + torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1, torch.feature_alpha_dropout: lambda input, p, train: -1, torch.feature_dropout: lambda input, p, train: -1, torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1, @@ -970,12 +978,21 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1, torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1, torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1, + torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + + torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,), dilation=(1,), ceil_mode=False: -1), torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(1, 1), ceil_mode=False: -1), torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), dilation=(1, 1, 1), ceil_mode=False: -1), + torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), + torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, + col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1), torch.rad2deg: lambda input, out=None: -1, torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,