diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index a0be86bae5..64fca6bcc7 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -46,7 +46,7 @@ class Trainer(Workflow): """ - def run(self) -> None: # type: ignore[override] + def run(self, *args) -> None: # type: ignore[override] """ Execute training based on Ignite Engine. If call this function multiple times, it will continuously run from the previous state. diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index c1fa448f25..f25ef680cc 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -29,6 +29,8 @@ from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler from .metrics_saver import MetricsSaver from .mlflow_handler import MLFlowHandler +from .model_calibrator import ModelCalibrater +from .model_quantizer import ModelQuantizer from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler from .panoptic_quality import PanopticQuality from .parameter_scheduler import ParamSchedulerHandler diff --git a/monai/handlers/model_calibrator.py b/monai/handlers/model_calibrator.py new file mode 100644 index 0000000000..76a3681b89 --- /dev/null +++ b/monai/handlers/model_calibrator.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import partial +from typing import TYPE_CHECKING + +import modelopt.torch.quantization as mtq +import torch + +from monai.utils import IgniteInfo, min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class ModelCalibrater: + """ + Model quantizer is for model quantization. It takes a model as input and convert it to a quantized + model. + + Args: + model: the model to be quantized. + example_inputs: the example inputs for the model quantization. examples:: + (torch.randn(256,256,256),) + config: the calibration config. + + """ + + def __init__(self, model: torch.nn.Module, export_path: str, config: dict = mtq.INT8_SMOOTHQUANT_CFG) -> None: + self.model = model + self.export_path = export_path + self.config = config + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(Events.STARTED, self) + + @staticmethod + def _model_wrapper(engine, model): + engine.run() + + def __call__(self, engine) -> None: + quant_fun = partial(self._model_wrapper, engine) + model = mtq.quantize(self.model, self.config, quant_fun) + torch.save(model.state_dict(), self.export_path) diff --git a/monai/handlers/model_quantizer.py b/monai/handlers/model_quantizer.py new file mode 100644 index 0000000000..4a6c69ad34 --- /dev/null +++ b/monai/handlers/model_quantizer.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from types import MethodType +from typing import TYPE_CHECKING + +import torch +from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config + +from monai.utils import IgniteInfo, min_version, optional_import + +Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") +Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint") +if TYPE_CHECKING: + from ignite.engine import Engine +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + + +class ModelQuantizer: + """ + Model quantizer is for model quantization. It takes a model as input and convert it to a quantized + model. + + Args: + model: the model to be quantized. + example_inputs: the example inputs for the model quantization. examples:: + (torch.randn(256,256,256),) + quantizer: quantizer for the quantization job. + + """ + + def __init__( + self, model: torch.nn.Module, example_inputs: Sequence, export_path: str, quantizer: Quantizer | None = None + ) -> None: + self.model = model + self.example_inputs = example_inputs + self.export_path = export_path + self.quantizer = ( + XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer + ) + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine, it can be a trainer, validator or evaluator. + """ + engine.add_event_handler(Events.STARTED, self.start) + engine.add_event_handler(Events.ITERATION_COMPLETED, self.epoch) + + def start(self) -> None: + self.model = torch.export.export_for_training(self.model, self.example_inputs).module() + self.model = prepare_qat_pt2e(self.model, self.quantizer) + self.model.train = MethodType(torch.ao.quantization.move_exported_model_to_train, self.model) + self.model.eval = MethodType(torch.ao.quantization.move_exported_model_to_eval, self.model) + + def epoch(self) -> None: + torch.save(self.model.state_dict(), self.export_path) diff --git a/requirements-dev.txt b/requirements-dev.txt index 72654d3534..bf25c9f535 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -61,3 +61,4 @@ pyamg>=5.0.0 git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 onnx_graphsurgeon polygraphy +nvidia-modelopt>=0.19.0