Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor code cleanup and refactoring #172

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
@@ -491,7 +491,11 @@ def list_supported_model_files(self):
"target_stem": model_scores.get(filename, {}).get("target_stem"),
"download_files": [filename],
} # Just the filename for MDX models
for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items()
for name, filename in {
**model_downloads_list["mdx_download_list"],
**model_downloads_list["mdx_download_vip_list"],
**audio_separator_models_list["mdx_download_list"],
}.items()
},
"Demucs": demucs_models,
"MDXC": {
1 change: 0 additions & 1 deletion audio_separator/separator/uvr_lib_v5/demucs/apply.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@
from concurrent.futures import ThreadPoolExecutor
import random
import typing as tp
from multiprocessing import Process, Queue, Pipe

import torch as th
from torch import nn
1 change: 0 additions & 1 deletion audio_separator/separator/uvr_lib_v5/playsound.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,6 @@ def _playsoundWin(sound, block = True):
sound = '"' + _canonicalizePath(sound) + '"'

from ctypes import create_unicode_buffer, windll, wintypes
from time import sleep
windll.winmm.mciSendStringW.argtypes = [wintypes.LPCWSTR, wintypes.LPWSTR, wintypes.UINT, wintypes.HANDLE]
windll.winmm.mciGetErrorStringW.argtypes = [wintypes.DWORD, wintypes.LPWSTR, wintypes.UINT]

8 changes: 6 additions & 2 deletions audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
import json
import sys
from importlib import metadata
from typing import Optional


def main():
@@ -215,7 +214,12 @@ def main():
"post_process_threshold": args.vr_post_process_threshold,
"high_end_process": args.vr_high_end_process,
},
demucs_params={"segment_size": args.demucs_segment_size, "shifts": args.demucs_shifts, "overlap": args.demucs_overlap, "segments_enabled": args.demucs_segments_enabled},
demucs_params={
"segment_size": args.demucs_segment_size,
"shifts": args.demucs_shifts,
"overlap": args.demucs_overlap,
"segments_enabled": args.demucs_segments_enabled,
},
mdxc_params={
"segment_size": args.mdxc_segment_size,
"batch_size": args.mdxc_batch_size,
1 change: 0 additions & 1 deletion tests/model-metrics/test-all-models.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
import numpy as np
import soundfile as sf
from audio_separator.separator import Separator
from pathlib import Path
import json
import logging
import musdb
49 changes: 1 addition & 48 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
@@ -175,7 +175,7 @@ def test_cli_normalization_threshold_argument(common_expected_args):
mock_separator.assert_called_once_with(**expected_args)


# Test using normalization_threshold argument
# Test using amplification_threshold argument
def test_cli_amplification_threshold_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--amplification=0.75"]
with patch("sys.argv", test_args):
@@ -226,23 +226,6 @@ def test_cli_invert_spectrogram_argument(common_expected_args):
mock_separator.assert_called_once_with(**expected_args)


# Test using use_autocast argument
def test_cli_use_autocast_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--use_autocast"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Update expected args for this specific test
expected_args = common_expected_args.copy()
expected_args["use_autocast"] = True

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)


# Test using use_autocast argument
def test_cli_use_autocast_argument(common_expected_args):
test_args = ["cli.py", "test_audio.mp3", "--use_autocast"]
@@ -260,36 +243,6 @@ def test_cli_use_autocast_argument(common_expected_args):


# Test using custom_output_names argument
def test_cli_Vocals_output_name_argument(common_expected_args):
custom_vocals_names = {"Vocals": "vocals_output"}
test_args = ["cli.py", "test_audio.mp3", f"--custom_output_names={json.dumps(custom_vocals_names)}"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", custom_output_names=custom_vocals_names)


# Test using custom_output_names argument
def test_cli_Instrumental_output_name_argument(common_expected_args):
custom_instrumental_names = {"Instrumental": "instrumental_output"}
test_args = ["cli.py", "test_audio.mp3", f"--custom_output_names={json.dumps(custom_instrumental_names)}"]
with patch("sys.argv", test_args):
with patch("audio_separator.separator.Separator") as mock_separator:
mock_separator_instance = mock_separator.return_value
mock_separator_instance.separate.return_value = ["output_file.mp3"]
main()

# Assertions
mock_separator.assert_called_once_with(**common_expected_args)
mock_separator_instance.separate.assert_called_once_with("test_audio.mp3", custom_output_names=custom_instrumental_names)


# Test using custom_output_names arguments
def test_cli_custom_output_names_argument(common_expected_args):
custom_names = {
"Vocals": "vocals_output",
19 changes: 3 additions & 16 deletions tests/unit/test_stft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import numpy as np
import torch
from unittest.mock import Mock, patch
from unittest.mock import Mock
from audio_separator.separator.uvr_lib_v5.stft import STFT

# Short-Time Fourier Transform (STFT) Process Overview:
@@ -121,29 +121,16 @@ def test_prepare_for_istft(self):
# Assertions
self.assertEqual(complex_tensor.shape, expected_shape)

def test_inverse_device_handling(self):
def test_inverse_stft(self):
# Create a mock tensor with the correct input shape
input_tensor = torch.rand(1, 2, 1025, 32) # shape matching output of STFT

# Initialize STFT
stft = STFT(logger=MockLogger(), n_fft=2048, hop_length=512, dim_f=1025, device="cpu")

# Apply inverse STFT
output_tensor = stft.inverse(input_tensor)
output_tensor = self.stft.inverse(input_tensor)

# Check if the output tensor is on the CPU
self.assertEqual(output_tensor.device.type, "cpu")

def test_inverse_output_shape(self):
# Create a mock tensor
input_tensor = torch.rand(1, 2, 1025, 32) # shape matching output of STFT

# Initialize STFT
stft = STFT(logger=MockLogger(), n_fft=2048, hop_length=512, dim_f=1025, device="cpu")

# Apply inverse STFT
output_tensor = stft.inverse(input_tensor)

# Expected output shape: (Batch size, Channel dimension, Time dimension)
expected_shape = (1, 2, 7936) # Calculated based on STFT parameters