Skip to content

Commit

Permalink
Fixed model architecture code import to only load the correct code fo…
Browse files Browse the repository at this point in the history
…r the model you're loading... added clear error message for anyone attempting to use Demucs with Python 3.9
  • Loading branch information
beveradb committed Mar 29, 2024
1 parent cf6b88d commit bcbbc51
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 164 deletions.
4 changes: 0 additions & 4 deletions audio_separator/separator/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from .mdx_separator import MDXSeparator
from .vr_separator import VRSeparator
from .demucs_separator import DemucsSeparator
from .mdxc_separator import MDXCSeparator
20 changes: 13 additions & 7 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from importlib import metadata
import os
import sys
import platform
import subprocess
import time
import logging
import warnings
import importlib

import hashlib
import json
Expand Down Expand Up @@ -618,18 +620,22 @@ def load_model(self, model_filename="UVR-MDX-NET-Inst_HQ_3.onnx"):
"sample_rate": self.sample_rate,
}

if model_type not in self.arch_specific_params:
raise ValueError(f"Model type not supported (yet): {model_type}")

# Instantiate the appropriate separator class depending on the model type
separator_classes = {"MDX": "MDXSeparator", "VR": "VRSeparator", "Demucs": "DemucsSeparator", "MDXC": "MDXCSeparator"}
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}

if model_type not in separator_classes:
if model_type not in self.arch_specific_params or model_type not in separator_classes:
raise ValueError(f"Model type not supported (yet): {model_type}")

module = __import__("audio_separator.separator.architectures", fromlist=[separator_classes[model_type]])
if model_type == "Demucs" and sys.version_info < (3, 10):
raise Exception("Demucs models require Python version 3.10 or newer.")

self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")

module_name, class_name = separator_classes[model_type].split(".")
module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
separator_class = getattr(module, class_name)

separator_class = getattr(module, separator_classes[model_type])
self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])

# Log the completion of the model load process
Expand Down
Loading

0 comments on commit bcbbc51

Please sign in to comment.