From 802cc8497840fe3095babb3c008652da84452bc1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 14 Nov 2024 08:27:36 +0000 Subject: [PATCH 1/4] add calibration and quantization Signed-off-by: root --- monai/engines/trainer.py | 2 +- monai/handlers/__init__.py | 2 + monai/handlers/model_calibrator.py | 68 +++++++++++++++++++++++++ monai/handlers/model_quantizer.py | 82 ++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 monai/handlers/model_calibrator.py create mode 100644 monai/handlers/model_quantizer.py diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index a0be86bae5..64fca6bcc7 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -46,7 +46,7 @@ class Trainer(Workflow): """ - def run(self) -> None: # type: ignore[override] + def run(self, *args) -> None: # type: ignore[override] """ Execute training based on Ignite Engine. If call this function multiple times, it will continuously run from the previous state. diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c1fa448f25..c915c281d0 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -29,6 +29,8 @@ from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler from .metrics_saver import MetricsSaver from .mlflow_handler import MLFlowHandler +from .model_quantizer import ModelQuantizer +from .model_calibrator import ModelCalibrater from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from .panoptic_quality import PanopticQuality from .parameter_scheduler import ParamSchedulerHandler diff --git a/monai/handlers/model_calibrator.py b/monai/handlers/model_calibrator.py new file mode 100644 index 0000000000..43484ea5e2 --- /dev/null +++ b/monai/handlers/model_calibrator.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import modelopt.torch.quantization as mtq +from functools import partial +from monai.utils import IgniteInfo, min_version, optional_import + + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class ModelCalibrater: + """ + Model quantizer is for model quantization. It takes a model as input and convert it to a quantized + model. + + Args: + model: the model to be quantized. + example_inputs: the example inputs for the model quantization. examples:: + (torch.randn(256,256,256),) + config: the calibration config. + + """ + + def __init__( + self, + model: torch.nn.Module, + export_path: str, + config: dict= mtq.INT8_SMOOTHQUANT_CFG, + + ) -> None: + self.model = model + self.export_path = export_path + self.config = config + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(Events.STARTED, self) + + @staticmethod + def _model_wrapper(engine, model): + engine.run() + + def __call__(self, engine) -> None: + quant_fun = partial(self._model_wrapper, engine) + model = mtq.quantize(self.model, self.config, quant_fun) + torch.save(self.model.state_dict(), self.export_path) diff --git a/monai/handlers/model_quantizer.py b/monai/handlers/model_quantizer.py new file mode 100644 index 0000000000..a1b4430cba --- /dev/null +++ b/monai/handlers/model_quantizer.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import warnings +from types import MethodType +from typing import TYPE_CHECKING, Sequence + +import torch + +from monai.networks.utils import copy_model_state +from monai.utils import IgniteInfo, min_version, optional_import +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + XNNPACKQuantizer, + get_symmetric_quantization_config, +) +from torch.ao.quantization.quantize_pt2e import ( + prepare_qat_pt2e, + convert_pt2e, +) + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class ModelQuantizer: + """ + Model quantizer is for model quantization. It takes a model as input and convert it to a quantized + model. + + Args: + model: the model to be quantized. + example_inputs: the example inputs for the model quantization. examples:: + (torch.randn(256,256,256),) + quantizer: quantizer for the quantization job. + + """ + + def __init__( + self, + model: torch.nn.Module, + example_inputs: Sequence, + export_path: str, + quantizer: Quantizer | None = None, + + ) -> None: + self.model = model + self.example_inputs = example_inputs + self.export_path = export_path + self.quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(Events.STARTED, self.start) + engine.add_event_handler(Events.ITERATION_COMPLETED, self.epoch) + + def start(self) -> None: + self.model = torch.export.export_for_training(self.model, self.example_inputs).module() + self.model = prepare_qat_pt2e(self.model, self.quantizer) + self.model.train = MethodType(torch.ao.quantization.move_exported_model_to_train, self.model) + self.model.eval = MethodType(torch.ao.quantization.move_exported_model_to_eval, self.model) + + def epoch(self) -> None: + torch.save(self.model.state_dict(), self.export_path) From 6ee190e8e6ca42091b1b1311959c2f46c7411d71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Nov 2024 00:51:58 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/handlers/model_calibrator.py | 4 ++-- monai/handlers/model_quantizer.py | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/monai/handlers/model_calibrator.py b/monai/handlers/model_calibrator.py index 43484ea5e2..e63bb88178 100644 --- a/monai/handlers/model_calibrator.py +++ b/monai/handlers/model_calibrator.py @@ -45,7 +45,7 @@ def __init__( model: torch.nn.Module, export_path: str, config: dict= mtq.INT8_SMOOTHQUANT_CFG, - + ) -> None: self.model = model self.export_path = export_path @@ -57,7 +57,7 @@ def attach(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ engine.add_event_handler(Events.STARTED, self) - + @staticmethod def _model_wrapper(engine, model): engine.run() diff --git a/monai/handlers/model_quantizer.py b/monai/handlers/model_quantizer.py index a1b4430cba..0d65e9ec40 100644 --- a/monai/handlers/model_quantizer.py +++ b/monai/handlers/model_quantizer.py @@ -11,14 +11,12 @@ from __future__ import annotations -import logging -import warnings from types import MethodType -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING +from collections.abc import Sequence import torch -from monai.networks.utils import copy_model_state from monai.utils import IgniteInfo, min_version, optional_import from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.xnnpack_quantizer import ( @@ -27,7 +25,6 @@ ) from torch.ao.quantization.quantize_pt2e import ( prepare_qat_pt2e, - convert_pt2e, ) Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -57,7 +54,7 @@ def __init__( example_inputs: Sequence, export_path: str, quantizer: Quantizer | None = None, - + ) -> None: self.model = model self.example_inputs = example_inputs @@ -77,6 +74,6 @@ def start(self) -> None: self.model = prepare_qat_pt2e(self.model, self.quantizer) self.model.train = MethodType(torch.ao.quantization.move_exported_model_to_train, self.model) self.model.eval = MethodType(torch.ao.quantization.move_exported_model_to_eval, self.model) - + def epoch(self) -> None: torch.save(self.model.state_dict(), self.export_path) From 88d855c0a7e7acb9a603d42ad3a0a0391ecae733 Mon Sep 17 00:00:00 2001 From: binliu Date: Fri, 15 Nov 2024 01:25:01 +0000 Subject: [PATCH 3/4] fix format Signed-off-by: binliu --- monai/handlers/__init__.py | 2 +- monai/handlers/model_calibrator.py | 16 +++++----------- monai/handlers/model_quantizer.py | 24 ++++++++---------------- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c915c281d0..f25ef680cc 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -29,8 +29,8 @@ from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler from .metrics_saver import MetricsSaver from .mlflow_handler import MLFlowHandler -from .model_quantizer import ModelQuantizer from .model_calibrator import ModelCalibrater +from .model_quantizer import ModelQuantizer from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from .panoptic_quality import PanopticQuality from .parameter_scheduler import ParamSchedulerHandler diff --git a/monai/handlers/model_calibrator.py b/monai/handlers/model_calibrator.py index e63bb88178..76a3681b89 100644 --- a/monai/handlers/model_calibrator.py +++ b/monai/handlers/model_calibrator.py @@ -11,13 +11,13 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING -import torch import modelopt.torch.quantization as mtq -from functools import partial -from monai.utils import IgniteInfo, min_version, optional_import +import torch +from monai.utils import IgniteInfo, min_version, optional_import Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") @@ -40,13 +40,7 @@ class ModelCalibrater: """ - def __init__( - self, - model: torch.nn.Module, - export_path: str, - config: dict= mtq.INT8_SMOOTHQUANT_CFG, - - ) -> None: + def __init__(self, model: torch.nn.Module, export_path: str, config: dict = mtq.INT8_SMOOTHQUANT_CFG) -> None: self.model = model self.export_path = export_path self.config = config @@ -65,4 +59,4 @@ def _model_wrapper(engine, model): def __call__(self, engine) -> None: quant_fun = partial(self._model_wrapper, engine) model = mtq.quantize(self.model, self.config, quant_fun) - torch.save(self.model.state_dict(), self.export_path) + torch.save(model.state_dict(), self.export_path) diff --git a/monai/handlers/model_quantizer.py b/monai/handlers/model_quantizer.py index 0d65e9ec40..4a6c69ad34 100644 --- a/monai/handlers/model_quantizer.py +++ b/monai/handlers/model_quantizer.py @@ -11,21 +11,16 @@ from __future__ import annotations +from collections.abc import Sequence from types import MethodType from typing import TYPE_CHECKING -from collections.abc import Sequence import torch +from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config from monai.utils import IgniteInfo, min_version, optional_import -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - XNNPACKQuantizer, - get_symmetric_quantization_config, -) -from torch.ao.quantization.quantize_pt2e import ( - prepare_qat_pt2e, -) Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") @@ -49,17 +44,14 @@ class ModelQuantizer: """ def __init__( - self, - model: torch.nn.Module, - example_inputs: Sequence, - export_path: str, - quantizer: Quantizer | None = None, - + self, model: torch.nn.Module, example_inputs: Sequence, export_path: str, quantizer: Quantizer | None = None ) -> None: self.model = model self.example_inputs = example_inputs self.export_path = export_path - self.quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer + self.quantizer = ( + XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer + ) def attach(self, engine: Engine) -> None: """ From 5e1c3409ce4c4e79cc165b246176c60a9af31c62 Mon Sep 17 00:00:00 2001 From: binliu Date: Fri, 15 Nov 2024 01:33:33 +0000 Subject: [PATCH 4/4] add dependencies Signed-off-by: binliu --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 72654d3534..bf25c9f535 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -61,3 +61,4 @@ pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon polygraphy +nvidia-modelopt>=0.19.0