From 7d967b9dbfbae401d2723ed4dd40835b5ff08171 Mon Sep 17 00:00:00 2001 From: WolframRhodium Date: Wed, 12 Apr 2023 08:15:31 +0800 Subject: [PATCH] vsncnn/vs_ncnn.cpp: add params `use_ncnn_network_format` and `ncnn_shape_hint` https://github.com/AmusementClub/vs-mlrt/issues/39 --- vsncnn/vs_ncnn.cpp | 188 +++++++++++++++++++++++++++++++++------------ 1 file changed, 141 insertions(+), 47 deletions(-) diff --git a/vsncnn/vs_ncnn.cpp b/vsncnn/vs_ncnn.cpp index 8b06821..c37c111 100644 --- a/vsncnn/vs_ncnn.cpp +++ b/vsncnn/vs_ncnn.cpp @@ -35,6 +35,18 @@ extern std::variant loadONNX( ) noexcept; +#ifdef _WIN32 +#include +#include +static inline std::wstring translateName(const char *name) noexcept { + std::wstring_convert> converter; + return converter.from_bytes(name); +} +#else +#define translateName(n) (n) +#endif + + static const VSPlugin * myself = nullptr; @@ -497,6 +509,31 @@ static void VS_CC vsNcnnCreate( path_is_serialization = false; } + bool use_ncnn_network_format = !!vsapi->propGetInt(in, "use_ncnn_network_format", 0, &error); + if (error) { + use_ncnn_network_format = false; + } + if (use_ncnn_network_format && path_is_serialization) { + return set_error( + "\"use_ncnn_network_format\" and \"path_is_serialization\" " + "should not be enabled at the same time" + ); + } + + // ncnn related code + if (auto device = ncnn::get_gpu_device(device_id); device != nullptr) { + d->device = device; + } else { + return set_error("get_gpu_device failed"); + } + + d->net.opt.num_threads = 1; + d->net.opt.use_vulkan_compute = true; + d->net.opt.use_fp16_packed = d->fp16; + d->net.opt.use_fp16_storage = d->fp16; + d->net.opt.use_int8_storage = false; + d->net.set_vulkan_device(d->device); + std::string_view path_view; std::string path; if (path_is_serialization) { @@ -518,22 +555,109 @@ static void VS_CC vsNcnnCreate( path_view = path; } - auto result = loadONNX(path_view, tile_w, tile_h, path_is_serialization); - if (std::holds_alternative(result)) { - return set_error(std::get(result)); - } + if (use_ncnn_network_format) { + if (vsapi->propNumElements(in, "ncnn_shape_hint") != 6) { + return set_error("\"ncnn_shape_hint\" must be specified as [in_c, in_h, in_w, out_c, out_h, out_w]"); + } + + auto ncnn_shape_hint = vsapi->propGetIntArray(in, "ncnn_shape_hint", nullptr); + d->in_tile_c = int64ToIntS(ncnn_shape_hint[0]); + d->in_tile_h = int64ToIntS(ncnn_shape_hint[1]); + d->in_tile_w = int64ToIntS(ncnn_shape_hint[2]); + d->out_tile_c = int64ToIntS(ncnn_shape_hint[3]); + d->out_tile_h = int64ToIntS(ncnn_shape_hint[4]); + d->out_tile_w = int64ToIntS(ncnn_shape_hint[5]); + + auto dot_index = static_cast(path_view.size() - 1); + while (dot_index >= 0 && path_view[dot_index] != '.') { + dot_index--; + } + if (dot_index < 0) { + return set_error("invalid \"network_path\""); + } + + auto temp = std::string{path_view.substr(0, dot_index + 1)} + "param"; + char * ncnn_bin; + + { + std::ifstream param_stream( + translateName(temp.c_str()), + std::ios::binary | std::ios::ate + ); + + if (!param_stream.good()) { + return set_error("open param failed"); + } + + auto size = param_stream.tellg(); + ncnn_bin = reinterpret_cast(vs_aligned_malloc(size, sizeof(void *))); + param_stream.seekg(0); + param_stream.read(ncnn_bin, size); + } + + if (d->net.load_param_mem(ncnn_bin) != 0) { + vs_aligned_free(ncnn_bin); + return set_error("load param failed"); + } + + vs_aligned_free(ncnn_bin); + + temp = std::string{path_view.substr(0, dot_index + 1)} + "bin"; + + { + std::ifstream bin_stream( + translateName(temp.c_str()), + std::ios::binary | std::ios::ate + ); + + if (!bin_stream.good()) { + return set_error("open weights failed"); + } + + auto size = bin_stream.tellg(); + ncnn_bin = reinterpret_cast(vs_aligned_malloc(size, sizeof(void *))); + bin_stream.seekg(0); + bin_stream.read(ncnn_bin, size); - auto onnx_model = std::move(std::get(result)); - { - const auto & input_shape = onnx_model.graph().input(0).type().tensor_type().shape(); - d->in_tile_c = int64ToIntS(input_shape.dim(1).dim_value()); - d->in_tile_h = int64ToIntS(input_shape.dim(2).dim_value()); - d->in_tile_w = int64ToIntS(input_shape.dim(3).dim_value()); - - const auto & output_shape = onnx_model.graph().output(0).type().tensor_type().shape(); - d->out_tile_c = int64ToIntS(output_shape.dim(1).dim_value()); - d->out_tile_h = int64ToIntS(output_shape.dim(2).dim_value()); - d->out_tile_w = int64ToIntS(output_shape.dim(3).dim_value()); + d->net.load_model(reinterpret_cast(ncnn_bin)); + } + + vs_aligned_free(ncnn_bin); + } else { + auto result = loadONNX(path_view, tile_w, tile_h, path_is_serialization); + if (std::holds_alternative(result)) { + return set_error(std::get(result)); + } + + auto onnx_model = std::move(std::get(result)); + { + const auto & input_shape = onnx_model.graph().input(0).type().tensor_type().shape(); + d->in_tile_c = int64ToIntS(input_shape.dim(1).dim_value()); + d->in_tile_h = int64ToIntS(input_shape.dim(2).dim_value()); + d->in_tile_w = int64ToIntS(input_shape.dim(3).dim_value()); + + const auto & output_shape = onnx_model.graph().output(0).type().tensor_type().shape(); + d->out_tile_c = int64ToIntS(output_shape.dim(1).dim_value()); + d->out_tile_h = int64ToIntS(output_shape.dim(2).dim_value()); + d->out_tile_w = int64ToIntS(output_shape.dim(3).dim_value()); + } + + auto ncnn_result = onnx2ncnn(onnx_model); + if (!ncnn_result.has_value()) { + return set_error("onnx2ncnn failed"); + } + + const auto & [ncnn_param, ncnn_model_bin] = ncnn_result.value(); + + if (d->net.load_param_mem(ncnn_param) != 0) { + vs_aligned_free(ncnn_param); + vs_aligned_free(ncnn_model_bin); + return set_error("load param failed"); + } + vs_aligned_free(ncnn_param); + // TODO: here returns the number of bytes read successfully + d->net.load_model(ncnn_model_bin); + vs_aligned_free(ncnn_model_bin); } d->out_vi = std::make_unique(*in_vis.front()); // mutable @@ -546,38 +670,6 @@ static void VS_CC vsNcnnCreate( d->out_vi->format = vsapi->registerFormat(cmRGB, stFloat, 32, 0, 0, core); } - auto ncnn_result = onnx2ncnn(onnx_model); - if (!ncnn_result.has_value()) { - return set_error("onnx2ncnn failed"); - } - - const auto & [ncnn_param, ncnn_model_bin] = ncnn_result.value(); - - // ncnn related code - if (auto device = ncnn::get_gpu_device(device_id); device != nullptr) { - d->device = device; - } else { - vs_aligned_free(ncnn_param); - vs_aligned_free(ncnn_model_bin); - return set_error("get_gpu_device failed"); - } - - d->net.opt.num_threads = 1; - d->net.opt.use_vulkan_compute = true; - d->net.opt.use_fp16_packed = d->fp16; - d->net.opt.use_fp16_storage = d->fp16; - d->net.opt.use_int8_storage = false; - d->net.set_vulkan_device(d->device); - if (d->net.load_param_mem(ncnn_param) != 0) { - vs_aligned_free(ncnn_param); - vs_aligned_free(ncnn_model_bin); - return set_error("load param failed"); - } - vs_aligned_free(ncnn_param); - // TODO: here returns the number of bytes read successfully - d->net.load_model(ncnn_model_bin); - vs_aligned_free(ncnn_model_bin); - d->input_index = d->net.input_indexes().front(); d->output_index = d->net.output_indexes().front(); @@ -634,6 +726,8 @@ VS_EXTERNAL_API(void) VapourSynthPluginInit( "builtindir:data:opt;" "fp16:int:opt;" "path_is_serialization:int:opt;" + "use_ncnn_network_format:int:opt;" + "ncnn_shape_hint:int[]:opt;" , vsNcnnCreate, nullptr, plugin