Skip to content

Commit

Permalink
Remove deprecated fbgemm operators (pytorch#104535)
Browse files Browse the repository at this point in the history
These operators are not used and have been deprecated since pytorch#72690 (Feb 2022). Additionally, the `torch.jit.quantized` interface has been deprecated since pytorch#40102 (June 2020).
Pull Request resolved: pytorch#104535
Approved by: https://github.com/ezyang
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Oct 22, 2023
1 parent bf01a7b commit 57c7aa1
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 1,808 deletions.
585 changes: 0 additions & 585 deletions aten/src/ATen/native/QuantizedLinear.cpp

This file was deleted.

203 changes: 0 additions & 203 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,12 @@
#include <ATen/ops/cat.h>
#include <ATen/ops/cudnn_is_acceptable.h>
#include <ATen/ops/dropout.h>
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
#include <ATen/ops/gru_cell_native.h>
#include <ATen/ops/gru_native.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/lstm_cell_native.h>
#include <ATen/ops/lstm_native.h>
#include <ATen/ops/matmul.h>
#include <ATen/ops/quantized_gru_cell_native.h>
#include <ATen/ops/quantized_lstm_cell_native.h>
#include <ATen/ops/quantized_rnn_relu_cell_native.h>
#include <ATen/ops/quantized_rnn_tanh_cell_native.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/rnn_relu_cell_native.h>
#include <ATen/ops/rnn_relu_native.h>
Expand Down Expand Up @@ -208,158 +201,6 @@ struct CellParams : public CellParamsBase {
}
};

c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
const at::Tensor& w_ih,
const at::Tensor& w_hh,
at::Tensor bias_ih,
at::Tensor bias_hh);

struct QuantizedCellParams : public CellParamsBase {
QuantizedCellParams(
Tensor _w_ih,
Tensor _w_hh,
Tensor _b_ih,
Tensor _b_hh,
Tensor _packed_ih,
Tensor _packed_hh,
Tensor _col_offsets_ih,
Tensor _col_offsets_hh,
Scalar _scale_ih,
Scalar _scale_hh,
Scalar _zero_point_ih,
Scalar _zero_point_hh)
: w_ih(std::move(_w_ih)),
w_hh(std::move(_w_hh)),
b_ih_(std::move(_b_ih)),
b_hh_(std::move(_b_hh)),
packed_ih(std::move(_packed_ih)),
packed_hh(std::move(_packed_hh)),
col_offsets_ih(std::move(_col_offsets_ih)),
col_offsets_hh(std::move(_col_offsets_hh)),
scale_ih(std::move(_scale_ih)),
scale_hh(std::move(_scale_hh)),
zero_point_ih(std::move(_zero_point_ih)),
zero_point_hh(std::move(_zero_point_hh)) {}

const Tensor w_ih;
const Tensor w_hh;
const Tensor b_ih_;
const Tensor b_hh_;
const Tensor packed_ih;
const Tensor packed_hh;
const Tensor col_offsets_ih;
const Tensor col_offsets_hh;
const Scalar scale_ih;
const Scalar scale_hh;
const Scalar zero_point_ih;
const Scalar zero_point_hh;

Tensor matmul_ih(const Tensor& input) const override {
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
}
Tensor matmul_hh(const Tensor& h) const override {
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
}
Tensor linear_ih(const Tensor& input) const override {
return at::fbgemm_linear_int8_weight_fp32_activation(
input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih_);
}
Tensor linear_hh(const Tensor& h) const override {
return at::fbgemm_linear_int8_weight_fp32_activation(
h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh_);
}
const Tensor& b_ih() const override {
return b_ih_;
}
const Tensor& b_hh() const override {
return b_hh_;
}
CellParamsSerializationType __getstate__() const override {
std::vector<at::Tensor> tensors_to_serialize = {
w_ih, w_hh, b_ih_, b_hh_, col_offsets_ih, col_offsets_hh};
std::vector<double> doubles_to_serialize = {scale_ih.toDouble(),
scale_hh.toDouble()};
std::vector<int64_t> longs_to_serialize = {zero_point_ih.toLong(),
zero_point_hh.toLong()};
return CellParamsSerializationType(
"quantized",
std::move(tensors_to_serialize),
std::move(doubles_to_serialize),
std::move(longs_to_serialize),
{});
}
static c10::intrusive_ptr<CellParamsBase> __setstate__(
CellParamsSerializationType state) {
std::vector<at::Tensor> tensors;
std::vector<double> doubles;
std::vector<int64_t> longs;
std::tie(std::ignore, tensors, doubles, longs, std::ignore) =
std::move(state);
TORCH_INTERNAL_ASSERT(tensors.size() == 6);
TORCH_INTERNAL_ASSERT(doubles.size() == 2);
TORCH_INTERNAL_ASSERT(longs.size() == 2);

at::Tensor qw_ih = std::move(tensors[0]), qw_hh = std::move(tensors[1]),
b_ih = std::move(tensors[2]), b_hh = std::move(tensors[3]),
col_offsets_ih = std::move(tensors[4]),
col_offsets_hh = std::move(tensors[5]);
double scale_ih = doubles[0], scale_hh = doubles[1];
int64_t zero_point_ih = longs[0], zero_point_hh = longs[1];

at::Tensor packed_ih = at::native::fbgemm_pack_quantized_matrix(qw_ih);
at::Tensor packed_hh = at::native::fbgemm_pack_quantized_matrix(qw_hh);

return c10::make_intrusive<QuantizedCellParams>(
/*w_ih=*/std::move(qw_ih),
/*w_hh=*/std::move(qw_hh),
/*b_ih_=*/std::move(b_ih),
/*b_hh_=*/std::move(b_hh),
/*packed_ih=*/std::move(packed_ih),
/*packed_hh=*/std::move(packed_hh),
/*col_offsets_ih=*/std::move(col_offsets_ih),
/*col_offsets_hh=*/std::move(col_offsets_hh),
/*scale_ih=*/scale_ih,
/*scale_hh=*/scale_hh,
/*zero_point_ih=*/zero_point_ih,
/*zero_point_hh=*/zero_point_hh);
}
};

