From aea46ff26b39c0c88e3d00cb88cb03442df61dd5 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 4 Sep 2024 01:31:28 -0700 Subject: [PATCH] Trt compiler fixes (#8064) Fixes https://github.com/Project-MONAI/MONAI/issues/8061. ### Description Post-merge fixes for trt_compile() ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Boris Fomitchev Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> --- monai/networks/trt_compiler.py | 8 ++++++-- tests/test_trt_compile.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py index a9dd0d9e9b..00d2eb61af 100644 --- a/monai/networks/trt_compiler.py +++ b/monai/networks/trt_compiler.py @@ -342,6 +342,7 @@ def forward(self, model, argv, kwargs): self._build_and_save(model, build_args) # This will reassign input_names from the engine self._load_engine() + assert self.engine is not None except Exception as e: if self.fallback: self.logger.info(f"Failed to build engine: {e}") @@ -403,8 +404,10 @@ def _onnx_to_trt(self, onnx_path): build_args = self.build_args.copy() build_args["tf32"] = self.precision != "fp32" - build_args["fp16"] = self.precision == "fp16" - build_args["bf16"] = self.precision == "bf16" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) @@ -502,6 +505,7 @@ def trt_compile( ) -> torch.nn.Module: """ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x Args: model: module to patch with TrtCompiler object. base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index 21125d203f..2f9db8f0c2 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -20,10 +20,10 @@ from monai.handlers import TrtHandler from monai.networks import trt_compile from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132 -from monai.utils import optional_import +from monai.utils import min_version, optional_import from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows -trt, trt_imported = optional_import("tensorrt") +trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version) polygraphy, polygraphy_imported = optional_import("polygraphy") build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")