diff --git a/scripts/utils/trt_utils.py b/scripts/utils/trt_utils.py index af95821..3632119 100644 --- a/scripts/utils/trt_utils.py +++ b/scripts/utils/trt_utils.py @@ -224,7 +224,7 @@ def try_set_inputs(): left = ctx.infer_shapes() assert len(left) == 0 - + def infer(self, stream, use_cuda_graph=False): if use_cuda_graph: if self.cuda_graph_instance is not None: @@ -433,7 +433,7 @@ def forward_trt_runner(self, trt_inputs): ret = list(ret.values()) ret = [r.cuda() for r in ret] # check = [check_m(r) for r in ret] - if len(ret)==1: + if len(ret) == 1: ret = ret[0] return ret diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index 72bf586..ebd1d9c 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -12,7 +12,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Union, List, Tuple +from typing import List, Tuple, Union import numpy as np import torch