Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 5, 2024
1 parent 20f1b19 commit 92cda1b
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@

try:
from monai.utils import TRTWrapper
TRT_AVAILABLE=True
except Exception as e:
TRT_AVAILABLE=False

TRT_AVAILABLE = True
except Exception:
TRT_AVAILABLE = False

rearrange, _ = optional_import("einops", name="rearrange")
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
Expand Down Expand Up @@ -137,19 +138,23 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
self.prev_mask = None
self.batch_data = None
if self.trt and TRT_AVAILABLE:
ts=os.path.getmtime(config_file)
self.model.image_encoder.encoder = TRTWrapper("Encoder",
self.model.image_encoder.encoder,
input_names=["x"],
output_names=["x_out"],
timestamp=ts)
ts = os.path.getmtime(config_file)
self.model.image_encoder.encoder = TRTWrapper(
"Encoder",
self.model.image_encoder.encoder,
input_names=["x"],
output_names=["x_out"],
timestamp=ts,
)
self.model.image_encoder.encoder.load_engine()

self.model.class_head = TRTWrapper("ClassHead",
self.model.class_head,
input_names=["src", "class_vector"],
output_names=["masks", "class_embedding"],
timestamp=ts)
self.model.class_head = TRTWrapper(
"ClassHead",
self.model.class_head,
input_names=["src", "class_vector"],
output_names=["masks", "class_embedding"],
timestamp=ts,
)
self.model.class_head.load_engine()
return

Expand Down

0 comments on commit 92cda1b

Please sign in to comment.