diff --git a/audio_separator/separator/architectures/mdx_separator.py b/audio_separator/separator/architectures/mdx_separator.py index f79ac69..9313151 100644 --- a/audio_separator/separator/architectures/mdx_separator.py +++ b/audio_separator/separator/architectures/mdx_separator.py @@ -1,7 +1,9 @@ """Module for separating audio sources using MDX architecture models.""" import os +import platform import torch +import onnx import onnxruntime as ort import numpy as np import onnx2torch @@ -121,7 +123,12 @@ def load_model(self): self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0] self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.") else: - self.model_run = onnx2torch.convert(self.model_path) + if platform.system() == 'Windows': + onnx_model = onnx.load(self.model_path) + self.model_run = onnx2torch.convert(onnx_model) + else: + self.model_run = onnx2torch.convert(self.model_path) + self.model_run.to(self.torch_device).eval() self.logger.warning("Model converted from onnx to pytorch due to segment size not matching dim_t, processing may be slower.")