-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Refactors quantization guide to be more concise - Updates `nvidia-modelopt` version so we don't need a special package index - Reduces threshold for summary mode when pretty printing tensors. - Adds a `# doc: no-output` tag to omit output from code examples.
- Loading branch information
1 parent
e0ca435
commit 2e9eb6b
Showing
8 changed files
with
128 additions
and
128 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,154 +1,142 @@ | ||
# Quantization | ||
|
||
**Quantization** can reduce a model's memory and compute requirements by running operations in a lower precision. | ||
|
||
- **Scaling** is required to translate to/from low precision. | ||
- **Scaling factors** are chosen such that they minimize accuracy loss. | ||
- Scaling factors can be loaded into Tripy models just like weights. | ||
|
||
## Using Quantized Modules | ||
:::{seealso} | ||
The | ||
[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#working-with-int8) | ||
explains quantization in more detail. | ||
::: | ||
|
||
Various modules predefined by Tripy support quantization. For example, the {class}`nvtripy.Linear` | ||
module includes two arguments to configure the quantization mode. Let's construct the following | ||
quantized linear module: | ||
|
||
```py | ||
# doc: print-locals quant_linear | ||
quant_linear = tp.Linear( | ||
4, | ||
2, | ||
quant_dtype=tp.int8, | ||
weight_quant_dim=None, | ||
) | ||
``` | ||
## Post-Training Quantization With ModelOpt | ||
|
||
As described in {class}`nvtripy.Linear`, the quantized linear module has | ||
2 additional parameters compared to a normal linear layer: | ||
If the model was not trained with quantization-aware training (QAT), we can use | ||
[TensorRT ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/index.html) | ||
to do **calibration** to determine scaling factors. | ||
|
||
1. `weight_scale`: The quantization scale for `weight`. | ||
:::{admonition} Info | ||
**Calibration** runs a model with a small set of input data to determine the | ||
numerical distribution of each tensor. | ||
|
||
2. `input_scale`: The quantization scale for the input. | ||
The **dynamic range** is the most important range within this distribution. | ||
::: | ||
|
||
`weight_scale` must always be provided while `input_scale` is optional. The input will be quantized | ||
only if `input_scale` is provided. For a `Linear` module in this example, only "per-tensor" quantization | ||
is allowed for the input. This is why there is no `input_quant_dim` argument. | ||
Let's calibrate a GPT model: | ||
|
||
Let's fill the scale parameters with dummy data: | ||
1. Install ModelOpt: | ||
|
||
```py | ||
# doc: print-locals quant_linear | ||
quant_linear.weight_scale = tp.Tensor(1.0) | ||
quant_linear.input_scale = tp.Tensor(1.0) | ||
``` | ||
```bash | ||
python3 -m pip install nvidia-modelopt==0.11.1 transformers==4.46.2 datasets==2.21.0 | ||
``` | ||
|
||
and run a forward pass to see the result: | ||
2. Download the model: | ||
|
||
```py | ||
x = tp.iota((3, 4), dtype=tp.float32) | ||
out = quant_linear(x) | ||
assert tp.equal(out, tp.Tensor([[0.0000, 1.0000], [6.0000, 23.0000], [12.0000, 45.0000]])) # doc: omit | ||
``` | ||
```py | ||
# doc: no-print-locals | ||
from transformers import GPT2LMHeadModel | ||
The result still has a data type of {class}`nvtripy.float32`, but internally, TensorRT quantized the | ||
input and weight, executed the linear layer with {class}`nvtripy.int8` precision, and finally dequantized | ||
the output back to the original precision. | ||
model = GPT2LMHeadModel.from_pretrained("gpt2") | ||
``` | ||
|
||
## Running Quantized Models | ||
3. Calibrate for `int8` precision: | ||
|
||
Now that we have covered how quantization works in {class}`nvtripy.Linear`, we will walk through | ||
the workflow of running a real-world quantized model: [nanoGPT](source:/examples/nanogpt/). | ||
1. Define the forward pass: | ||
|
||
### Calibration With Model Optimizer | ||
```py | ||
# doc: no-output | ||
from transformers import AutoTokenizer | ||
from modelopt.torch.utils.dataset_utils import create_forward_loop | ||
<!-- Tripy: TEST: IGNORE Start --> | ||
MAX_SEQ_LEN = 512 | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
"gpt2", | ||
use_fast=True, | ||
model_max_length=MAX_SEQ_LEN, | ||
padding_side="left", | ||
trust_remote_code=True, | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
The quantization scales are not available unless the model was trained with QAT (quantization-aware training). | ||
We need to perform another step called calibration to compute the correct scales for each quantized layer. | ||
There are many ways to do calibration, one of which is using the `nvidia-modelopt` toolkit. To install it, run: | ||
forward_loop = create_forward_loop( | ||
model=model, | ||
dataset_name="cnn_dailymail", | ||
tokenizer=tokenizer, | ||
device=model.device, | ||
num_samples=8, | ||
) | ||
``` | ||
|
||
```bash | ||
python3 -m pip install --extra-index-url https://pypi.nvidia.com nvidia-modelopt==0.11.0 transformers==4.46.2 datasets==2.21.0 | ||
``` | ||
2. Set up quantization configuration: | ||
|
||
First, let's get the pre-trained GPT model from hugging face: | ||
```py | ||
import modelopt.torch.quantization as mtq | ||
```py | ||
# doc: no-print-locals | ||
from transformers import GPT2LMHeadModel | ||
quant_cfg = mtq.INT8_DEFAULT_CFG | ||
``` | ||
|
||
model = GPT2LMHeadModel.from_pretrained("gpt2") | ||
``` | ||
3. Run calibration to replace linear layers with | ||
[`QuantLinear`](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), | ||
which contain calibration information: | ||
|
||
Then, we perform int8 weight-only quantization: | ||
```py | ||
# doc: no-output | ||
mtq.quantize(model, quant_cfg, forward_loop=forward_loop) | ||
``` | ||
|
||
```py | ||
from transformers import AutoTokenizer | ||
import modelopt.torch.quantization as mtq | ||
|
||
from modelopt.torch.utils.dataset_utils import create_forward_loop | ||
|
||
# define the modelopt quant configs | ||
quant_cfg = mtq.INT8_DEFAULT_CFG | ||
# disable input quantization for weight-only | ||
# quantized linear modules | ||
quant_cfg["quant_cfg"]["*input_quantizer"] = { | ||
"enable": False, | ||
} | ||
|
||
# define the forward loop for calibration | ||
MAX_SEQ_LEN = 512 | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
"gpt2", | ||
use_fast=True, | ||
model_max_length=MAX_SEQ_LEN, | ||
padding_side="left", | ||
trust_remote_code=True, | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
forward_loop = create_forward_loop( | ||
model=model, | ||
dataset_name="cnn_dailymail", | ||
tokenizer=tokenizer, | ||
device=model.device, | ||
num_samples=8, | ||
) | ||
|
||
# call the api for calibration | ||
mtq.quantize(model, quant_cfg, forward_loop=forward_loop) | ||
``` | ||
|
||
`mtq.quantize` replaces all linear layers specified in `quant_cfg` with `QuantLinear` | ||
layers, which contain the calibrated parameters. | ||
The `amax` attribute(s) of `QuantLinear`'s quantizers specify **dynamic range**(s): | ||
### Load Scales Into The Tripy Model | ||
```py | ||
torch_qlinear = model.transformer.h[0].attn.c_attn | ||
print(torch_qlinear) | ||
``` | ||
Let's take a look at one of the `QuantLinear` produced by model optimizer: | ||
We must convert dynamic ranges to scaling factors to load them into Tripy: | ||
```py | ||
print(model.transformer.h[0].attn.c_attn) | ||
def get_scale(quantizer): | ||
amax = quantizer.export_amax() | ||
# `maxbound` is the maximum value representible by the data type. | ||
# For `int8`, this is 127. | ||
scale = amax.float() / quantizer.maxbound | ||
return scale.squeeze().contiguous() | ||
input_scale = get_scale(torch_qlinear.input_quantizer) | ||
weight_scale = get_scale(torch_qlinear.weight_quantizer) | ||
``` | ||
The `amax` attribute gives us the dynamic range of the tensor. Tripy requires scaling factors, so we can convert it like so: | ||
## Loading Scales Into Tripy | ||
Let's load the scales into an {class}`nvtripy.Linear` module: | ||
|
||
```py | ||
def convert_to_scale(amax, maxbound): | ||
return amax.float() / maxbound | ||
# Set the quantization data type and the dimension | ||
# along which the weights are quantized. | ||
qlinear = tp.Linear(768, 2304, quant_dtype=tp.int8, weight_quant_dim=torch_qlinear.weight_quantizer.axis) | ||
qlinear.input_scale = tp.Tensor(input_scale) | ||
qlinear.weight_scale = tp.Tensor(weight_scale) | ||
``` | ||
|
||
Let's convert the `amax` to the scaling factor and load it to a compatible {class}`nvtripy.Linear` module: | ||
:::{note} | ||
We use scales from ModelOpt here, but scaling factors can come from anywhere. | ||
::: | ||
|
||
We run the module just like a non-quantized `float32` module. | ||
Inputs and weights are quantized internally: | ||
|
||
```py | ||
# doc: print-locals weight_only_qlinear | ||
weight_only_qlinear = tp.Linear( | ||
768, | ||
2304, | ||
quant_dtype=tp.int8, | ||
weight_quant_dim=0, | ||
) | ||
quantizer = model.transformer.h[0].attn.c_attn.weight_quantizer | ||
scale = convert_to_scale(quantizer.export_amax(), quantizer.maxbound) | ||
scale = scale.squeeze().contiguous() | ||
weight_only_qlinear.weight_scale = tp.Tensor(scale) | ||
``` | ||
dummy_input = tp.ones((1, 768), dtype=tp.float32) | ||
For an example of how to load weights from a quantized model, refer to | ||
[load_quant_weights_from_hf](source:/examples/nanogpt/weight_loader.py) from the nanoGPT example. | ||
output = qlinear(dummy_input) | ||
``` | ||
<!-- Tripy: TEST: IGNORE End --> | ||
:::{seealso} | ||
`load_quant_weights_from_hf` in the [nanoGPT weight loader](source:/examples/nanogpt/weight_loader.py) | ||
is an example of loading scaling factors for an entire model. | ||
::: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters