From 3ad9e612834c78c8108ec19abd7f05e47afd30df Mon Sep 17 00:00:00 2001 From: Andrew Beveridge Date: Thu, 25 Apr 2024 16:25:21 -0400 Subject: [PATCH] Attempt to fix windows MDX ONNX loading tempfile bug in onnx2torch --- audio_separator/separator/architectures/mdx_separator.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/audio_separator/separator/architectures/mdx_separator.py b/audio_separator/separator/architectures/mdx_separator.py index f79ac69..9313151 100644 --- a/audio_separator/separator/architectures/mdx_separator.py +++ b/audio_separator/separator/architectures/mdx_separator.py @@ -1,7 +1,9 @@ """Module for separating audio sources using MDX architecture models.""" import os +import platform import torch +import onnx import onnxruntime as ort import numpy as np import onnx2torch @@ -121,7 +123,12 @@ def load_model(self): self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0] self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.") else: - self.model_run = onnx2torch.convert(self.model_path) + if platform.system() == 'Windows': + onnx_model = onnx.load(self.model_path) + self.model_run = onnx2torch.convert(onnx_model) + else: + self.model_run = onnx2torch.convert(self.model_path) + self.model_run.to(self.torch_device).eval() self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")