Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PTH (VR Architecture) models and newer MDX models, restructure and refactor a lot #36

Merged
merged 16 commits into from
Feb 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed MPS hardware acceleration for VR models
beveradb committed Feb 4, 2024
commit ea6f41aabb344d25ff0eeeacd045236bb4f2e026
4 changes: 3 additions & 1 deletion audio_separator/separator/architectures/mdx_separator.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,6 @@ def __init__(self, common_config, arch_config):
self.hop_length = arch_config.get("hop_length")
self.segment_size = arch_config.get("segment_size")
self.overlap = arch_config.get("overlap")
self.batch_size = arch_config.get("batch_size")

# Initializing model parameters
self.compensate = self.model_data["compensate"]
@@ -38,6 +37,9 @@ def __init__(self, common_config, arch_config):
self.logger.debug(f"Model params: batch_size={self.batch_size}, compensate={self.compensate}, segment_size={self.segment_size}, dim_f={self.dim_f}, dim_t={self.dim_t}")
self.logger.debug(f"Model params: n_fft={self.n_fft}, hop={self.hop_length}")

# self.logger.warning("Torch MPS backend does not yet support FFT operations, Torch will still use CPU!")
# self.torch_device = self.torch_device_cpu

# Loading the model for inference
self.logger.debug("Loading ONNX model for inference...")
if self.segment_size == self.dim_t:
23 changes: 17 additions & 6 deletions audio_separator/separator/architectures/vr_separator.py
Original file line number Diff line number Diff line change
@@ -51,6 +51,9 @@ def __init__(self, common_config, arch_config):
# Note: Selecting this option can adversely affect the conversion process, depending on the track. Because of this, it is only recommended as a last resort.
self.is_post_process = self.model_data.get("is_post_process", False)

# post_process_threshold values = ('0.1', '0.2', '0.3')
self.post_process_threshold = 0.2

# '• Use GPU for Processing (if available):\n'
# ' - If checked, the application will attempt to use your GPU for faster processing.\n'
# ' - If a GPU is not detected, it will default to CPU processing.\n'
@@ -64,22 +67,20 @@ def __init__(self, common_config, arch_config):
# '• Higher values mean more RAM usage but slightly faster processing times.\n'
# '• Lower values mean less RAM usage but slightly longer processing times.\n'
# '• Batch size value has no effect on output quality.'
# Andrew note: for some reason, the HP_2 model run only worked with batch size set to 16, not 1 or 2
self.batch_size = self.model_data.get("batch_size", 16)

# 'Select window size to balance quality and speed:\n\n'
# '• 1024 - Quick but lesser quality.\n'
# '• 512 - Medium speed and quality.\n'
# '• 320 - Takes longer but may offer better quality.'
self.window_size = self.model_data.get("window_size", 1024)
self.window_size = self.model_data.get("window_size", 512)

# The application will mirror the missing frequency range of the output.
self.high_end_process = arch_config.get("high_end_process", "False")
self.input_high_end_h = None
self.input_high_end = None

# post_process_threshold values = ('0.1', '0.2', '0.3')
self.post_process_threshold = 0.2

# 'Adjust the intensity of primary stem extraction:\n\n'
# '• It ranges from -100 - 100.\n'
# '• Bigger values mean deeper extractions.\n'
@@ -97,11 +98,18 @@ def __init__(self, common_config, arch_config):
self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"]
self.is_vr_51_model = True

self.logger.debug(f"VR arch params: is_tta={self.is_tta}, is_post_process={self.is_post_process}, post_process_threshold={self.post_process_threshold}")
self.logger.debug(f"VR arch params: is_gpu_conversion={self.is_gpu_conversion}, batch_size={self.batch_size}, window_size={self.window_size}")
self.logger.debug(f"VR arch params: high_end_process={self.high_end_process}, aggression_setting={self.aggression_setting}")
self.logger.debug(f"VR arch params: is_vr_51_model={self.is_vr_51_model}, model_samplerate={self.model_samplerate}, model_capacity={self.model_capacity}")

self.model_run = lambda *args, **kwargs: self.logger.error("Model run method is not initialised yet.")

# This should go away once we refactor to remove soundfile.write and replace with pydub like we did for the MDX rewrite
self.wav_subtype = "PCM_16"

self.logger.info("")

def separate(self, audio_file_path):
"""
Separates the audio file into primary and secondary sources based on the model's configuration.
@@ -242,18 +250,21 @@ def inference_vr(self, X_spec, device, aggressiveness):
def _execute(X_mag_pad, roi_size):
X_dataset = []
patches = (X_mag_pad.shape[2] - 2 * self.model_run.offset) // roi_size
# total_iterations = patches // self.batch_size if not self.is_tta else (patches // self.batch_size) * 2

self.logger.debug(f"inference_vr iterating through {len(patches)} patches")
self.logger.debug(f"inference_vr appending to X_dataset for each of {patches} patches")
for i in tqdm(range(patches)):
start = i * roi_size
X_mag_window = X_mag_pad[:, :, start : start + self.window_size]
X_dataset.append(X_mag_window)

total_iterations = patches // self.batch_size if not self.is_tta else (patches // self.batch_size) * 2
self.logger.debug(f"inference_vr iterating through {total_iterations} batches, batch_size = {self.batch_size}")

X_dataset = np.asarray(X_dataset)
self.model_run.eval()
with torch.no_grad():
mask = []

for i in tqdm(range(0, patches, self.batch_size)):

X_batch = X_dataset[i : i + self.batch_size]
18 changes: 18 additions & 0 deletions audio_separator/separator/common_separator.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from pydub import AudioSegment
from audio_separator.separator.uvr_lib_v5 import spec_utils


class CommonSeparator:
"""
This class contains the common methods and attributes common to all architecture-specific Separator classes.
@@ -70,6 +71,13 @@ def __init__(self, config):
self.output_dir = config.get("output_dir")
self.output_format = config.get("output_format")

