diff --git a/OnnxBridge/utils/optimizations.py b/OnnxBridge/utils/optimizations.py index 39d761b4..b24f88f7 100644 --- a/OnnxBridge/utils/optimizations.py +++ b/OnnxBridge/utils/optimizations.py @@ -26,7 +26,17 @@ def numpy_float_array_to_float_val_str_nchw(input_array): def numpy_float_array_to_float_val_str_nhwc(input_array): chunk = [] - if len(input_array.shape) == 4: + if len(input_array.shape) == 5: + co, ci, d, h, w = input_array.shape + arr = np.zeros([co, d, h, w, ci]) + for i in range(co): + for j in range(ci): + for k in range(d): + for l in range(h): + for m in range(w): + arr[i][k][l][m][j] = input_array[i][j][k][l][m] + input_array = arr + elif len(input_array.shape) == 4: co, ci, h, w = input_array.shape arr = np.zeros([co, h, w, ci]) for i in range(co):