From b4b188f81d3223c58f01d29e97d03ff00eecd111 Mon Sep 17 00:00:00 2001 From: kunkunlin Date: Tue, 14 Jan 2025 16:25:10 +0800 Subject: [PATCH 1/6] [A] add io_binding mode for onnxengine with cuda graph on --- capybara/onnxengine/__init__.py | 1 + capybara/onnxengine/engine_io_binding.py | 179 ++++++++++++++++++++ tests/onnxruntime/test_engine.py | 14 ++ tests/onnxruntime/test_engine_io_binding.py | 15 ++ tests/resources/model.onnx | Bin 0 -> 11493 bytes 5 files changed, 209 insertions(+) create mode 100644 capybara/onnxengine/engine_io_binding.py create mode 100644 tests/onnxruntime/test_engine.py create mode 100644 tests/onnxruntime/test_engine_io_binding.py create mode 100644 tests/resources/model.onnx diff --git a/capybara/onnxengine/__init__.py b/capybara/onnxengine/__init__.py index 65fc8fe..e6fc481 100644 --- a/capybara/onnxengine/__init__.py +++ b/capybara/onnxengine/__init__.py @@ -1,4 +1,5 @@ from .engine import Backend, ONNXEngine +from .engine_io_binding import ONNXEngineIOBinding from .metadata import get_onnx_metadata, write_metadata_into_onnx from .tools import get_onnx_input_infos, get_onnx_output_infos diff --git a/capybara/onnxengine/engine_io_binding.py b/capybara/onnxengine/engine_io_binding.py new file mode 100644 index 0000000..7833588 --- /dev/null +++ b/capybara/onnxengine/engine_io_binding.py @@ -0,0 +1,179 @@ +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Union + +import colored +import numpy as np +import onnxruntime as ort + +from .metadata import get_onnx_metadata +from .tools import get_onnx_input_infos, get_onnx_output_infos + + +class ONNXEngineIOBinding: + + def __init__( + self, + model_path: Union[str, Path], + input_initializer: Dict[str, np.ndarray], + gpu_id: int = 0, + session_option: Dict[str, Any] = {}, + provider_option: Dict[str, Any] = {}, + ): + """ + Initialize an ONNX model inference engine. + + Args: + model_path (Union[str, Path]): + Filename or serialized ONNX or ORT format model in a byte string. + gpu_id (int, optional): + GPU ID. Defaults to 0. + session_option (Dict[str, Any], optional): + Session options. Defaults to {}. + provider_option (Dict[str, Any], optional): + Provider options. Defaults to {}. + """ + self.device_id = gpu_id + providers = ['CUDAExecutionProvider'] + provider_options = [ + { + 'device_id': self.device_id, + 'cudnn_conv_use_max_workspace': '1', + 'enable_cuda_graph': '1', + **provider_option, + } + ] + + # setting session options + sess_options = self._get_session_info(session_option) + + # setting onnxruntime session + model_path = str(model_path) if isinstance(model_path, Path) else model_path + self.sess = ort.InferenceSession( + model_path, + sess_options=sess_options, + providers=providers, + provider_options=provider_options, + ) + + # setting onnxruntime session info + self.model_path = model_path + self.metadata = get_onnx_metadata(model_path) + self.providers = self.sess.get_providers() + self.provider_options = self.sess.get_provider_options() + + input_infos, output_infos = self._init_io_infos(model_path, input_initializer) + + io_binding, x_ortvalues, y_ortvalues = self._setup_io_binding(input_infos, output_infos) + self.io_binding = io_binding + self.x_ortvalues = x_ortvalues + self.y_ortvalues = y_ortvalues + self.input_infos = input_infos + self.output_infos = output_infos + # # Pass gpu_graph_id to RunOptions through RunConfigs + # ro = ort.RunOptions() + # # gpu_graph_id is optional if the session uses only one cuda graph + # ro.add_run_config_entry("gpu_graph_id", "1") + # self.run_option = ro + + def __call__(self, **xs) -> Dict[str, np.ndarray]: + self._update_x_ortvalues(xs) + # self.sess.run_with_iobinding(self.io_binding, self.run_option) + self.sess.run_with_iobinding(self.io_binding) + return {k: v.numpy() for k, v in self.y_ortvalues.items()} + + def _get_session_info( + self, + session_option: Dict[str, Any] = {}, + ) -> ort.SessionOptions: + """ + Ref: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions + """ + sess_opt = ort.SessionOptions() + session_option_default = { + 'graph_optimization_level': ort.GraphOptimizationLevel.ORT_ENABLE_ALL, + 'log_severity_level': 2, + } + session_option_default.update(session_option) + for k, v in session_option_default.items(): + setattr(sess_opt, k, v) + return sess_opt + + def _init_io_infos(self, model_path, input_initializer: dict): + sess = ort.InferenceSession( + model_path, + providers=['CPUExecutionProvider'], + ) + outs = sess.run(None, input_initializer) + input_shapes = {k: v.shape for k, v in input_initializer.items()} + output_shapes = {x.name: o.shape for x, o in zip(sess.get_outputs(), outs)} + input_infos = get_onnx_input_infos(model_path) + output_infos = get_onnx_output_infos(model_path) + for k, v in input_infos.items(): + v['shape'] = input_shapes[k] + for k, v in output_infos.items(): + v['shape'] = output_shapes[k] + del sess + return input_infos, output_infos + + def _setup_io_binding(self, input_infos, output_infos): + x_ortvalues = {} + y_ortvalues = {} + for k, v in input_infos.items(): + m = np.zeros(**v) + x_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type='cuda', device_id=self.device_id) + for k, v in output_infos.items(): + m = np.zeros(**v) + y_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type='cuda', device_id=self.device_id) + + io_binding = self.sess.io_binding() + for k, v in x_ortvalues.items(): + io_binding.bind_ortvalue_input(k, v) + for k, v in y_ortvalues.items(): + io_binding.bind_ortvalue_output(k, v) + + return io_binding, x_ortvalues, y_ortvalues + + def _update_x_ortvalues(self, xs: dict): + for k, v in self.x_ortvalues.items(): + v.update_inplace(xs[k]) + + def __repr__(self) -> str: + def format_nested_dict(dict_data, indent=0): + info = "" + for k, v in dict_data.items(): + prefix = " " * indent + if isinstance(v, dict): + info += f"{prefix}{k}:\n" + format_nested_dict(v, indent + 1) + elif isinstance(v, str) and v.startswith('{') and v.endswith('}'): + try: + nested_dict = eval(v) + if isinstance(nested_dict, dict): + info += f"{prefix}{k}:\n" + format_nested_dict(nested_dict, indent + 1) + else: + info += f"{prefix}{k}: {v}\n" + except: + info += f"{prefix}{k}: {v}\n" + else: + info += f"{prefix}{k}: {v}\n" + return info + + title = 'DOCSAID X ONNXRUNTIME' + styled_title = colored.stylize( + title, [colored.fg('blue'), colored.attr('bold')]) + divider_length = 50 + title_length = len(title) + left_padding = (divider_length - title_length) // 2 + right_padding = divider_length - title_length - left_padding + + path = f'Model Path: {self.model_path}' + input_info = format_nested_dict(self.input_infos) + output_info = format_nested_dict(self.output_infos) + metadata = format_nested_dict(self.metadata) + providers = f'Provider: {", ".join(self.providers)}' + provider_options = format_nested_dict(self.provider_options) + + divider = colored.stylize( + f"+{'-' * divider_length}+", [colored.fg('blue'), colored.attr('bold')]) + infos = f'\n\n{divider}\n|{" " * left_padding}{styled_title}{" " * right_padding}|\n{divider}\n\n{path}\n\n{input_info}\n{output_info}\n\n{metadata}\n\n{providers}\n\n{provider_options}\n{divider}' + return infos diff --git a/tests/onnxruntime/test_engine.py b/tests/onnxruntime/test_engine.py new file mode 100644 index 0000000..ec2637a --- /dev/null +++ b/tests/onnxruntime/test_engine.py @@ -0,0 +1,14 @@ +import numpy as np + +from capybara import ONNXEngine, Timer + + +def test_ONNXEngine(): + model_path = "tests/resources/model.onnx" + engine = ONNXEngine(model_path, backend='cuda') + for i in range(30): + xs = {'inputs': np.random.randn(32, 3, 640, 640).astype('float32')} + outs = engine(**xs) + if i: + assert not np.allclose(outs['outputs'], prev_outs['outputs']) + prev_outs = outs diff --git a/tests/onnxruntime/test_engine_io_binding.py b/tests/onnxruntime/test_engine_io_binding.py new file mode 100644 index 0000000..e3ba003 --- /dev/null +++ b/tests/onnxruntime/test_engine_io_binding.py @@ -0,0 +1,15 @@ +import numpy as np + +from capybara import ONNXEngine, ONNXEngineIOBinding, Timer + + +def test_ONNXEngineIOBinding(): + model_path = "tests/resources/model.onnx" + input_initializer = {'inputs': np.random.randn(32, 3, 640, 640).astype('float32')} + engine = ONNXEngineIOBinding(model_path, input_initializer) + for i in range(30): + xs = {'inputs': np.random.randn(32, 3, 640, 640).astype('float32')} + outs = engine(**xs) + if i: + assert not np.allclose(outs['outputs'], prev_outs['outputs']) + prev_outs = outs diff --git a/tests/resources/model.onnx b/tests/resources/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a843355e53f947a12d4add6dbbd924bf20dea0e0 GIT binary patch literal 11493 zcmZ9S2{2c0^#ASqP6%aRvSeRB_dM2UFNGHEt5Q<5(Sm5Plu(wEq(#V5Nqp{kvb9OH zD5w~xL3J7xIb>h9uD{LqsCb)IauDyZElUIC(Q^X|vH}479ykl*U-}=DKPW(>% z{~5`O2?Va+^gna{cb<%xa8U4`ZJRa+Iq^C1O&ZCG?OeYtVC~jD>jV7^A9{&w65;>v zkl}UtzmGPz(90n+95?j}lSc-BkRl>3)_D33&|BPhFYZLm)ZHQ@0zla|tlxW(& zndD$rE~qV9jtZ;`E6ABZujoW{CAp8UBQqGpMco1d6xPA_wl$uh%c5C`=qdL8* zUCETJ&co?^W$g2N@9>Gi26kG`TNIj?1##z`Xh!*K+~j6WcGkUTZ#_SOY~57Mj{X3D ztL4bu&xxR!_8uIc)?q@J9bC}VC2#rl=)qHJ#E)+pS+dN42ya$H;cY{Z7x5hh1d!!h z+{_4nl}DE!_u*{g9}tkM2EJb!L)#K?>VCtBhy@KY>$NCuls|=eGn9z(nQy3UoynRv`r>gXZSu0+ z0X%q@a=oY6knlZXG)35uo(yhb-nY$Td2CamJ>P;R7CpqR>Cvzww*k`l`5}wH1d?hk z;p#{amTJn-x}7tr#487~ZpRr&EwZ5JZ6D%d!3(&iCb+C*ZYgK^VlkrHUV+iFnasiP zGPY{ZM@UZ5prM^5_^vz^t!NemCq}~HKTmd*_Z3W1P@#eGRq!Zkh{@8vi(US<%sin_ zFn?7!j`C@fVka%)9pB9k&O3rv`s*>~k1#t+b^=F_7s0o#HtzJDi(sms9(DQXMh>1| z3Qw+R(_RO0D(#;DpX~Cv^3Fx5_t%zQpGf31WPZe$>|^lLS(J#bm8CU*AEHU3Gnwh? zje*v!m|!7C`@1iL;!AVTt2zz0pFV^23$)3NOD8cjuoxoa^07gB8Z`B!G21f_!`-NK z_Irm3Sv=CsfVC#B()@*FL=KmCE%@(G1D=f(Vz+qU~sY*3Ah?4;uKO75Ohd+jHF%Np*G3C#!N&Mzi@ckabTvii6 z$>#Np()NcKcgKR>&;hbFRSrJP+6CH-4!LC{N@pwg;ljXbc-KCKIJ}(>-&9p8yZJ8^ zclmv`L~CisOG1K z#2+!yLAPO6++8@8B0(N#Oe03Jgw#Jffv^WX0J}FfNpbzBTrkG$2o;#!j%d$6m5WLa(F# zmkHG69Aj98J_z$uCkJZBQE*I??i!Y-Znc&)_@NUSRQkjMj8psLYQ$(h7M1bsGw-TJp9%If2CFt6sN?z=#!*D4*GRt%q7Wmpz z&x6i1s6+LB*+Osd4L=P@=2rBWmNlfmUWksbx?$*)9Jv~Y82_>dL_2Rn%W+ZiY9${{ zxhP01I{V)J4ph#;2n)N)psZ^%ZfO&umnDRVQ2~b*unO2> zeG$2jAA{o6J|;}X6{Ukc*<2)F$OS2MWF0a5JP0bkfUK@fE^r1*%9M3B_ zkyH2LF|p;?@}LAed<}#1WCe@q$X(` zEl*#>NCi2JYwd#C<^&kpx{6tLYa=sQGYC%(TanvlzgWQ$S5j|q0nNYOhw?iN)p+=e zJ(O=szQ9Vl=N8BxyTb$k+K1&@DZS zhViY($CnkTLbC{be%_2jCS%YeRL`VHCbDM@|Kj#?VJbaE6i2Yl-cbQka7&bmCW@em z3?ID~G=)r$t-xvbcQgIQLNr3&n$T7K%->hT(3NCQ-YGRf<|;jARq#E$wZj!Eb~(_y z3rxvn!BSx7WI|KUG$wbQI5jU-q~Gm)pdweE zMvZCM6ChI@qPlaB%r|pTsmo{O-Q8P02swnl{I|IX=4uC;VJvvs^Vf_3&FqURu$cT^p zoTksdT$PI=6Nu{TOY!FpQJPaemHDz%gf{P#AQhW5>BW0rvDN4jOz8hcuXHD3YZ(J_ zhQj1Y*K<@0DS*y%Jn)I+L5Afme5t`p+dk`(Ew^TaX%RxnJ_W)q=An3F1>DV5hyD9S zpi!tB)$QxSIbEI>bXZZ5BxO1_YZ;vUoq=6BS#Y~F79XpYz`So=X#Vyb8$Zv1NaReR z_B=Ag>U#ie(3e1iL4Sw#Q+6EJxLNptcNQ5rW&_QO6F`PL$R<~RLGutVwjs8WyZwzS zGO^21K$Z_5-)4wxs1Wg0Vu;w{V~}{7g@I&4vROWv{jr^w%G8};1s2sX5f&*hFn9)a z;`qsYaXF&OKc7jLO@q|oyJ)J_UiPSu5b|4=-F8-;zS}K8O|p&2(H$ZX_e2ecU&;~T zAw{yTp3rGm{^FV!k|e}!H#6gd33eik2kgaSJO^Ed*;#HZ%QZmr!tFBGb$N0dAkU zjOL#tXgvLkV^KMfsAW!1gv=mo1&%?9iV(G&b`F*oY7nirL-HjE? zQuih3yw_1|M#p=6yqcfHsS1$Qx%W9+6{pZ=p^EIC31{l(%ty4JRk9&l)o3$+3O4P0 z1<4Di&?-keI9nq?JO`)IS11NMu1J!W!`B%8sW0%9ya@T+(vIno`Ou{zOFyU_MY9QU zkcA46ciD=PZ~2J6as_wsG(mFnQx#S!kHcpZVLaX{#YQX|XNo_l)0z2pnBdn1musBC z^V3>f{G=P_cNnnuroVx&L670x+^tyrvboG9<%jY+H{b{(MZISqpdnj2!#QPs3=SWKGN;m%!>G85$HT zLd`>UXq1K&x%$eO=~_)U@O#Jy%i{kB1P*KM}& zcpvsM%JgXVHBiv~03UuNz=LNDJs;N3?0OWBTlZO#1snVD#oREq@X~auGq(uNirit{ zEqKU{@o%7b?GGw;-el;h1Q;;sFcai~w`~SnQM?~o5x2bR_)rJ-v)*}DzD9}31e9VuEgXkdzkW9@3 zOYs!EwaJf_zx#$=-eX1i7Pf)M<4w$0dV({b-;LCMvq#+-!MH2@9Y?g6eq# z^59}2N?bQ3E8YLVVuyN&+#d^L#ZSPX|0G<*YW&R?$&U7#P-`O<7&+Yohrg)M)zVp* zy)uJQ@wFhuI$^Bm^kmS=iGn*TM)0@%Z&<_`;BHSn0ZU%Jf_l`Y?|p4?*XBrA;J=)) z4Gm&;>RC~z!j~AaaSN+q`2ts=0r@Iihs^kMG&=Md_%*p;zeJqWOmV{8U4!iL_)_@p zaTTf$f5I3KTe7=OoBn#>OkYS(aKoEl!MXSs9Ph*JP@8K*-5pejX5C(p3KyeJli&G@ z>2u-iAkab=ek!MlC-}G=KV_qrDvTeFkg#4P8B6P;)|FaQs&Gr?d8zC z+nHL;&BNM|4CY+;L;NP#2opUQ*qEndDBSCf7tRc#D<=-XX9Bc!3&CSRitcV6XHTLL z9hUkCvlJI{{ktuRyPY*{TINU!XIap*cYiP&)5NL#UVR#r6$t`o=CgAovhbme8gy*` z2V+<7BcnttP?R_83D`Zny@DZR)UVbGM!?6 z8#*=!p?y;)yjy5Y$1@%1CSg08BdP|`BQ|8Fwl=X8yUNh{$1r%=TQsr2M}H3CA>|HmmN0>%9u?e{1w5oi_bRU0I)tGIW59R84-|-tgfnF#7_F3n zjem7Wv_}W*ube__eXp|fCjVn@=`oyBrUCrtKd6f5%B#Kkb=<IlH@P5>gu53~uO|HhYJ)jq&9`6Iw3U@HNAxU?4Oecd!RA6_!5|Po& zh1<&}ST(s9U=Sin>$a|ei3BJ5H(Z^bpJGb$EULjyB^S?acndY=oc0ykP}6Pg4jJ_rCST6|I zG`Nwz`DO6*RxM_T>;$jeAlMZe2g6~0D3=`#p-ZyRFj|=yB>-I)E)6Trje`9i6Z&3R zj_b^+3_3=;^^_1^W}yZXgxtZZK_Rv%+Mi)-A6$t{5yQtEr7-Sned2pCC&Ys z#EUx%8q3?D;I16zS^D6j?kx79L>IK%3X@YxGf+$?5byY}=bD9b$mhdh$lq>GcrT@( zR34C_hF)Z!>cBSlddQf}?UuajLHS&FP+j6mKb_*i@zpFujt+rWsUX$7V^3my?nA`u z6S%m09So~Xr&)9S*jw$_AXrp@mMW%zLIhiX0t}mtdYREWxWUglOJT zLBc!u7f$^4LeWNH@g3#qx7^I-U*P7P6g|yGqB1Gu+S8c3vb18f9G`Q2-X~!I~Y@e)EKXRFEhj4ydA&#~9LiMRpyf66}>Dy>_&A)T7KhS`x zpcpCVdB~I`J;&-jLeyA#9c~Z*f`d+SSQtB<*}wB9I*@+YDY%hozBGXY?;G)}m;y>@ zyuc%E1>j$t#5Qj%0e4P4imVZ#)4Xf2%SD0CkyRv{H!p%4pFVINysKaj%#1Q@_-(-r#~;GM;W^-T z=@Fa@ZsZiVC2-Ed2N0Xv$=q0LM66!(5KEgJC~(muyKg^ccm3GHJw8u|&h59RXDv5^ zmc|P>^792-Z(ajkMMm`0rdq5>HY4Yl1<==&052&oo^_p0M->H`epZ^AcPdk#nY=_S zIb(7kuEiFw56I0~14>m3+4TlvrPdz5A!Feh_#^%R7M1^C_n7VsIwQoSNTufpurY!HAKkJ9MUyNJC#{&fhn>5Id)`iU6k zRfHbl@9?KQFC@NjqK0dfK*2AAoyuK{k&)kVWXOM?%JCH@@ZV(;cks|% z6$(V|ygr?}FP7oVs)X@_d~}KB1=g=%Eq1;ZB~Od_VG8ehM&82_SG&qk{WWgn(v^EC z{dFxqEtv`Nao90k7N;D2iVZPFq(EDZ ztoQQ9%pGaKj&y)-fDu*2X^4$Q=zUKQ4i<`#?kgAI$og7*;QowFDQd=j{-)f|v8qJJ z?|o^(za;L*aw&StGMu^K%At4nXp-pD{PcQVKO6|1LerkN;o6i|tp966dw1W*qGkuO ztI~kX8N12OHw`LVa`-rteou_7^t2%w@qw^<5>XI=Iq>I`J@J1s0{nXf=!2f;AbM1g zdfOV3fKVG|c=|DzHImC-aZO@(b_Z!KYKhz*&&)`D)oQ3>UvR+*hSiqUbDAM>Ek&p_o5)%s5%N8KVD&JtT5fO%o5eNKE{J9Pr(>OWqExdwX3F5Jg>ITYh(?Q2jUq)qqiQ>5ow4&iJ~M<^eE%rq?= zVSiN=pp2pl@bZ|DfVL0lyIq2ukV?S2Uc4A_kAkSWHoaH;8Q)w#2{#w|fXZ+xo;2FX z#zaZck}Dzj!=;`L>Mz8;YyH^6XO2yYk$9u_H2#>yM{XzG!NAlP7?xLuoHk7$GDq;U zkTCTO?qK)!5K>(nhRO3rz-5vZ9`7`zOP5?@hReP19S6wiO+sXGp&&K7co~Mv<>(>a ze5~213nS-VFv4MfV4&n1zIB!){daz0L;5D**~G=;7Rz9L_A@Z8Goo2Gw$$pfE2-4W zg5PSRxG!Y|#6A1Y9LmdqoGp{Nl(7pb-K`3j{W)}TzZ%`H-pufKI*{P1T@avD!J2#O zlRth^bVK?#q&`k)pe;vbgp^3$kSWaGM{U}-MlFeX;nf-;Ug2c{p1mq%~!=qLM z;uCohPQF*9rMoRiP}W*}Z0AUV)W*18N&IA>*@+%b2YP%mmi-d6rnh}k(ev9WoYp7I zjBY;$Zr*p9S8jzc+S`cku|cTa=|rlXV$tdKbdoIc9%fYRW>!D7p-(SWF>ZqKAgD4~ zx;)^a@jKn9lBflU?ASSG*>DslPd&|O zU6{(uI6L{DnN`C?wjDj+=7SA8eAs|+Im%ge4-F6T(|z_euw*olQ9QZ?1JB)p1wD?$ z@~sMFDL;hmJt}Zf$qKRrPQ&2BsdTexBeqX@0)=wBKt6OO@{i?$W$75poHwGcbXCaA z4_PSMt4*|Sz5)vi4eIl495hZ^)7RlG@b7^sy`f}E{%A`uJDA5zIOhRl@1BXDP2b_E zcgxVF=Nj<4=77P}Yna6S&KWV~Bk_l(ki@Pg{H@DRK5>$`)?Ul8?Djoq_gAC2%|dib zge(a$P-bE?9LU=u2jU$kO*0M}6EgiF+`iC)m0PXI)2h9&IP)CL*n0yG=NiKL$}E%; z3`d)E5qe)soS>#PS+DCvSC&P>{Az2G<0DP#yE#y<>qzEJa`-2sGhorxXn34c%gjDv zMRapT=v+pd)Vg*+{y!UH@#HXHGJaOlbnjp(1Zp{x$dWghyzK^jN(IsSRbF zHTkIXw=9eg`8C=3>%`sFxiD)&obCy9p>v~E$lNj;^!2kPr$_x*C2d>!_337yjka{Y ztRr32?G2(aTj4GnjJLBzNr|->t(pw$3sYxuvX8og&UGacoF_xH=i5$lSYZ?%E(C)o zF<@8~4_Ay$XpDy*2^4f98NwFeJzfvH>m=x)Y6nbco71m~fmkx@JPut6#0~j=Ak1$_ zg;q*{XGEh{Rx{Gac4Ei6N&X@c2F2GSLGdlZFTD>iZkB?-c^lX~by>9TGokO8 zYgjgQ2yBk3&?GlC^7fA+tXeh(Z!Xq@;p|Ir?4lXb9iBnN7ATP}d3p9zhB(oTP@^M) zmpDAn8!}LtLW}D;lW! z5l(DqfhYZG%(;3gdLvVcQFF>b@usKPn4AQMKfUBESnG|k>I=BW0!}2fe-)aEXS3wi zc^F<;2ITS&tSs1pojc5lzQSZ5Q+Wmn&31*6gd%7hlcAybo3VYZf>KMTP{FZZm@<1N zWPDA6$Bdc)Np6F;Y=AUtsi@ZELm}dX&*2n{)1nb z&!7Hb;HtBbTP;f?PUkY-7iN-@Y14_zA}R9drYbY_mLqP}Pe2~=YxtSZ1q9Y}Xz6$r z`%Xw5c6`f68{Scv_oIq=`%Q~_A8@0t6Cq%eHW?$LZ0J$H2(;9H55HVGpzd%E`b;v< zWMOIYapg)T$~YYCW=^O4wp?(ZGlVC)k3qsBe)xJk0tw$!;Nt~IdYH&;F8#u8aTDa` zPPq)6mHH6c(~8#CQjm6KI$br9g&J;hSp12XJhIfqj%W6yxz!)oK_!|gZAIH2%%I!0 zyum%YZa|-62ODs?oee*C8~RG8lA=s2`fqz2bi7j|Y84Y`-lhsQJ6&ko+*X{olA&T( z+$fvz8X~DHmG&BChh$T*s#>3p@4O9jx7~)!n=-iTMFjJqoS!bqu^^lNPR7#gZ1$z2 zAXQ08hG54w+|{B&Q`RH`oSI7OT<@Tur4I}ZNuwULV%;1Erta7x(0m|*Z?=bk+KLV6 z{@@NqWU11%-7m2Ifil(GSPT1ugqd|_ee9MPG0yHC79{38A6+b|fV)RD(N~W{PqMS{ z!dMm7PE{pi2Xv{%DLu-wV>$5KRx{5$Cb{)x%1NvJ0Y|I&sOExX_Sse$I+k*uz0*2| zRuUV4_0vaRfp+%Wzqeo$n}q%$l~~|RXv8L7HkyAHczjSH5y2Adyr@$+)zgN`FPjS^ zhIcu+ZK8AqpEsK_&ljFr7Nc*)IB+kjFiRBH!CEjR;w%2ZK@AO>{E&~n$!=k`kLnPO z-GXGSr3r@y1&DvT6XxG|%EUAuMWumT_^2S23As52N|XD-d>=2A(K4YyGxuTAuW5A1 z_cHTne3EkoX~Lp^Mx+l3IU6ZQ|5!Z-QmIRBmKjsg0DXFNX)aiGkAkukheYvM(3t6t zbT@ky8hEBt#YrvM+RDpN6re*)W zapew%$OYk>rZldh_&9z{m7*fTCXlrCOpVNX1I&EE$u z2l-L{oj!GmJb_o{>(hRCvfPf~;UQyMwZV>ZXZ^5>5~aHDaLfM;qZeU9c}jhOz4-!ur5REqO--8l$DRaKyuiHmS@3%P8HnCb zxjV9};Q7`_;H&EenLzuM?V&2=QP}qBh?KEm-)BB`IbgDqs0mD*oR6LkoDey4mZyvfx9&*Uym+P5-w&H{#fD4US1+s;Q?y^ zh)#t({XOJ_E{q|m%NL;&RGn%{h2kocO#Jre33@HzfPbe1IBmI(Gt@0fk=`#Tm|}%B zMyD}E(tvjCb0A%tM$yG1kIB?FC$m-w;p-VWAm;uRU6Y5fKK~eM7*Uj%KdIZZyTk}Q z7NlN0{GeVYfDOXUsClatA>u9LSS(A%Hk@|@X5&jA7eySr2qf` literal 0 HcmV?d00001 From f6e3b782889844902bab47ddd6708d98e97b8066 Mon Sep 17 00:00:00 2001 From: zephyr-sh Date: Tue, 14 Jan 2025 20:03:16 +0800 Subject: [PATCH 2/6] [C] Update github workflow --- .github/workflows/pull_request.yml | 103 ++++++++++++++++++++--------- docker/pr.dockerfile | 17 +++++ 2 files changed, 88 insertions(+), 32 deletions(-) create mode 100644 docker/pr.dockerfile diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 7b6b77a..a4724c7 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -1,57 +1,96 @@ name: Pull Request on: - workflow_dispatch: pull_request: - branches: - - main + branches: [main] + paths-ignore: + - "**/version.h" + - "doc/**" + - "**.md" + +env: + LOCAL_REGISTRY: localhost:5000 + DOCKERFILE: docker/pr.dockerfile jobs: - test: - name: Run Tests + get_runner_and_uid: + name: Get Runner runs-on: [self-hosted, unicorn] + steps: + - name: Get UID and GID + id: uid_gid + run: | + echo "uid_gid=$(id -u):$(id -g)" >> "$GITHUB_OUTPUT" + outputs: + runner: ${{ runner.name }} + uid: ${{ steps.uid_gid.outputs.uid_gid }} + + build_docker_image: + name: Build Docker Image + needs: [get_runner_and_uid] + runs-on: ${{ needs.get_runner_and_uid.outputs.runner }} + + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Build Docker Image + id: docker_build + uses: docker/build-push-action@v3 + with: + file: ${{ env.DOCKERFILE }} + no-cache: false + push: true + tags: ${{ env.LOCAL_REGISTRY }}/capybara-container:ci + + outputs: + image: ${{ env.LOCAL_REGISTRY }}/capybara-container:ci + + ci: + name: CI + needs: [get_runner_and_uid, build_docker_image] + runs-on: ${{ needs.get_runner_and_uid.outputs.runner }} strategy: matrix: python-version: - "3.10" + container: + image: ${{ needs.build_docker_image.outputs.image }} + options: --user ${{ needs.get_runner_and_uid.outputs.uid }} --gpus all + steps: - - name: Checkout code + - name: Checkout Repository uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install packages + - name: Install Dependencies run: | - python -m pip install pytest wheel pylint pylint-flask - python setup.py bdist_wheel - wheel_file=$(echo dist/*.whl) - python -m pip install ${wheel_file} --force-reinstall + python3 -m pip install pytest wheel pylint pylint-flask pytest-cov typeguard - - name: Lint with pylint + - name: Build and Install Package run: | - python -m pylint ${{ github.workspace }}/capybara \ - --rcfile=.github/workflows/.pylintrc \ - --load-plugins pylint_flask \ + python3 setup.py bdist_wheel && \ + wheel_file=$(ls dist/*.whl 2>/dev/null || echo '') && \ + if [ -z "$wheel_file" ]; then + echo 'Error: No wheel file found in dist directory.' && exit 1 + fi && \ + python3 -m pip install $wheel_file --force-reinstall - - name: Test with pytest + - name: Lint with Pylint run: | - python -m pip install pytest pytest-cov typeguard + python3 -m pylint capybara \ + --rcfile=.github/workflows/.pylintrc \ + --load-plugins pylint_flask - # Test all - python -m pytest tests - - # Report all - mkdir -p tests/coverage - python -m pytest -x \ - --junitxml=tests/coverage/cov-jumitxml.xml \ - --cov=capybara tests | tee tests/coverage/cov.txt + - name: Run Tests with Pytest + run: | + mkdir -p tests/coverage && \ + python3 -m pytest tests --junitxml=tests/coverage/cov-junitxml.xml \ + --cov=capybara | tee tests/coverage/cov.txt - - name: Pytest coverage comment + - name: Pytest Coverage Comment id: coverageComment uses: MishaKav/pytest-coverage-comment@main with: + github-token: ${{ secrets.GITHUB_TOKEN }} pytest-coverage-path: tests/coverage/cov.txt - junitxml-path: tests/coverage/cov-jumitxml.xml + junitxml-path: tests/coverage/cov-junitxml.xml diff --git a/docker/pr.dockerfile b/docker/pr.dockerfile new file mode 100644 index 0000000..14e2aa0 --- /dev/null +++ b/docker/pr.dockerfile @@ -0,0 +1,17 @@ +# syntax=docker/dockerfile:experimental +FROM nvcr.io/nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 + +ENV PYTHONDONTWRITEBYTECODE=1 \ + DEBIAN_FRONTEND=noninteractive \ + TZ=Asia/Taipei + +RUN apt-get update -y && apt-get upgrade -y && \ + apt-get install -y --no-install-recommends \ + tzdata wget git libturbojpeg exiftool ffmpeg poppler-utils libpng-dev \ + libtiff5-dev libjpeg8-dev libopenjp2-7-dev zlib1g-dev gcc \ + libfreetype6-dev liblcms2-dev libwebp-dev tcl8.6-dev tk8.6-dev python3-tk \ + python3-pip libharfbuzz-dev libfribidi-dev libxcb1-dev libfftw3-dev gosu \ + libpq-dev python3-dev && \ + ln -sf /usr/share/zoneinfo/$TZ /etc/localtime && \ + dpkg-reconfigure -f noninteractive tzdata && \ + apt-get clean && rm -rf /var/lib/apt/lists/* From b767d2342e11075e35e349bcec2229f8ccf5916a Mon Sep 17 00:00:00 2001 From: zephyr-sh Date: Wed, 15 Jan 2025 09:34:59 +0800 Subject: [PATCH 3/6] [C] Update `with-statement` in `ONNXEngineIOBinding._init_io_infos` --- capybara/onnxengine/engine_io_binding.py | 48 ++++++++++++++++-------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/capybara/onnxengine/engine_io_binding.py b/capybara/onnxengine/engine_io_binding.py index 7833588..06921e6 100644 --- a/capybara/onnxengine/engine_io_binding.py +++ b/capybara/onnxengine/engine_io_binding.py @@ -48,7 +48,8 @@ def __init__( sess_options = self._get_session_info(session_option) # setting onnxruntime session - model_path = str(model_path) if isinstance(model_path, Path) else model_path + model_path = str(model_path) if isinstance( + model_path, Path) else model_path self.sess = ort.InferenceSession( model_path, sess_options=sess_options, @@ -62,9 +63,11 @@ def __init__( self.providers = self.sess.get_providers() self.provider_options = self.sess.get_provider_options() - input_infos, output_infos = self._init_io_infos(model_path, input_initializer) + input_infos, output_infos = self._init_io_infos( + model_path, input_initializer) - io_binding, x_ortvalues, y_ortvalues = self._setup_io_binding(input_infos, output_infos) + io_binding, x_ortvalues, y_ortvalues = self._setup_io_binding( + input_infos, output_infos) self.io_binding = io_binding self.x_ortvalues = x_ortvalues self.y_ortvalues = y_ortvalues @@ -100,20 +103,28 @@ def _get_session_info( return sess_opt def _init_io_infos(self, model_path, input_initializer: dict): - sess = ort.InferenceSession( - model_path, - providers=['CPUExecutionProvider'], - ) - outs = sess.run(None, input_initializer) - input_shapes = {k: v.shape for k, v in input_initializer.items()} - output_shapes = {x.name: o.shape for x, o in zip(sess.get_outputs(), outs)} + + try: + with ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) as cpu_sess: + outs = cpu_sess.run(None, input_initializer) + input_shapes = {k: v.shape for k, + v in input_initializer.items()} + output_shapes = { + x.name: o.shape + for x, o in zip(cpu_sess.get_outputs(), outs) + } + except Exception as e: + raise RuntimeError(f"Failed to run CPU check session: {str(e)}") + input_infos = get_onnx_input_infos(model_path) output_infos = get_onnx_output_infos(model_path) + for k, v in input_infos.items(): v['shape'] = input_shapes[k] + for k, v in output_infos.items(): v['shape'] = output_shapes[k] - del sess + return input_infos, output_infos def _setup_io_binding(self, input_infos, output_infos): @@ -121,10 +132,12 @@ def _setup_io_binding(self, input_infos, output_infos): y_ortvalues = {} for k, v in input_infos.items(): m = np.zeros(**v) - x_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type='cuda', device_id=self.device_id) + x_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy( + m, device_type='cuda', device_id=self.device_id) for k, v in output_infos.items(): m = np.zeros(**v) - y_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy(m, device_type='cuda', device_id=self.device_id) + y_ortvalues[k] = ort.OrtValue.ortvalue_from_numpy( + m, device_type='cuda', device_id=self.device_id) io_binding = self.sess.io_binding() for k, v in x_ortvalues.items(): @@ -144,12 +157,14 @@ def format_nested_dict(dict_data, indent=0): for k, v in dict_data.items(): prefix = " " * indent if isinstance(v, dict): - info += f"{prefix}{k}:\n" + format_nested_dict(v, indent + 1) + info += f"{prefix}{k}:\n" + \ + format_nested_dict(v, indent + 1) elif isinstance(v, str) and v.startswith('{') and v.endswith('}'): try: nested_dict = eval(v) if isinstance(nested_dict, dict): - info += f"{prefix}{k}:\n" + format_nested_dict(nested_dict, indent + 1) + info += f"{prefix}{k}:\n" + \ + format_nested_dict(nested_dict, indent + 1) else: info += f"{prefix}{k}: {v}\n" except: @@ -175,5 +190,6 @@ def format_nested_dict(dict_data, indent=0): divider = colored.stylize( f"+{'-' * divider_length}+", [colored.fg('blue'), colored.attr('bold')]) - infos = f'\n\n{divider}\n|{" " * left_padding}{styled_title}{" " * right_padding}|\n{divider}\n\n{path}\n\n{input_info}\n{output_info}\n\n{metadata}\n\n{providers}\n\n{provider_options}\n{divider}' + infos = f'\n\n{divider}\n|{" " * left_padding}{styled_title}{" " * right_padding}|\n{divider}\n\n{ + path}\n\n{input_info}\n{output_info}\n\n{metadata}\n\n{providers}\n\n{provider_options}\n{divider}' return infos From 57bb7662810223a787ac3c5ce0a3d230d6a29279 Mon Sep 17 00:00:00 2001 From: zephyr-sh Date: Wed, 15 Jan 2025 09:35:16 +0800 Subject: [PATCH 4/6] [C] Update testing times in ort engine --- tests/onnxruntime/test_engine.py | 4 ++-- tests/onnxruntime/test_engine_io_binding.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/onnxruntime/test_engine.py b/tests/onnxruntime/test_engine.py index ec2637a..ea9492a 100644 --- a/tests/onnxruntime/test_engine.py +++ b/tests/onnxruntime/test_engine.py @@ -1,12 +1,12 @@ import numpy as np -from capybara import ONNXEngine, Timer +from capybara import ONNXEngine def test_ONNXEngine(): model_path = "tests/resources/model.onnx" engine = ONNXEngine(model_path, backend='cuda') - for i in range(30): + for i in range(5): xs = {'inputs': np.random.randn(32, 3, 640, 640).astype('float32')} outs = engine(**xs) if i: diff --git a/tests/onnxruntime/test_engine_io_binding.py b/tests/onnxruntime/test_engine_io_binding.py index e3ba003..e1dcaa5 100644 --- a/tests/onnxruntime/test_engine_io_binding.py +++ b/tests/onnxruntime/test_engine_io_binding.py @@ -1,13 +1,14 @@ import numpy as np -from capybara import ONNXEngine, ONNXEngineIOBinding, Timer +from capybara import ONNXEngineIOBinding def test_ONNXEngineIOBinding(): model_path = "tests/resources/model.onnx" - input_initializer = {'inputs': np.random.randn(32, 3, 640, 640).astype('float32')} + input_initializer = {'inputs': np.random.randn( + 32, 3, 640, 640).astype('float32')} engine = ONNXEngineIOBinding(model_path, input_initializer) - for i in range(30): + for i in range(5): xs = {'inputs': np.random.randn(32, 3, 640, 640).astype('float32')} outs = engine(**xs) if i: From 225d50bc92a26bd95868da05c489d5541f71b6fc Mon Sep 17 00:00:00 2001 From: zephyr-sh Date: Wed, 15 Jan 2025 09:49:54 +0800 Subject: [PATCH 5/6] [F] Fixed pylint error --- .github/workflows/.pylintrc | 4 +- capybara/onnxengine/engine.py | 75 +++++++++++++++-------- capybara/onnxengine/engine_io_binding.py | 77 +++++++++++++++--------- 3 files changed, 102 insertions(+), 54 deletions(-) diff --git a/.github/workflows/.pylintrc b/.github/workflows/.pylintrc index 12ccc22..1abcb48 100644 --- a/.github/workflows/.pylintrc +++ b/.github/workflows/.pylintrc @@ -535,5 +535,5 @@ preferred-modules= # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception +overgeneral-exceptions=builtins.BaseException, + builtins.Exception diff --git a/capybara/onnxengine/engine.py b/capybara/onnxengine/engine.py index 5e77157..18c7393 100644 --- a/capybara/onnxengine/engine.py +++ b/capybara/onnxengine/engine.py @@ -118,43 +118,68 @@ def _get_provider_info( return providers, provider_option def __repr__(self) -> str: + import re + + def strip_ansi_codes(text): + """Remove ANSI escape codes from a string.""" + ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + return ansi_escape.sub('', text) + def format_nested_dict(dict_data, indent=0): - info = "" - for k, v in dict_data.items(): - prefix = " " * indent - if isinstance(v, dict): - info += f"{prefix}{k}:\n" + \ - format_nested_dict(v, indent + 1) - elif isinstance(v, str) and v.startswith('{') and v.endswith('}'): + """Recursively format nested dictionaries with indentation.""" + info = [] + prefix = " " * indent + for key, value in dict_data.items(): + if isinstance(value, dict): + info.append(f"{prefix}{key}:") + info.append(format_nested_dict(value, indent + 1)) + elif isinstance(value, str) and value.startswith('{') and value.endswith('}'): try: - nested_dict = eval(v) + nested_dict = eval(value) if isinstance(nested_dict, dict): - info += f"{prefix}{k}:\n" + \ - format_nested_dict(nested_dict, indent + 1) + info.append(f"{prefix}{key}:") + info.append(format_nested_dict( + nested_dict, indent + 1)) else: - info += f"{prefix}{k}: {v}\n" - except: - info += f"{prefix}{k}: {v}\n" + info.append(f"{prefix}{key}: {value}") + except Exception: + info.append(f"{prefix}{key}: {value}") else: - info += f"{prefix}{k}: {v}\n" - return info + info.append(f"{prefix}{key}: {value}") + return "\n".join(info) title = 'DOCSAID X ONNXRUNTIME' + divider_length = 50 + divider = f"+{'-' * divider_length}+" styled_title = colored.stylize( title, [colored.fg('blue'), colored.attr('bold')]) - divider_length = 50 - title_length = len(title) - left_padding = (divider_length - title_length) // 2 - right_padding = divider_length - title_length - left_padding - path = f'Model Path: {self.model_path}' + def center_text(text, width): + """Center text within a fixed width, handling ANSI escape codes.""" + plain_text = strip_ansi_codes(text) + text_length = len(plain_text) + left_padding = (width - text_length) // 2 + right_padding = width - text_length - left_padding + return f"|{' ' * left_padding}{text}{' ' * right_padding}|" + + path = f"Model Path: {self.model_path}" input_info = format_nested_dict(self.input_infos) output_info = format_nested_dict(self.output_infos) metadata = format_nested_dict(self.metadata) - providers = f'Provider: {", ".join(self.providers)}' + providers = f"Provider: {', '.join(self.providers)}" provider_options = format_nested_dict(self.provider_options) - divider = colored.stylize( - f"+{'-' * divider_length}+", [colored.fg('blue'), colored.attr('bold')]) - infos = f'\n\n{divider}\n|{" " * left_padding}{styled_title}{" " * right_padding}|\n{divider}\n\n{path}\n\n{input_info}\n{output_info}\n\n{metadata}\n\n{providers}\n\n{provider_options}\n{divider}' - return infos + sections = [ + divider, + center_text(styled_title, divider_length), + divider, + path, + input_info, + output_info, + metadata, + providers, + provider_options, + divider, + ] + + return "\n\n".join(sections) diff --git a/capybara/onnxengine/engine_io_binding.py b/capybara/onnxengine/engine_io_binding.py index 06921e6..a90c241 100644 --- a/capybara/onnxengine/engine_io_binding.py +++ b/capybara/onnxengine/engine_io_binding.py @@ -1,4 +1,3 @@ -from enum import Enum from pathlib import Path from typing import Any, Dict, Union @@ -152,44 +151,68 @@ def _update_x_ortvalues(self, xs: dict): v.update_inplace(xs[k]) def __repr__(self) -> str: + import re + + def strip_ansi_codes(text): + """Remove ANSI escape codes from a string.""" + ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + return ansi_escape.sub('', text) + def format_nested_dict(dict_data, indent=0): - info = "" - for k, v in dict_data.items(): - prefix = " " * indent - if isinstance(v, dict): - info += f"{prefix}{k}:\n" + \ - format_nested_dict(v, indent + 1) - elif isinstance(v, str) and v.startswith('{') and v.endswith('}'): + """Recursively format nested dictionaries with indentation.""" + info = [] + prefix = " " * indent + for key, value in dict_data.items(): + if isinstance(value, dict): + info.append(f"{prefix}{key}:") + info.append(format_nested_dict(value, indent + 1)) + elif isinstance(value, str) and value.startswith('{') and value.endswith('}'): try: - nested_dict = eval(v) + nested_dict = eval(value) if isinstance(nested_dict, dict): - info += f"{prefix}{k}:\n" + \ - format_nested_dict(nested_dict, indent + 1) + info.append(f"{prefix}{key}:") + info.append(format_nested_dict( + nested_dict, indent + 1)) else: - info += f"{prefix}{k}: {v}\n" - except: - info += f"{prefix}{k}: {v}\n" + info.append(f"{prefix}{key}: {value}") + except Exception: + info.append(f"{prefix}{key}: {value}") else: - info += f"{prefix}{k}: {v}\n" - return info + info.append(f"{prefix}{key}: {value}") + return "\n".join(info) title = 'DOCSAID X ONNXRUNTIME' + divider_length = 50 + divider = f"+{'-' * divider_length}+" styled_title = colored.stylize( title, [colored.fg('blue'), colored.attr('bold')]) - divider_length = 50 - title_length = len(title) - left_padding = (divider_length - title_length) // 2 - right_padding = divider_length - title_length - left_padding - path = f'Model Path: {self.model_path}' + def center_text(text, width): + """Center text within a fixed width, handling ANSI escape codes.""" + plain_text = strip_ansi_codes(text) + text_length = len(plain_text) + left_padding = (width - text_length) // 2 + right_padding = width - text_length - left_padding + return f"|{' ' * left_padding}{text}{' ' * right_padding}|" + + path = f"Model Path: {self.model_path}" input_info = format_nested_dict(self.input_infos) output_info = format_nested_dict(self.output_infos) metadata = format_nested_dict(self.metadata) - providers = f'Provider: {", ".join(self.providers)}' + providers = f"Provider: {', '.join(self.providers)}" provider_options = format_nested_dict(self.provider_options) - divider = colored.stylize( - f"+{'-' * divider_length}+", [colored.fg('blue'), colored.attr('bold')]) - infos = f'\n\n{divider}\n|{" " * left_padding}{styled_title}{" " * right_padding}|\n{divider}\n\n{ - path}\n\n{input_info}\n{output_info}\n\n{metadata}\n\n{providers}\n\n{provider_options}\n{divider}' - return infos + sections = [ + divider, + center_text(styled_title, divider_length), + divider, + path, + input_info, + output_info, + metadata, + providers, + provider_options, + divider, + ] + + return "\n\n".join(sections) From 4c2279f99bead763451db807d134a66b8fca7df1 Mon Sep 17 00:00:00 2001 From: zephyr-sh Date: Wed, 15 Jan 2025 09:53:58 +0800 Subject: [PATCH 6/6] [C] Revert `_init_io_infos` --- capybara/onnxengine/engine_io_binding.py | 25 +++++++++--------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/capybara/onnxengine/engine_io_binding.py b/capybara/onnxengine/engine_io_binding.py index a90c241..79278a7 100644 --- a/capybara/onnxengine/engine_io_binding.py +++ b/capybara/onnxengine/engine_io_binding.py @@ -102,28 +102,21 @@ def _get_session_info( return sess_opt def _init_io_infos(self, model_path, input_initializer: dict): - - try: - with ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) as cpu_sess: - outs = cpu_sess.run(None, input_initializer) - input_shapes = {k: v.shape for k, - v in input_initializer.items()} - output_shapes = { - x.name: o.shape - for x, o in zip(cpu_sess.get_outputs(), outs) - } - except Exception as e: - raise RuntimeError(f"Failed to run CPU check session: {str(e)}") - + sess = ort.InferenceSession( + model_path, + providers=['CPUExecutionProvider'], + ) + outs = sess.run(None, input_initializer) + input_shapes = {k: v.shape for k, v in input_initializer.items()} + output_shapes = {x.name: o.shape for x, + o in zip(sess.get_outputs(), outs)} input_infos = get_onnx_input_infos(model_path) output_infos = get_onnx_output_infos(model_path) - for k, v in input_infos.items(): v['shape'] = input_shapes[k] - for k, v in output_infos.items(): v['shape'] = output_shapes[k] - + del sess return input_infos, output_infos def _setup_io_binding(self, input_infos, output_infos):