c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
const at::Tensor& w_ih,
const at::Tensor& w_hh,
at::Tensor b_ih,
at::Tensor b_hh) {
auto make_vals = [&](const at::Tensor& W) {
auto params = at::native::fbgemm_linear_quantize_weight(W);
at::Tensor packed_weight =
at::native::fbgemm_pack_quantized_matrix(std::get<0>(params));
return std::tuple_cat(
std::make_tuple(std::move(packed_weight)), std::move(params));
};

at::Tensor qw_ih, qw_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh;
at::Scalar scale_ih, scale_hh, zero_point_ih, zero_point_hh;

std::tie(packed_ih, qw_ih, col_offsets_ih, scale_ih, zero_point_ih) =
make_vals(w_ih);
std::tie(packed_hh, qw_hh, col_offsets_hh, scale_hh, zero_point_hh) =
make_vals(w_hh);

return c10::make_intrusive<QuantizedCellParams>(
/*qw_ih=*/std::move(qw_ih),
/*qw_hh=*/std::move(qw_hh),
/*b_ih=*/std::move(b_ih),
/*b_hh=*/std::move(b_hh),
/*packed_ih=*/std::move(packed_ih),
/*packed_hh=*/std::move(packed_hh),
/*col_offsets_ih=*/std::move(col_offsets_ih),
/*col_offsets_hh=*/std::move(col_offsets_hh),
/*scale_ih=*/std::move(scale_ih),
/*scale_hh=*/std::move(scale_hh),
/*zero_point_ih=*/std::move(zero_point_ih),
/*zero_point_hh=*/std::move(zero_point_hh));
}

// QuantizedCellParams vs. QuantizedCellParamsDynamic
//
Expand Down Expand Up @@ -536,7 +377,6 @@ static std::unordered_map<
std::string,
c10::intrusive_ptr<CellParamsBase> (*)(CellParamsSerializationType)>
cell_params_deserializers = {
{"quantized", &QuantizedCellParams::__setstate__},
{"quantized_dynamic", &QuantizedCellParamsDynamic::__setstate__},
{"quantized_fp16", &QuantizedCellParamsFP16::__setstate__}};

Expand Down Expand Up @@ -1841,38 +1681,6 @@ static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data_legacy(
"using the newer definitions in torch.jit.quantized");
}

