Skip to content

Commit

Permalink
[C] update onnx metadata functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kunkunlin1221 committed Feb 11, 2025
1 parent b7beb8a commit de00db1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
9 changes: 6 additions & 3 deletions capybara/onnxengine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .engine import Backend, ONNXEngine
from .engine_io_binding import ONNXEngineIOBinding
from .metadata import get_onnx_metadata, write_metadata_into_onnx
from .tools import (get_onnx_input_infos, get_onnx_output_infos,
make_onnx_dynamic_axes)
from .metadata import (
get_onnx_metadata,
parse_metadata_from_onnx,
write_metadata_into_onnx,
)
from .tools import get_onnx_input_infos, get_onnx_output_infos, make_onnx_dynamic_axes

# 暫時無法使用
# from .quantize import quantize, quantize_static
27 changes: 21 additions & 6 deletions capybara/onnxengine/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def get_onnx_metadata(
onnx_path: Union[str, Path],
) -> dict:
onnx_path = str(onnx_path)
sess = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
metadata = sess.get_modelmeta().custom_metadata_map
del sess
return metadata
Expand All @@ -21,16 +21,31 @@ def write_metadata_into_onnx(
onnx_path: Union[str, Path],
out_path: Union[str, Path],
drop_old_meta: bool = False,
enable_json_dumps: bool = True,
**kwargs,
):
onnx_path = str(onnx_path)
onnx_model = onnx.load(onnx_path)
meta_data = get_onnx_metadata(onnx_path) if not drop_old_meta else {}

meta_data.update({
'Date': now(fmt='%Y-%m-%d %H:%M:%S'),
**kwargs
})
meta_data.update({"Date": now(fmt="%Y-%m-%d %H:%M:%S"), **kwargs})

onnx.helper.set_model_props(onnx_model, {'metadata': json.dumps(meta_data)})
onnx.helper.set_model_props(
onnx_model,
{k: json.dumps(v) if enable_json_dumps else v for k, v in meta_data.items()},
)
onnx.save(onnx_model, out_path)


def parse_metadata_from_onnx(
onnx_path: Union[str, Path],
enable_json_loads: bool = True,
) -> dict:
onnx_path = str(onnx_path)
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
metadata = {
k: json.loads(v) if enable_json_loads else v
for k, v in sess.get_modelmeta().custom_metadata_map.items()
}
del sess
return metadata

0 comments on commit de00db1

Please sign in to comment.