-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Refactors quantization guide to be more concise - Updates `nvidia-modelopt` version so we don't need a special package index - Reduces threshold for summary mode when pretty printing tensors. - Adds a `# doc: no-output` tag to omit output from code examples.
- Loading branch information
1 parent
e0ca435
commit 2838b73
Showing
11 changed files
with
226 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,154 +1,235 @@ | ||
# Quantization | ||
|
||
**Quantization** reduces memory and compute requirements by running operations in low precision: | ||
- **Scaling** is required to translate to/from low precision. | ||
- **Scaling factors** are chosen such that they minimize accuracy loss. | ||
- They can be either: | ||
- Loaded into quantization-enabled {class}`nvtripy.Module`s, or | ||
- Used with {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`. | ||
|
||
:::{seealso} | ||
The | ||
[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#working-with-int8) | ||
explains quantization in more detail. | ||
::: | ||
|
||
## Using Quantized Modules | ||
|
||
Various modules predefined by Tripy support quantization. For example, the {class}`nvtripy.Linear` | ||
module includes two arguments to configure the quantization mode. Let's construct the following | ||
quantized linear module: | ||
## Post-Training Quantization With ModelOpt | ||
|
||
```py | ||
# doc: print-locals quant_linear | ||
quant_linear = tp.Linear( | ||
4, | ||
2, | ||
quant_dtype=tp.int8, | ||
weight_quant_dim=None, | ||
) | ||
``` | ||
If the model was not trained with quantization-aware training (QAT), we can use | ||
[TensorRT ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/index.html) | ||
to do **calibration** to determine scaling factors. | ||
|
||
As described in {class}`nvtripy.Linear`, the quantized linear module has | ||
2 additional parameters compared to a normal linear layer: | ||
:::{admonition} Info | ||
**Calibration** runs a model with a small set of input data to determine the | ||
numerical distribution of each tensor. | ||
|
||
1. `weight_scale`: The quantization scale for `weight`. | ||
The **dynamic range** is the most important range within this distribution and | ||
scales are chosen to target this range. | ||
::: | ||
|
||
2. `input_scale`: The quantization scale for the input. | ||
Let's calibrate a GPT model: | ||
|
||
`weight_scale` must always be provided while `input_scale` is optional. The input will be quantized | ||
only if `input_scale` is provided. For a `Linear` module in this example, only "per-tensor" quantization | ||
is allowed for the input. This is why there is no `input_quant_dim` argument. | ||
1. Install ModelOpt: | ||
|
||
Let's fill the scale parameters with dummy data: | ||
```bash | ||
python3 -m pip install nvidia-modelopt==0.11.1 transformers==4.46.2 datasets==2.21.0 | ||
``` | ||
|
||
```py | ||
# doc: print-locals quant_linear | ||
quant_linear.weight_scale = tp.Tensor(1.0) | ||
quant_linear.input_scale = tp.Tensor(1.0) | ||
``` | ||
2. Download the model: | ||
|
||
and run a forward pass to see the result: | ||
```py | ||
# doc: no-print-locals | ||
from transformers import GPT2LMHeadModel | ||
```py | ||
x = tp.iota((3, 4), dtype=tp.float32) | ||
out = quant_linear(x) | ||
assert tp.equal(out, tp.Tensor([[0.0000, 1.0000], [6.0000, 23.0000], [12.0000, 45.0000]])) # doc: omit | ||
``` | ||
model = GPT2LMHeadModel.from_pretrained("gpt2") | ||
``` | ||
|
||
The result still has a data type of {class}`nvtripy.float32`, but internally, TensorRT quantized the | ||
input and weight, executed the linear layer with {class}`nvtripy.int8` precision, and finally dequantized | ||
the output back to the original precision. | ||
3. Calibrate for `int8` precision: | ||
|
||
## Running Quantized Models | ||
1. Define the forward pass: | ||
|
||
Now that we have covered how quantization works in {class}`nvtripy.Linear`, we will walk through | ||
the workflow of running a real-world quantized model: [nanoGPT](source:/examples/nanogpt/). | ||
```py | ||
# doc: no-output | ||
from transformers import AutoTokenizer | ||
from modelopt.torch.utils.dataset_utils import create_forward_loop | ||
### Calibration With Model Optimizer | ||
MAX_SEQ_LEN = 512 | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
"gpt2", | ||
use_fast=True, | ||
model_max_length=MAX_SEQ_LEN, | ||
padding_side="left", | ||
trust_remote_code=True, | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
<!-- Tripy: TEST: IGNORE Start --> | ||
forward_loop = create_forward_loop( | ||
model=model, | ||
dataset_name="cnn_dailymail", | ||
tokenizer=tokenizer, | ||
device=model.device, | ||
num_samples=8, | ||
) | ||
``` | ||
|
||
The quantization scales are not available unless the model was trained with QAT (quantization-aware training). | ||
We need to perform another step called calibration to compute the correct scales for each quantized layer. | ||
There are many ways to do calibration, one of which is using the `nvidia-modelopt` toolkit. To install it, run: | ||
2. Set up quantization configuration: | ||
|
||
```bash | ||
python3 -m pip install --extra-index-url https://pypi.nvidia.com nvidia-modelopt==0.11.0 transformers==4.46.2 datasets==2.21.0 | ||
``` | ||
```py | ||
import modelopt.torch.quantization as mtq | ||
First, let's get the pre-trained GPT model from hugging face: | ||
quant_cfg = mtq.INT8_DEFAULT_CFG | ||
``` | ||
|
||
```py | ||
# doc: no-print-locals | ||
from transformers import GPT2LMHeadModel | ||
3. Run calibration to replace linear layers with | ||
[`QuantLinear`](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), | ||
which contain calibration information: | ||
|
||
```py | ||
# doc: no-output | ||
mtq.quantize(model, quant_cfg, forward_loop=forward_loop) | ||
``` | ||
|
||
|
||
The `amax` attributes of `QuantLinear`'s quantizers specify **dynamic ranges**: | ||
model = GPT2LMHeadModel.from_pretrained("gpt2") | ||
```py | ||
torch_qlinear = model.transformer.h[0].attn.c_attn | ||
print(torch_qlinear) | ||
``` | ||
Then, we perform int8 weight-only quantization: | ||
We must convert dynamic ranges to scaling factors to load them into Tripy: | ||
```py | ||
from transformers import AutoTokenizer | ||
import modelopt.torch.quantization as mtq | ||
|
||
from modelopt.torch.utils.dataset_utils import create_forward_loop | ||
|
||
# define the modelopt quant configs | ||
quant_cfg = mtq.INT8_DEFAULT_CFG | ||
# disable input quantization for weight-only | ||
# quantized linear modules | ||
quant_cfg["quant_cfg"]["*input_quantizer"] = { | ||
"enable": False, | ||
} | ||
|
||
# define the forward loop for calibration | ||
MAX_SEQ_LEN = 512 | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
"gpt2", | ||
use_fast=True, | ||
model_max_length=MAX_SEQ_LEN, | ||
padding_side="left", | ||
trust_remote_code=True, | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
forward_loop = create_forward_loop( | ||
model=model, | ||
dataset_name="cnn_dailymail", | ||
tokenizer=tokenizer, | ||
device=model.device, | ||
num_samples=8, | ||
) | ||
|
||
# call the api for calibration | ||
mtq.quantize(model, quant_cfg, forward_loop=forward_loop) | ||
def get_scale(quantizer): | ||
amax = quantizer.export_amax() | ||
# `maxbound` is the maximum value representible by the data type. | ||
# For `int8`, this is 127. | ||
scale = amax.float() / quantizer.maxbound | ||
return tp.Tensor(scale.squeeze().contiguous()) | ||
input_scale = get_scale(torch_qlinear.input_quantizer) | ||
weight_scale = get_scale(torch_qlinear.weight_quantizer) | ||
``` | ||
`mtq.quantize` replaces all linear layers specified in `quant_cfg` with `QuantLinear` | ||
layers, which contain the calibrated parameters. | ||
### Load Scales Into The Tripy Model | ||
## Loading Scales Into Tripy | ||
### Using Modules | ||
Modules that support quantization usually: | ||
- Expose additional model parameters for scales. | ||
- Accept arguments that control how quantization is performed. | ||
Let's take a look at one of the `QuantLinear` produced by model optimizer: | ||
Let's load the scales into an {class}`nvtripy.Linear` module: | ||
|
||
```py | ||
print(model.transformer.h[0].attn.c_attn) | ||
qlinear = tp.Linear( | ||
768, | ||
2304, | ||
# The data type to quantize to: | ||
quant_dtype=tp.int8, | ||
# The dimension along which the weights are quantized: | ||
weight_quant_dim=torch_qlinear.weight_quantizer.axis) | ||
# Load weights: | ||
qlinear.weight = tp.Tensor(torch_qlinear.weight.detach().contiguous()) | ||
qlinear.bias = tp.Tensor(torch_qlinear.bias.detach().contiguous()) | ||
# Load scaling factors: | ||
qlinear.input_scale = input_scale | ||
qlinear.weight_scale = weight_scale | ||
``` | ||
|
||
The `amax` attribute gives us the dynamic range of the tensor. Tripy requires scaling factors, so we can convert it like so: | ||
:::{note} | ||
We use scales from ModelOpt here, but scaling factors can come from anywhere. | ||
::: | ||
|
||
We can run it just like a regular `float32` module. | ||
Inputs/weights are quantized internally: | ||
|
||
```py | ||
def convert_to_scale(amax, maxbound): | ||
return amax.float() / maxbound | ||
input = tp.ones((1, 768), dtype=tp.float32) | ||
output = qlinear(input) | ||
``` | ||
Let's convert the `amax` to the scaling factor and load it to a compatible {class}`nvtripy.Linear` module: | ||
:::{seealso} | ||
`load_quant_weights_from_hf` in the [nanoGPT weight loader](source:/examples/nanogpt/weight_loader.py) | ||
is an example of loading scaling factors for an entire model. | ||
::: | ||
### Manually | ||
When using {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`, | ||
`dequantize` must **immediately follow** `quantize`. | ||
TensorRT will **rotate** `dequantize` over subsequent ops as needed. | ||
:::{seealso} | ||
The | ||
[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#qdq-placement-recs) | ||
includes recommendations on placement of quantization and dequantization ops. | ||
::: | ||
<!-- We cannot print the quantized input/weight below since that would break Q/DQ rotation --> | ||
To mimic the behavior of the {class}`nvtripy.Linear` module above, we can: | ||
1. Quantize the input: | ||
```py | ||
# doc: no-print-locals | ||
input = tp.ones((1, 768), dtype=tp.float32) | ||
input = tp.quantize(input, input_scale, dtype=tp.int8) | ||
# Note the placement of dequantize: | ||
input = tp.dequantize(input, input_scale, dtype=tp.float32) | ||
``` | ||
2. Quantize the weights: | ||
```py | ||
# doc: no-print-locals | ||
weight = tp.Tensor(torch_qlinear.weight.detach().contiguous()) | ||
dim = torch_qlinear.weight_quantizer.axis | ||
weight = tp.quantize(weight, weight_scale, dtype=tp.int8, dim=dim) | ||
weight = tp.dequantize(weight, weight_scale, dtype=tp.float32, dim=dim) | ||
``` | ||
|
||
3. Perform the computation (matrix multiply in this case): | ||
|
||
```py | ||
# doc: no-print-locals bias | ||
bias = tp.Tensor(torch_qlinear.bias.detach().contiguous()) | ||
output = input @ tp.transpose(weight, 0, 1) + bias | ||
``` | ||
|
||
:::{warning} | ||
**Evaluating** the tensor produced by `dequantize` will affect accuracy. | ||
|
||
- **Why:** Evaluation replaces the tensor with a constant, losing information | ||
like which op produced it. | ||
|
||
So, TensorRT won't see `dequantize` when evaluating subsequent ops and | ||
won't **rotate** it correctly. | ||
|
||
For example, **don't** do this: | ||
```py | ||
# doc: print-locals weight_only_qlinear | ||
weight_only_qlinear = tp.Linear( | ||
768, | ||
2304, | ||
quant_dtype=tp.int8, | ||
weight_quant_dim=0, | ||
) | ||
quantizer = model.transformer.h[0].attn.c_attn.weight_quantizer | ||
scale = convert_to_scale(quantizer.export_amax(), quantizer.maxbound) | ||
scale = scale.squeeze().contiguous() | ||
weight_only_qlinear.weight_scale = tp.Tensor(scale) | ||
``` | ||
# doc: no-eval | ||
tensor = tp.ones(...) | ||
For an example of how to load weights from a quantized model, refer to | ||
[load_quant_weights_from_hf](source:/examples/nanogpt/weight_loader.py) from the nanoGPT example. | ||
tensor = tp.quantize(tensor, ...) | ||
tensor = tp.dequantize(tensor, ...) | ||
<!-- Tripy: TEST: IGNORE End --> | ||
# The `print` below will trigger an evaluation of the tensor which will prevent | ||
# TensorRT from rotating the dequantization node. This will affect accuracy! | ||
print(tensor) | ||
# Rest of the program, including some computation involving tensor | ||
... | ||
``` | ||
::: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.