# 'Specify the number of batches to be processed at a time.\n\nNotes:\n\n'
# '• Higher values mean more RAM usage but slightly faster processing times.\n'
# '• Lower values mean less RAM usage but slightly longer processing times.\n'
# '• Batch size value has no effect on output quality.'
# BATCH_SIZE = ('1', ''2', '3', '4', '5', '6', '7', '8', '9', '10')
self.batch_size = config.get("batch_size")

# Functional options which are applicable to all architectures and the user may tweak to affect the output
self.normalization_threshold = config.get("normalization_threshold")
self.denoise_enabled = config.get("denoise_enabled")
@@ -103,6 +111,16 @@ def __init__(self, config):
# We haven't implemented support for the checkpoint models here, so we're not using it.
# self.dim_c = 4

self.logger.debug(f"Common params: model_name={self.model_name}, model_path={self.model_path}")
self.logger.debug(f"Common params: primary_stem_output_path={self.primary_stem_output_path}, secondary_stem_output_path={self.secondary_stem_output_path}")
self.logger.debug(f"Common params: output_dir={self.output_dir}, output_format={self.output_format}")
self.logger.debug(f"Common params: batch_size={self.batch_size}, normalization_threshold={self.normalization_threshold}")
self.logger.debug(f"Common params: denoise_enabled={self.denoise_enabled}, output_single_stem={self.output_single_stem}")
self.logger.debug(f"Common params: invert_using_spec={self.invert_using_spec}, sample_rate={self.sample_rate}")

self.logger.debug(f"Common params: primary_stem_name={self.primary_stem_name}, secondary_stem_name={self.secondary_stem_name}")
self.logger.debug(f"Common params: is_karaoke={self.is_karaoke}, is_bv_model={self.is_bv_model}, bv_model_rebalance={self.bv_model_rebalance}")

self.cached_sources_map = {}

def separate(self, audio_file_path):
9 changes: 5 additions & 4 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
@@ -224,8 +224,8 @@ def configure_mps(self, ort_providers):
self.logger.info("Apple Silicon MPS/CoreML is available in Torch, setting Torch device to MPS")
self.torch_device_mps = torch.device("mps")

self.logger.warning("Torch MPS backend does not yet support FFT operations, Torch will still use CPU!")
self.torch_device = self.torch_device_cpu
self.torch_device = self.torch_device_mps

if "CoreMLExecutionProvider" in ort_providers:
self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
@@ -496,6 +496,7 @@ def load_model(self, model_filename="UVR-MDX-NET-Inst_HQ_3.onnx", model_type=Non
"secondary_stem_output_path": self.secondary_stem_output_path,
"output_format": self.output_format,
"output_dir": self.output_dir,
"batch_size": self.batch_size,
"normalization_threshold": self.normalization_threshold,
"denoise_enabled": self.denoise_enabled,
"output_single_stem": self.output_single_stem,
@@ -505,12 +506,12 @@ def load_model(self, model_filename="UVR-MDX-NET-Inst_HQ_3.onnx", model_type=Non

# These are parameters which users may want to configure so we expose them to the top-level Separator class,
# even though they are specific to a single model architecture
arch_specific_params = {"MDX": {"hop_length": self.hop_length, "segment_size": self.segment_size, "overlap": self.overlap, "batch_size": self.batch_size}, "VR": {}}
arch_specific_params = {"MDX": {"hop_length": self.hop_length, "segment_size": self.segment_size, "overlap": self.overlap}, "VR": {}}

if model_type == "MDX":
self.model_instance = MDXSeparator(common_config=common_params, arch_config=arch_specific_params["MDX"])
elif model_type == "VR":
self.model_instance = VRSeparator(common_config=common_params, arch_config=arch_specific_params)
self.model_instance = VRSeparator(common_config=common_params, arch_config=arch_specific_params["VR"])
else:
raise ValueError(f"Unsupported model type: {model_type}")

2 changes: 1 addition & 1 deletion audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
@@ -55,7 +55,7 @@ def main():

parser.add_argument("--overlap", type=float, default=0.25, help="Optional: overlap (default: %(default)s). Example: --overlap=0.25")

parser.add_argument("--batch_size", type=int, default=1, help="Optional: batch_size (default: %(default)s). Example: --batch_size=1")
parser.add_argument("--batch_size", type=int, default=1, help="Optional: batch_size (default: %(default)s). Example: --batch_size=4")

args = parser.parse_args()