forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathquantized_backward.cpp
167 lines (159 loc) · 6.7 KB
/
quantized_backward.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#include <ATen/native/quantized/PackedParams.h>
#include <ATen/native/quantized/cpu/QuantUtils.h>
#include <torch/library.h>
#include <torch/torch.h>
namespace {
using namespace torch::autograd;
using namespace at;
// This class is a custom gradient function that enables quantized tensor to
// pass input gradient back to the previous layers This function can be used
// when the user is adapting mixed precision for traninig after quantization
// From torch layer, we have no access to linear_dynamic operator which needs to
// access via redispatching mechanism TO-DO : currently we are supporting per
// tensor quantization only, will expand to per channel later on
class PackedLinearWeightDynamicBackward
: public Function<PackedLinearWeightDynamicBackward> {
public:
static torch::Tensor forward(
AutogradContext* ctx,
at::Tensor input,
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
bool reduce_range) {
static auto op =
at::Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_dynamic", "")
.typed<at::Tensor(
at::Tensor,
c10::intrusive_ptr<
LinearPackedParamsBase,
c10::detail::intrusive_target_default_null_type<
LinearPackedParamsBase>> const&,
bool)>();
// Calculate statistics for quantization of input Tensor
float x_min = 0;
float x_max = 0;
if (input.numel() > 0) {
auto input_contig = input.contiguous();
x_min = input_contig.min().item<float>();
x_max = input_contig.max().item<float>();
}
auto output = op.redispatch(
DispatchKeySet({DispatchKey::CPU}),
std::move(input),
packed_weight,
reduce_range);
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/0,
/*qmax=*/255);
ctx->saved_data["weight"] = packed_weight;
// q_params.scale : shape [1] (per-tensor)
ctx->saved_data["input_scale"] = q_params.scale;
return output;
}
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs) {
if (grad_outputs.empty()) {
return {torch::Tensor(), torch::Tensor(), torch::Tensor()};
}
auto packed_weight =
ctx->saved_data["weight"].toCustomClass<LinearPackedParamsBase>();
auto unpacked_parameters = packed_weight->unpack();
auto original_weight = std::get<0>(unpacked_parameters);
auto input_scale = ctx->saved_data["input_scale"].toDouble();
// Gradient for post-scaling
// Let us rewrite this layer by separating the matmul from the output
// scaling: y = (x * s1) @ w * s2 + b So you now back-propagate through four
// operations: + b, * s2, @ W, and * s1. The steps are: start with the
// gradient from the top, aka the adjoint, which is grad_outputs[0].
// gradient for + b: this is a no-op.
// gradient for * s2: scale by s2. That's the affine/per-channel scale baked
// into W. gradient for @ W: matmul with W.t. gradient for * s1: scale by
// s1.
auto grad_output0 = grad_outputs[0];
const auto qtype = original_weight.qscheme();
if (qtype == at::kPerTensorAffine) {
grad_output0 *= original_weight.q_scale();
original_weight = at::permute(original_weight, {1, 0});
} else if (qtype == at::kPerChannelAffine) {
// Per Channel quantizer does not support transpose.
// Manual transpose is necessary
original_weight = original_weight.dequantize();
// kwanghoon(TODO): This is going to be a long term solution that is applicable
// to every models One issue with quantizing a gradient, we can't get good
// enough gradient to improve model accuracy when model become complicated As of
// now, we can disable, and comeback when we figure it out better solution.
#if 0
// Enable Kernel backend for quantized backpropagaiton matrix
// multiplication
original_weight = at::permute(original_weight, {1, 0});
// Take advantage of QNNPACK for matrix multiplication
// Per channel scales & zero point computation
// Sources :
// https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/observer.py#L350-L353
auto [amin, amax] = at::aminmax(original_weight, /*dim* = */ 1);
// QInt8 type signed quantization
auto qmax = 127;
auto qmin = -128;
// Clamp with some epsilon number, so that value does not go below zero
auto epsilon = 1e-9;
auto new_scales = (amax - amin) / float(qmax - qmin);
new_scales = at::clamp(new_scales, epsilon);
auto new_zero_point =
qmin - at::round(amin / new_scales).toType(c10::kInt);
new_zero_point = at::clamp(new_zero_point, qmin, qmax);
// TO-DO (BUGBUG)
// Backend kernel is designed for inference, tightly coded for output
// channel. For mathematical correctness, we should enable to run kernel
// with input channel axis after transpose. As workaround, we are simply
// either exploring per tensor quantization or per channel quantization
// with axis = 0
original_weight = at::quantize_per_channel(
original_weight,
new_scales,
new_zero_point,
/*axis = 1 for transpose, but we are forcing it to non-transposed case
due to above issue*/
0,
c10::kQInt8);
#endif
} else {
TORCH_INTERNAL_ASSERT(false, "Unsupported quantization scheme.");
}
#if 1
// Pure FP32 computation, useful for debugging purpose
auto dLdX1 = torch::matmul(grad_output0, original_weight);
#else
// Take advantage of QNNPACK for matrix multiplication
static auto op = at::Dispatcher::singleton()
.findSchemaOrThrow("quantized::linear_prepack", "")
.typed<c10::intrusive_ptr<LinearPackedParamsBase>(
at::Tensor, c10::optional<at::Tensor>)>();
auto prepacked_weight = op.call(original_weight, nullopt);
auto dLdX1 =
prepacked_weight->apply_dynamic(grad_output0.toType(c10::kFloat));
#endif
auto input_grad0 = dLdX1 * input_scale;
return {input_grad0, torch::Tensor(), torch::Tensor()};
}
};
at::Tensor packed_linear_weight_grad(
c10::DispatchKeySet ks,
at::Tensor input,
const c10::intrusive_ptr<LinearPackedParamsBase>& packed_weight,
bool reduce_range) {
return PackedLinearWeightDynamicBackward::apply(
std::move(input), packed_weight, reduce_range);
}
} // namespace
namespace at {
namespace native {
namespace {
TORCH_LIBRARY_IMPL(quantized, Autograd, m) {
m.impl(
TORCH_SELECTIVE_NAME("quantized::linear_dynamic"),
TORCH_FN(packed_linear_weight_grad));
}
} // namespace
} // namespace native
} // namespace at