Skip to content

Commit

Permalink
Removed need for use_cuda and use_coreml parameters, hardware acceler…
Browse files Browse the repository at this point in the history
…ation is now detected and automatically configured more robustly
  • Loading branch information
beveradb committed Dec 21, 2023
1 parent c391ab0 commit 42c9d03
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 67 deletions.
88 changes: 46 additions & 42 deletions audio_separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def __init__(
output_dir=None,
primary_stem_path=None,
secondary_stem_path=None,
use_cuda=False,
use_coreml=False,
output_format="WAV",
output_subtype=None,
normalization_enabled=True,
Expand All @@ -48,14 +46,12 @@ def __init__(
self.logger.addHandler(self.log_handler)

self.logger.debug(
f"Separator instantiating with input file: {audio_file_path}, model_name: {model_name}, output_dir: {output_dir}, use_cuda: {use_cuda}, output_format: {output_format}"
f"Separator instantiating with input file: {audio_file_path}, model_name: {model_name}, output_dir: {output_dir}, output_format: {output_format}"
)

self.model_name = model_name
self.model_file_dir = model_file_dir
self.output_dir = output_dir
self.use_cuda = use_cuda
self.use_coreml = use_coreml
self.primary_stem_path = primary_stem_path
self.secondary_stem_path = secondary_stem_path

Expand Down Expand Up @@ -110,51 +106,57 @@ def __init__(
warnings.filterwarnings("ignore")
self.cpu = torch.device("cpu")

if self.use_cuda:
self.logger.debug("CUDA requested, checking Torch version and CUDA status")
self.logger.debug(f"Torch version: {str(torch.__version__)}")
# Prepare for hardware-accelerated inference by validating both Torch and ONNX Runtime support either CUDA or CoreML
self.logger.debug(f"Torch version: {str(torch.__version__)}")
ort_device = ort.get_device()
ort_providers = ort.get_available_providers()
hardware_acceleration_enabled = False

cuda_available = torch.cuda.is_available()
self.logger.debug(f"Is CUDA enabled for Torch? {str(cuda_available)}")
if torch.cuda.is_available():
self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
self.device = torch.device("cuda")

if cuda_available:
self.logger.info("Torch running in CUDA GPU mode")
self.device = torch.device("cuda")
self.run_type = ["CUDAExecutionProvider"]
if ort_device == "GPU" and "CUDAExecutionProvider" in ort_providers:
self.logger.info("ONNXruntime device is GPU with CUDAExecutionProvider available, enabling acceleration")
self.onnx_execution_provider = ["CUDAExecutionProvider"]
hardware_acceleration_enabled = True
else:
raise Exception("CUDA requested but not available with current Torch installation. Do you have an Nvidia GPU?")

# Check GPU inferencing is enabled for ONNXRuntime too, which is essential to actually use the GPU
ort_device = ort.get_device()
if ort_device == 'GPU':
self.logger.info("ONNX Runtime running in GPU mode")
else:
raise Exception("CUDA requested but not available with current ONNX Runtime installation. Try pip install --force-reinstall onnxruntime-gpu")

elif self.use_coreml:
self.logger.debug("Apple Silicon CoreML requested, checking Torch version")
self.logger.debug(f"Torch version: {str(torch.__version__)}")
self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
self.logger.warning("If you expect CUDA to work with your GPU, try pip install --force-reinstall onnxruntime-gpu")
else:
self.logger.debug(
"CUDA not available with Torch installation. If you have an Nvidia GPU and expect CUDA support to work, try: "
)
self.logger.debug(
"pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
)

mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
self.logger.debug(f"Is Apple Silicon CoreML MPS available? {str(mps_available)}")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.logger.info("Apple Silicon MPS/CoreML is available in Torch, setting Torch device to MPS")

if mps_available:
self.logger.debug("Running in Apple Silicon MPS GPU mode")
# TODO: Change this to use MPS once FFTs are supported, see https://github.com/pytorch/pytorch/issues/78044
# self.device = torch.device("mps")

# TODO: Change this to use MPS once FFTs are supported, see https://github.com/pytorch/pytorch/issues/78044
# self.device = torch.device("mps")
self.logger.warning("Torch MPS backend does not yet support FFT operations, Torch will still use CPU!")
self.logger.warning("To track progress towards Apple Silicon acceleration, see https://github.com/pytorch/pytorch/issues/78044")
self.device = torch.device("cpu")

self.device = torch.device("cpu")
self.run_type = ["CoreMLExecutionProvider"]
if "CoreMLExecutionProvider" in ort_providers:
self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
hardware_acceleration_enabled = True
else:
raise Exception(
"Apple Silicon CoreML / MPS requested but not available with current Torch installation. Do you have an Apple Silicon GPU?"
)

self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
self.logger.warning("If you expect MPS/CoreML to work with your Mac, try pip install --force-reinstall onnxruntime-silicon")
else:
self.logger.debug("Running in CPU mode")
raise Exception(
"Apple Silicon CoreML / MPS requested but not available with current Torch installation. Do you have an Apple Silicon GPU?"
)

if not hardware_acceleration_enabled:
self.logger.info("No hardware acceleration could be configured, running in CPU mode")
self.device = torch.device("cpu")
self.run_type = ["CPUExecutionProvider"]
self.onnx_execution_provider = ["CPUExecutionProvider"]

def get_model_hash(self, model_path):
try:
Expand Down Expand Up @@ -195,7 +197,7 @@ def separate(self):
)

self.logger.debug("Loading model...")
ort_ = ort.InferenceSession(model_path, providers=self.run_type)
ort_ = ort.InferenceSession(model_path, providers=self.onnx_execution_provider)
self.model_run = lambda spek: ort_.run(None, {"input": spek.cpu().numpy()})[0]

self.initialize_model_settings()
Expand Down Expand Up @@ -235,7 +237,9 @@ def separate(self):
self.write_audio(self.secondary_stem_path, self.secondary_source, samplerate)
output_files.append(self.secondary_stem_path)

torch.cuda.empty_cache()
if hasattr(torch, "cuda"):
torch.cuda.empty_cache()

return output_files

def write_audio(self, stem_path, stem_source, samplerate):
Expand Down
14 changes: 0 additions & 14 deletions audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,6 @@ def main():
help="Optional: directory to write output files (default: <current dir>). Example: --output_dir=/app/separated",
)

parser.add_argument(
"--use_cuda",
action="store_true",
help="Optional: use Nvidia GPU with CUDA for separation (default: %(default)s). Example: --use_cuda=true",
)

parser.add_argument(
"--use_coreml",
action="store_true",
help="Optional: use Apple Silicon GPU with CoreML for separation (default: %(default)s). Example: --use_coreml=true",
)

parser.add_argument(
"--output_format",
default="FLAC",
Expand Down Expand Up @@ -102,8 +90,6 @@ def main():
model_name=args.model_name,
model_file_dir=args.model_file_dir,
output_dir=args.output_dir,
use_cuda=args.use_cuda,
use_coreml=args.use_coreml,
output_format=args.output_format,
denoise_enabled=args.denoise,
normalization_enabled=args.normalize,
Expand Down
15 changes: 7 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "audio-separator"
version = "0.8.1"
version = "0.9.0"
description = "Easy to use vocal separation on CLI or as a python package, using the amazing MDX-Net models from UVR trained by @Anjok07"
authors = ["Andrew Beveridge <[email protected]>"]
license = "MIT"
Expand All @@ -19,12 +19,11 @@ librosa = ">=0.9"
torch = ">=2"
wget = ">=3"
six = ">=1.16"
onnxruntime = { version = ">=1.15", optional = true }
onnxruntime = ">=1.15"
onnxruntime-gpu = { version = ">=1.15", optional = true }
onnxruntime-silicon = { version = ">=1.15", optional = true }

[tool.poetry.extras]
cpu = ["onnxruntime"]
gpu = ["onnxruntime-gpu"]
silicon = ["onnxruntime-silicon"]

Expand Down

0 comments on commit 42c9d03

Please sign in to comment.