From 7eb847c88940a998260a581f0077786fdc690ca3 Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 17 Jan 2025 17:19:41 -0800 Subject: [PATCH] Refactors quantization guide - Refactors quantization guide to be more concise - Updates `nvidia-modelopt` version so we don't need a special package index - Reduces threshold for summary mode when pretty printing tensors. - Adds a `# doc: no-output` tag to omit output from code examples. --- tripy/docs/README.md | 6 +- .../00-introduction-to-tripy.md | 2 +- .../docs/pre0_user_guides/01-quantization.md | 301 +++++++++++------- tripy/docs/pre0_user_guides/02-compiler.md | 2 +- tripy/examples/nanogpt/requirements.txt | 3 +- tripy/nvtripy/backend/mlir/memref.py | 12 +- tripy/nvtripy/frontend/trace/ops/shape.py | 2 +- tripy/nvtripy/frontend/utils.py | 4 +- tripy/pyproject.toml | 2 +- tripy/tests/helper.py | 23 +- tripy/tests/integration/test_sequential.py | 3 +- 11 files changed, 226 insertions(+), 134 deletions(-) diff --git a/tripy/docs/README.md b/tripy/docs/README.md index 13f9464d1..217ae6d7f 100644 --- a/tripy/docs/README.md +++ b/tripy/docs/README.md @@ -154,9 +154,11 @@ Code blocks in docstrings/guides are **preprocessed**: - Code is **executed** and any output is displayed in the docs. - If the code throws, doc generation will fail. Use `# doc: allow-exception` to allow exceptions. + - `# doc: allow-exception` allows exceptions to be thrown. By default, they are treated as failures. - - **Note:** `# doc: no-eval` disables execution but this means the code will be **untested**! + - `# doc: no-output` omits output from the docs (but still executes the code). + + - `# doc: no-eval` disables execution but this means the code will be **untested**! - Local variables are also displayed. You can customize this: diff --git a/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md b/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md index edd5b9bf4..4c223d12b 100644 --- a/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md +++ b/tripy/docs/pre0_user_guides/00-introduction-to-tripy.md @@ -63,7 +63,7 @@ Usage: out = fast_mlp(inp) ``` -:::{note} +:::{important} There are **restrictions** on what can be compiled - see {func}`nvtripy.compile`. ::: diff --git a/tripy/docs/pre0_user_guides/01-quantization.md b/tripy/docs/pre0_user_guides/01-quantization.md index 5c75a899e..63496c09b 100644 --- a/tripy/docs/pre0_user_guides/01-quantization.md +++ b/tripy/docs/pre0_user_guides/01-quantization.md @@ -1,154 +1,235 @@ # Quantization +**Quantization** reduces memory and compute requirements by running operations in low precision: +- **Scaling** is required to translate to/from low precision. +- **Scaling factors** are chosen such that they minimize accuracy loss. +- They can be either: + - Loaded into quantization-enabled {class}`nvtripy.Module`s, or + - Used with {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`. +:::{seealso} +The +[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#working-with-int8) +explains quantization in more detail. +::: -## Using Quantized Modules -Various modules predefined by Tripy support quantization. For example, the {class}`nvtripy.Linear` -module includes two arguments to configure the quantization mode. Let's construct the following -quantized linear module: +## Post-Training Quantization With ModelOpt -```py -# doc: print-locals quant_linear -quant_linear = tp.Linear( - 4, - 2, - quant_dtype=tp.int8, - weight_quant_dim=None, -) -``` +If the model was not trained with quantization-aware training (QAT), we can use +[TensorRT ModelOpt](https://nvidia.github.io/TensorRT-Model-Optimizer/index.html) +to do **calibration** to determine scaling factors. -As described in {class}`nvtripy.Linear`, the quantized linear module has -2 additional parameters compared to a normal linear layer: +:::{admonition} Info +**Calibration** runs a model with a small set of input data to determine the +numerical distribution of each tensor. -1. `weight_scale`: The quantization scale for `weight`. +The **dynamic range** is the most important range within this distribution and +scales are chosen to target this range. +::: -2. `input_scale`: The quantization scale for the input. +Let's calibrate a GPT model: -`weight_scale` must always be provided while `input_scale` is optional. The input will be quantized -only if `input_scale` is provided. For a `Linear` module in this example, only "per-tensor" quantization -is allowed for the input. This is why there is no `input_quant_dim` argument. +1. Install ModelOpt: -Let's fill the scale parameters with dummy data: + ```bash + python3 -m pip install nvidia-modelopt==0.11.1 transformers==4.46.2 datasets==2.21.0 + ``` -```py -# doc: print-locals quant_linear -quant_linear.weight_scale = tp.Tensor(1.0) -quant_linear.input_scale = tp.Tensor(1.0) -``` +2. Download the model: -and run a forward pass to see the result: + ```py + # doc: no-print-locals + from transformers import GPT2LMHeadModel -```py -x = tp.iota((3, 4), dtype=tp.float32) -out = quant_linear(x) -assert tp.equal(out, tp.Tensor([[0.0000, 1.0000], [6.0000, 23.0000], [12.0000, 45.0000]])) # doc: omit -``` + model = GPT2LMHeadModel.from_pretrained("gpt2") + ``` -The result still has a data type of {class}`nvtripy.float32`, but internally, TensorRT quantized the -input and weight, executed the linear layer with {class}`nvtripy.int8` precision, and finally dequantized -the output back to the original precision. +3. Calibrate for `int8` precision: -## Running Quantized Models + 1. Define the forward pass: -Now that we have covered how quantization works in {class}`nvtripy.Linear`, we will walk through -the workflow of running a real-world quantized model: [nanoGPT](source:/examples/nanogpt/). + ```py + # doc: no-output + from transformers import AutoTokenizer + from modelopt.torch.utils.dataset_utils import create_forward_loop -### Calibration With Model Optimizer + MAX_SEQ_LEN = 512 + tokenizer = AutoTokenizer.from_pretrained( + "gpt2", + use_fast=True, + model_max_length=MAX_SEQ_LEN, + padding_side="left", + trust_remote_code=True, + ) + tokenizer.pad_token = tokenizer.eos_token - + forward_loop = create_forward_loop( + model=model, + dataset_name="cnn_dailymail", + tokenizer=tokenizer, + device=model.device, + num_samples=8, + ) + ``` -The quantization scales are not available unless the model was trained with QAT (quantization-aware training). -We need to perform another step called calibration to compute the correct scales for each quantized layer. -There are many ways to do calibration, one of which is using the `nvidia-modelopt` toolkit. To install it, run: + 2. Set up quantization configuration: -```bash -python3 -m pip install --extra-index-url https://pypi.nvidia.com nvidia-modelopt==0.11.0 transformers==4.46.2 datasets==2.21.0 -``` + ```py + import modelopt.torch.quantization as mtq -First, let's get the pre-trained GPT model from hugging face: + quant_cfg = mtq.INT8_DEFAULT_CFG + ``` -```py -# doc: no-print-locals -from transformers import GPT2LMHeadModel + 3. Run calibration to replace linear layers with + [`QuantLinear`](https://nvidia.github.io/TensorRT-Model-Optimizer/reference/generated/modelopt.torch.quantization.nn.modules.quant_linear.html#modelopt.torch.quantization.nn.modules.quant_linear.QuantLinear), + which contain calibration information: + + ```py + # doc: no-output + mtq.quantize(model, quant_cfg, forward_loop=forward_loop) + ``` + + +The `amax` attributes of `QuantLinear`'s quantizers specify **dynamic ranges**: -model = GPT2LMHeadModel.from_pretrained("gpt2") +```py +torch_qlinear = model.transformer.h[0].attn.c_attn +print(torch_qlinear) ``` -Then, we perform int8 weight-only quantization: +We must convert dynamic ranges to scaling factors to load them into Tripy: ```py -from transformers import AutoTokenizer -import modelopt.torch.quantization as mtq - -from modelopt.torch.utils.dataset_utils import create_forward_loop - -# define the modelopt quant configs -quant_cfg = mtq.INT8_DEFAULT_CFG -# disable input quantization for weight-only -# quantized linear modules -quant_cfg["quant_cfg"]["*input_quantizer"] = { - "enable": False, -} - -# define the forward loop for calibration -MAX_SEQ_LEN = 512 -tokenizer = AutoTokenizer.from_pretrained( - "gpt2", - use_fast=True, - model_max_length=MAX_SEQ_LEN, - padding_side="left", - trust_remote_code=True, -) -tokenizer.pad_token = tokenizer.eos_token - -forward_loop = create_forward_loop( - model=model, - dataset_name="cnn_dailymail", - tokenizer=tokenizer, - device=model.device, - num_samples=8, -) - -# call the api for calibration -mtq.quantize(model, quant_cfg, forward_loop=forward_loop) +def get_scale(quantizer): + amax = quantizer.export_amax() + # `maxbound` is the maximum value representible by the data type. + # For `int8`, this is 127. + scale = amax.float() / quantizer.maxbound + return tp.Tensor(scale.squeeze().contiguous()) + +input_scale = get_scale(torch_qlinear.input_quantizer) +weight_scale = get_scale(torch_qlinear.weight_quantizer) ``` -`mtq.quantize` replaces all linear layers specified in `quant_cfg` with `QuantLinear` -layers, which contain the calibrated parameters. -### Load Scales Into The Tripy Model +## Loading Scales Into Tripy + +### Using Modules + +Modules that support quantization usually: +- Expose additional model parameters for scales. +- Accept arguments that control how quantization is performed. -Let's take a look at one of the `QuantLinear` produced by model optimizer: +Let's load the scales into an {class}`nvtripy.Linear` module: ```py -print(model.transformer.h[0].attn.c_attn) +qlinear = tp.Linear( + 768, + 2304, + # The data type to quantize to: + quant_dtype=tp.int8, + # The dimension along which the weights are quantized: + weight_quant_dim=torch_qlinear.weight_quantizer.axis) + +# Load weights: +qlinear.weight = tp.Tensor(torch_qlinear.weight.detach().contiguous()) +qlinear.bias = tp.Tensor(torch_qlinear.bias.detach().contiguous()) + +# Load scaling factors: +qlinear.input_scale = input_scale +qlinear.weight_scale = weight_scale ``` -The `amax` attribute gives us the dynamic range of the tensor. Tripy requires scaling factors, so we can convert it like so: +:::{note} +We use scales from ModelOpt here, but scaling factors can come from anywhere. +::: + +We can run it just like a regular `float32` module. +Inputs/weights are quantized internally: ```py -def convert_to_scale(amax, maxbound): - return amax.float() / maxbound +input = tp.ones((1, 768), dtype=tp.float32) + +output = qlinear(input) ``` -Let's convert the `amax` to the scaling factor and load it to a compatible {class}`nvtripy.Linear` module: +:::{seealso} +`load_quant_weights_from_hf` in the [nanoGPT weight loader](source:/examples/nanogpt/weight_loader.py) +is an example of loading scaling factors for an entire model. +::: + + +### Manually +When using {func}`nvtripy.quantize`/{func}`nvtripy.dequantize`, +`dequantize` must **immediately follow** `quantize`. + +TensorRT will **rotate** `dequantize` over subsequent ops as needed. + +:::{seealso} +The +[TensorRT developer guide](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#qdq-placement-recs) +includes recommendations on placement of quantization and dequantization ops. +::: + + + +To mimic the behavior of the {class}`nvtripy.Linear` module above, we can: + +1. Quantize the input: + + ```py + # doc: no-print-locals + input = tp.ones((1, 768), dtype=tp.float32) + + input = tp.quantize(input, input_scale, dtype=tp.int8) + # Note the placement of dequantize: + input = tp.dequantize(input, input_scale, dtype=tp.float32) + ``` + +2. Quantize the weights: + + ```py + # doc: no-print-locals + weight = tp.Tensor(torch_qlinear.weight.detach().contiguous()) + + dim = torch_qlinear.weight_quantizer.axis + weight = tp.quantize(weight, weight_scale, dtype=tp.int8, dim=dim) + weight = tp.dequantize(weight, weight_scale, dtype=tp.float32, dim=dim) + ``` + +3. Perform the computation (matrix multiply in this case): + + ```py + # doc: no-print-locals bias + bias = tp.Tensor(torch_qlinear.bias.detach().contiguous()) + + output = input @ tp.transpose(weight, 0, 1) + bias + ``` + +:::{warning} +**Evaluating** the tensor produced by `dequantize` will affect accuracy. + +- **Why:** Evaluation replaces the tensor with a constant, losing information + like which op produced it. + + So, TensorRT won't see `dequantize` when evaluating subsequent ops and + won't **rotate** it correctly. + +For example, **don't** do this: ```py -# doc: print-locals weight_only_qlinear -weight_only_qlinear = tp.Linear( - 768, - 2304, - quant_dtype=tp.int8, - weight_quant_dim=0, -) -quantizer = model.transformer.h[0].attn.c_attn.weight_quantizer -scale = convert_to_scale(quantizer.export_amax(), quantizer.maxbound) -scale = scale.squeeze().contiguous() -weight_only_qlinear.weight_scale = tp.Tensor(scale) -``` +# doc: no-eval +tensor = tp.ones(...) -For an example of how to load weights from a quantized model, refer to -[load_quant_weights_from_hf](source:/examples/nanogpt/weight_loader.py) from the nanoGPT example. +tensor = tp.quantize(tensor, ...) +tensor = tp.dequantize(tensor, ...) - +# The `print` below will trigger an evaluation of the tensor which will prevent +# TensorRT from rotating the dequantization node. This will affect accuracy! +print(tensor) + +# Rest of the program, including some computation involving tensor +... +``` +::: diff --git a/tripy/docs/pre0_user_guides/02-compiler.md b/tripy/docs/pre0_user_guides/02-compiler.md index 7afada02a..03093789a 100644 --- a/tripy/docs/pre0_user_guides/02-compiler.md +++ b/tripy/docs/pre0_user_guides/02-compiler.md @@ -2,7 +2,7 @@ Modules and functions can be compiled for better performance. -:::{note} +:::{important} There are **restrictions** on what can be compiled - see {func}`nvtripy.compile`. ::: diff --git a/tripy/examples/nanogpt/requirements.txt b/tripy/examples/nanogpt/requirements.txt index f841318fe..b37f504cf 100644 --- a/tripy/examples/nanogpt/requirements.txt +++ b/tripy/examples/nanogpt/requirements.txt @@ -4,6 +4,5 @@ transformers==4.46.2 tiktoken==0.5.2 --extra-index-url https://download.pytorch.org/whl/cu121 torch==2.3.0 ---extra-index-url https://pypi.nvidia.com -nvidia-modelopt==0.11.0 +nvidia-modelopt==0.11.1 datasets==2.18.0 diff --git a/tripy/nvtripy/backend/mlir/memref.py b/tripy/nvtripy/backend/mlir/memref.py index 3d4d1bf92..c9b831e73 100644 --- a/tripy/nvtripy/backend/mlir/memref.py +++ b/tripy/nvtripy/backend/mlir/memref.py @@ -79,13 +79,13 @@ def create_memref_view(data): def check_tensor_type_and_suggest_contiguous(obj): obj_type = str(type(obj)) if "torch.Tensor" in obj_type: - return "PyTorch Tensor", "tensor.contiguous() or tensor.clone()" + return "PyTorch tensors", "tensor.contiguous() or tensor.clone()" elif "jaxlib" in obj_type or "jax.numpy" in obj_type: - return "JAX Array", "jax.numpy.asarray(array) or jax.numpy.copy(array)" + return "JAX arrays", "jax.numpy.asarray(array) or jax.numpy.copy(array)" elif "numpy.ndarray" in obj_type: - return "NumPy Array", "np.ascontiguousarray(array) or array.copy(order='C')" + return "NumPy arrays", "np.ascontiguousarray(array) or array.copy(order='C')" elif "cupy.ndarray" in obj_type: - return "CuPy Array", "cp.ascontiguousarray(array) or array.copy(order='C')" + return "CuPy arrays", "cp.ascontiguousarray(array) or array.copy(order='C')" else: return None, None @@ -94,8 +94,8 @@ def check_tensor_type_and_suggest_contiguous(obj): error_message = ( f"Non-canonical strides detected:\n" f" Shape: {shape}\n" - f" Current stride: {given_strides}\n" - f" Expected canonical stride: {canonical_strides}\n" + f" Strides: {given_strides}\n" + f" Expected canonical strides: {canonical_strides}\n" f"Non-canonical strides are not supported for Tripy tensors. " f"This usually occurs when the tensor is not contiguous in memory. " + ( diff --git a/tripy/nvtripy/frontend/trace/ops/shape.py b/tripy/nvtripy/frontend/trace/ops/shape.py index fb29c44ea..08016f0c7 100644 --- a/tripy/nvtripy/frontend/trace/ops/shape.py +++ b/tripy/nvtripy/frontend/trace/ops/shape.py @@ -50,7 +50,7 @@ def shape(self: "nvtripy.Tensor") -> ShapeLike: Represents the shape of the tensor. Returns: - A shape tensor containing the shape of this tensor. + A sequence containing the shape of this tensor. .. code-block:: python :linenos: diff --git a/tripy/nvtripy/frontend/utils.py b/tripy/nvtripy/frontend/utils.py index d43746f97..a8f89c7cd 100644 --- a/tripy/nvtripy/frontend/utils.py +++ b/tripy/nvtripy/frontend/utils.py @@ -120,7 +120,7 @@ def process_dim(dim: int, input_rank: int) -> int: return new_dim -def pretty_print(data_list, shape, threshold=1000, linewidth=10, edgeitems=3): +def pretty_print(data_list, shape, threshold=40, linewidth=10, edgeitems=3): """ Returns a pretty-print string of list format data. """ @@ -131,7 +131,7 @@ def _data_str(data, summarize, linewidth, edgeitems, indent=0): if len(data) == 0 or isinstance(data[0], (float, int)): if summarize and len(data) > 2 * edgeitems: - data_lines = [data[:edgeitems] + [" ..."] + data[-edgeitems:]] + data_lines = [data[:edgeitems] + ["..."] + data[-edgeitems:]] else: data_lines = [data[i : i + linewidth] for i in range(0, len(data), linewidth)] lines = [", ".join([f"{e:.4f}" if isinstance(e, float) else str(e) for e in line]) for line in data_lines] diff --git a/tripy/pyproject.toml b/tripy/pyproject.toml index 01ce39469..08fc0f1b3 100644 --- a/tripy/pyproject.toml +++ b/tripy/pyproject.toml @@ -54,7 +54,7 @@ docs = [ "sphinxcontrib-mermaid==0.9.2", "nvtripy[doc_test_common]", # Needed for guides: - "nvidia-modelopt==0.11.0", + "nvidia-modelopt==0.11.1", "transformers==4.44.2", "datasets==2.21.0", ] diff --git a/tripy/tests/helper.py b/tripy/tests/helper.py index ef654f8a2..a5a174127 100644 --- a/tripy/tests/helper.py +++ b/tripy/tests/helper.py @@ -157,13 +157,20 @@ def exec_code(code, code_locals=None) -> Dict[str, Any]: @contextlib.contextmanager def capture_output(): + def reset_outfile(outfile): + outfile.flush() + outfile.seek(0) + try: outfile = io.StringIO() with contextlib.redirect_stdout(outfile), contextlib.redirect_stderr(outfile): yield outfile + except: + reset_outfile(outfile) + print(outfile.read()) + raise finally: - outfile.flush() - outfile.seek(0) + reset_outfile(outfile) def discover_modules(): @@ -388,15 +395,17 @@ def process_code_block_for_outputs_and_locals( TRIPY_CLASSES = [tripy_obj for tripy_obj in discover_tripy_objects() if inspect.isclass(tripy_obj)] # Special tags are documented under docs/README.md. NO_EVAL = "# doc: no-eval" + NO_OUTPUT = "# doc: no-output" NO_PRINT_LOCALS = "# doc: no-print-locals" PRINT_LOCALS = "# doc: print-locals" ALLOW_EXCEPTION = "# doc: allow-exception" - REMOVE_TAGS = [NO_PRINT_LOCALS, PRINT_LOCALS, NO_EVAL, ALLOW_EXCEPTION] + REMOVE_TAGS = [NO_PRINT_LOCALS, PRINT_LOCALS, NO_EVAL, NO_OUTPUT, ALLOW_EXCEPTION] if strip_assertions: REMOVE_TAGS.append("assert ") OMIT_COMMENT = "# doc: omit" should_append_locals = True + should_append_output = True should_eval = True allow_exception = False @@ -423,6 +432,9 @@ def process_code_block_for_outputs_and_locals( if block_line.strip() == NO_EVAL: should_eval = False + if block_line.strip() == NO_OUTPUT: + should_append_output = False + if block_line.strip() == ALLOW_EXCEPTION: allow_exception = True @@ -491,7 +503,7 @@ def process_code_block_for_outputs_and_locals( print(e) code_locals = local_vars else: - print(f"{err_msg}\n" f"Note: Code block was:\n\n{block}") + print(f"{err_msg}\n" f"Note: Code block was:\n\n{block}\n\nExtracted code was:\n\n{code}\n") raise new_locals = { @@ -556,8 +568,7 @@ def split_block_lines(kind, contents, lang="python"): # Add output as a separate code block. stdout = outfile.read() or "" - - if stdout: + if stdout and should_append_output: # Strip out ANSI control sequences from output: stdout = ANSI_ESCAPE.sub("", stdout) output_lines = split_block_lines(BlockKind.OUTPUT, stdout, lang="") diff --git a/tripy/tests/integration/test_sequential.py b/tripy/tests/integration/test_sequential.py index b6f5e9c71..755b95625 100644 --- a/tripy/tests/integration/test_sequential.py +++ b/tripy/tests/integration/test_sequential.py @@ -68,9 +68,8 @@ def test_dict_forward_pass_accuracy(self, eager_or_compiled): with torch.no_grad(): torch_output = torch_model(input_tensor) - rtol_ = 2e-6 assert torch.allclose( - torch.from_dlpack(tp_output), torch_output, rtol=rtol_ + torch.from_dlpack(tp_output), torch_output, rtol=2e-4, atol=1e-5 ), "Forward pass outputs do not match." def test_nested_forward_pass_accuracy(self, eager_or_compiled):