Skip to content

Commit

Permalink
Merge pull request #22 from DocsaidLab/bugfix/fix_forget_update_parse…
Browse files Browse the repository at this point in the history
…_metadata_function_in_ONNXEngine

[F] fix forget updatind the parse function in ONNXEngine
  • Loading branch information
kunkunlin1221 authored Feb 11, 2025
2 parents a9b0eaa + d920d56 commit 9973bf8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@ temp_image.jpg

#
.DS_Store
.python-version
53 changes: 28 additions & 25 deletions capybara/onnxengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import onnxruntime as ort

from ..enums import EnumCheckMixin
from .metadata import get_onnx_metadata
from .metadata import parse_metadata_from_onnx
from .tools import get_onnx_input_infos, get_onnx_output_infos


Expand All @@ -17,7 +17,6 @@ class Backend(EnumCheckMixin, Enum):


class ONNXEngine:

def __init__(
self,
model_path: Union[str, Path],
Expand All @@ -43,18 +42,16 @@ def __init__(
"""
# setting device info
backend = Backend.obj_to_enum(backend)
self.device_id = 0 if backend.name == 'cpu' else gpu_id
self.device_id = 0 if backend.name == "cpu" else gpu_id

# setting provider options
providers, provider_options = self._get_provider_info(
backend, provider_option)
providers, provider_options = self._get_provider_info(backend, provider_option)

# setting session options
sess_options = self._get_session_info(session_option)

# setting onnxruntime session
model_path = str(model_path) if isinstance(
model_path, Path) else model_path
model_path = str(model_path) if isinstance(model_path, Path) else model_path
self.sess = ort.InferenceSession(
model_path,
sess_options=sess_options,
Expand All @@ -64,7 +61,7 @@ def __init__(

# setting onnxruntime session info
self.model_path = model_path
self.metadata = get_onnx_metadata(model_path)
self.metadata = parse_metadata_from_onnx(model_path)
self.providers = self.sess.get_providers()
self.provider_options = self.sess.get_provider_options()

Expand All @@ -86,8 +83,8 @@ def _get_session_info(
"""
sess_opt = ort.SessionOptions()
session_option_default = {
'graph_optimization_level': ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
'log_severity_level': 2,
"graph_optimization_level": ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
"log_severity_level": 2,
}
session_option_default.update(session_option)
for k, v in session_option_default.items():
Expand All @@ -103,27 +100,29 @@ def _get_provider_info(
Ref: https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
"""
if backend == Backend.cuda:
providers = ['CUDAExecutionProvider']
provider_option = [{
'device_id': self.device_id,
'cudnn_conv_use_max_workspace': '1',
**provider_option,
}]
providers = ["CUDAExecutionProvider"]
provider_option = [
{
"device_id": self.device_id,
"cudnn_conv_use_max_workspace": "1",
**provider_option,
}
]
elif backend == Backend.cpu:
providers = ['CPUExecutionProvider']
providers = ["CPUExecutionProvider"]
# "CPUExecutionProvider" is different from everything else.
provider_option = None
else:
raise ValueError(f'backend={backend} is not supported.')
raise ValueError(f"backend={backend} is not supported.")
return providers, provider_option

def __repr__(self) -> str:
import re

def strip_ansi_codes(text):
"""Remove ANSI escape codes from a string."""
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
return ansi_escape.sub('', text)
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape.sub("", text)

def format_nested_dict(dict_data, indent=0):
"""Recursively format nested dictionaries with indentation."""
Expand All @@ -133,13 +132,16 @@ def format_nested_dict(dict_data, indent=0):
if isinstance(value, dict):
info.append(f"{prefix}{key}:")
info.append(format_nested_dict(value, indent + 1))
elif isinstance(value, str) and value.startswith('{') and value.endswith('}'):
elif (
isinstance(value, str)
and value.startswith("{")
and value.endswith("}")
):
try:
nested_dict = eval(value)
if isinstance(nested_dict, dict):
info.append(f"{prefix}{key}:")
info.append(format_nested_dict(
nested_dict, indent + 1))
info.append(format_nested_dict(nested_dict, indent + 1))
else:
info.append(f"{prefix}{key}: {value}")
except Exception:
Expand All @@ -148,11 +150,12 @@ def format_nested_dict(dict_data, indent=0):
info.append(f"{prefix}{key}: {value}")
return "\n".join(info)

title = 'DOCSAID X ONNXRUNTIME'
title = "DOCSAID X ONNXRUNTIME"
divider_length = 50
divider = f"+{'-' * divider_length}+"
styled_title = colored.stylize(
title, [colored.fg('blue'), colored.attr('bold')])
title, [colored.fg("blue"), colored.attr("bold")]
)

def center_text(text, width):
"""Center text within a fixed width, handling ANSI escape codes."""
Expand Down

0 comments on commit 9973bf8

Please sign in to comment.