#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \
return_type name( \
const Tensor& input, \
hx_type hx, \
const Tensor& w_ih, \
const Tensor& w_hh, \
const Tensor& b_ih, \
const Tensor& b_hh, \
const Tensor& packed_ih, \
const Tensor& packed_hh, \
const Tensor& col_offsets_ih, \
const Tensor& col_offsets_hh, \
const Scalar& scale_ih, \
const Scalar& scale_hh, \
const Scalar& zero_point_ih, \
const Scalar& zero_point_hh) { \
QuantizedCellParams params( \
w_ih, \
w_hh, \
b_ih, \
b_hh, \
packed_ih, \
packed_hh, \
col_offsets_ih, \
col_offsets_hh, \
scale_ih, \
scale_hh, \
zero_point_ih, \
zero_point_hh); \
return cell_type{}( \
input, prepare_hx_fn(hx), params); \
}
// Set reduced range to be True for all RNN Cells by default. This flag is used only for FBGEMM kernels
// QNNPACK does not reduce range for activations
#define DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(name, hx_type, cell_type, return_type, prepare_hx_fn) \
Expand All @@ -1895,7 +1703,6 @@ return_type name( \
}

// Quantized LSTM cell
using quantized_lstm_cell_type = LSTMCell<QuantizedCellParams>;
using quantized_lstm_return_type = std::tuple<Tensor, Tensor>;
static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
return std::make_tuple(hx[0], hx[1]);
Expand All @@ -1904,7 +1711,6 @@ static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
// Quantized LSTM cell
using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;

DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);

static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);

Expand All @@ -1915,22 +1721,15 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
}

// Quantized GRU cell
using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;

DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);

static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);

// Quantized RNN w/ ReLU cell
using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);

// Quantized RNN w/ tanh cell
using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);

Expand Down Expand Up @@ -1972,7 +1771,6 @@ TORCH_LIBRARY_FRAGMENT(aten, m) {
TORCH_LIBRARY_FRAGMENT(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
Expand All @@ -1992,7 +1790,6 @@ TORCH_LIBRARY_IMPL(aten, CPU, m) {

TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params_dynamic"), TORCH_FN(make_quantized_cell_params_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params"), TORCH_FN(make_quantized_cell_params));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_lstm_cell_dynamic"), TORCH_FN(quantized_lstm_cell_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_gru_cell_dynamic"), TORCH_FN(quantized_gru_cell_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_rnn_relu_cell_dynamic"), TORCH_FN(quantized_rnn_relu_cell_dynamic));
Expand Down
25 changes: 0 additions & 25 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3285,22 +3285,6 @@
dispatch:
CUDA: _mixed_dtypes_linear

- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor

- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor

- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)

- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor

- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor

- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor

- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor

- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor

- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor
variants: function, method

Expand Down Expand Up @@ -7611,15 +7595,6 @@
# - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
#

# Quantized RNN cells
- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)

- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor

- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor

- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor

# PackedSequence utilities
- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
dispatch:
Expand Down
1 change: 0 additions & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,6 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/PointwiseOps.cpp",
"aten/src/ATen/native/Pooling.cpp",
"aten/src/ATen/native/Pow.cpp",
"aten/src/ATen/native/QuantizedLinear.cpp",
"aten/src/ATen/native/RNN.cpp",
"aten/src/ATen/native/RangeFactories.cpp",
"aten/src/ATen/native/ReduceAllOps.cpp",
Expand Down
1 change: 0 additions & 1 deletion caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,6 @@ set(ATen_CPU_INCLUDE
${CMAKE_BINARY_DIR}/aten/src)

if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/QuantizedLinear.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/RNN.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/qlinear_unpack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,18 @@
("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_int8_weight_fp32_activation", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_int8_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_quantize_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_pack_gemm_matrix_fp16", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_fp16_weight_fp32_activation", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_fp16_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_pack_quantized_matrix", datetime.date(2023, 12, 31)),
("aten::quantized_lstm_cell", datetime.date(2023, 12, 31)),
("aten::quantized_gru_cell", datetime.date(2023, 12, 31)),
("aten::quantized_rnn_relu_cell", datetime.date(2023, 12, 31)),
("aten::quantized_rnn_tanh_cell", datetime.date(2023, 12, 31)),
("quantized::make_quantized_cell_params", datetime.date(2023, 12, 31)),
]

ALLOW_LIST_COMPILED = [
Expand Down
13 changes: 0 additions & 13 deletions test/jit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.common_quantization import skipIfNoFBGEMM

if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
Expand Down Expand Up @@ -407,12 +406,6 @@ class Config:
def test_snli(self):
self._test_snli(self, device='cpu')

@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_snli_quantized(self):
self._test_snli(self, device='cpu', quantized=True)

@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_snli_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
Expand Down Expand Up @@ -553,12 +546,6 @@ def forward(self, x):
def test_vae(self):
self._test_vae(self, device='cpu')

@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_vae_quantized(self):
self._test_vae(self, device='cpu', quantized=True)

@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_vae_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
Expand Down
Loading

0 comments on commit 57c7aa1

Please sign in to comment.