diff --git a/.gitignore b/.gitignore index be4dcdb..d5df6de 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ temp_image.jpg # .DS_Store +.python-version \ No newline at end of file diff --git a/capybara/onnxengine/engine.py b/capybara/onnxengine/engine.py index 18c7393..59171e0 100644 --- a/capybara/onnxengine/engine.py +++ b/capybara/onnxengine/engine.py @@ -7,7 +7,7 @@ import onnxruntime as ort from ..enums import EnumCheckMixin -from .metadata import get_onnx_metadata +from .metadata import parse_metadata_from_onnx from .tools import get_onnx_input_infos, get_onnx_output_infos @@ -17,7 +17,6 @@ class Backend(EnumCheckMixin, Enum): class ONNXEngine: - def __init__( self, model_path: Union[str, Path], @@ -43,18 +42,16 @@ def __init__( """ # setting device info backend = Backend.obj_to_enum(backend) - self.device_id = 0 if backend.name == 'cpu' else gpu_id + self.device_id = 0 if backend.name == "cpu" else gpu_id # setting provider options - providers, provider_options = self._get_provider_info( - backend, provider_option) + providers, provider_options = self._get_provider_info(backend, provider_option) # setting session options sess_options = self._get_session_info(session_option) # setting onnxruntime session - model_path = str(model_path) if isinstance( - model_path, Path) else model_path + model_path = str(model_path) if isinstance(model_path, Path) else model_path self.sess = ort.InferenceSession( model_path, sess_options=sess_options, @@ -64,7 +61,7 @@ def __init__( # setting onnxruntime session info self.model_path = model_path - self.metadata = get_onnx_metadata(model_path) + self.metadata = parse_metadata_from_onnx(model_path) self.providers = self.sess.get_providers() self.provider_options = self.sess.get_provider_options() @@ -86,8 +83,8 @@ def _get_session_info( """ sess_opt = ort.SessionOptions() session_option_default = { - 'graph_optimization_level': ort.GraphOptimizationLevel.ORT_ENABLE_ALL, - 'log_severity_level': 2, + "graph_optimization_level": ort.GraphOptimizationLevel.ORT_ENABLE_ALL, + "log_severity_level": 2, } session_option_default.update(session_option) for k, v in session_option_default.items(): @@ -103,18 +100,20 @@ def _get_provider_info( Ref: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options """ if backend == Backend.cuda: - providers = ['CUDAExecutionProvider'] - provider_option = [{ - 'device_id': self.device_id, - 'cudnn_conv_use_max_workspace': '1', - **provider_option, - }] + providers = ["CUDAExecutionProvider"] + provider_option = [ + { + "device_id": self.device_id, + "cudnn_conv_use_max_workspace": "1", + **provider_option, + } + ] elif backend == Backend.cpu: - providers = ['CPUExecutionProvider'] + providers = ["CPUExecutionProvider"] # "CPUExecutionProvider" is different from everything else. provider_option = None else: - raise ValueError(f'backend={backend} is not supported.') + raise ValueError(f"backend={backend} is not supported.") return providers, provider_option def __repr__(self) -> str: @@ -122,8 +121,8 @@ def __repr__(self) -> str: def strip_ansi_codes(text): """Remove ANSI escape codes from a string.""" - ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - return ansi_escape.sub('', text) + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) def format_nested_dict(dict_data, indent=0): """Recursively format nested dictionaries with indentation.""" @@ -133,13 +132,16 @@ def format_nested_dict(dict_data, indent=0): if isinstance(value, dict): info.append(f"{prefix}{key}:") info.append(format_nested_dict(value, indent + 1)) - elif isinstance(value, str) and value.startswith('{') and value.endswith('}'): + elif ( + isinstance(value, str) + and value.startswith("{") + and value.endswith("}") + ): try: nested_dict = eval(value) if isinstance(nested_dict, dict): info.append(f"{prefix}{key}:") - info.append(format_nested_dict( - nested_dict, indent + 1)) + info.append(format_nested_dict(nested_dict, indent + 1)) else: info.append(f"{prefix}{key}: {value}") except Exception: @@ -148,11 +150,12 @@ def format_nested_dict(dict_data, indent=0): info.append(f"{prefix}{key}: {value}") return "\n".join(info) - title = 'DOCSAID X ONNXRUNTIME' + title = "DOCSAID X ONNXRUNTIME" divider_length = 50 divider = f"+{'-' * divider_length}+" styled_title = colored.stylize( - title, [colored.fg('blue'), colored.attr('bold')]) + title, [colored.fg("blue"), colored.attr("bold")] + ) def center_text(text, width): """Center text within a fixed width, handling ANSI escape codes."""