Skip to content

Commit

Permalink
Improved error handling for onnxruntime GPU vs non-GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
beveradb committed Dec 21, 2023
1 parent 90e1347 commit 6d932e2
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions audio_separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,22 @@ def __init__(
self.logger.debug(f"Torch version: {str(torch.__version__)}")

cuda_available = torch.cuda.is_available()
self.logger.debug(f"Is CUDA enabled? {str(cuda_available)}")
self.logger.debug(f"Is CUDA enabled for Torch? {str(cuda_available)}")

if cuda_available:
self.logger.debug("Running in GPU mode")
self.logger.info("Torch running in CUDA GPU mode")
self.device = torch.device("cuda")
self.run_type = ["CUDAExecutionProvider"]
else:
raise Exception("CUDA requested but not available with current Torch installation. Do you have an Nvidia GPU?")

# Check GPU inferencing is enabled for ONNXRuntime too, which is essential to actually use the GPU
ort_device = ort.get_device()
if ort_device == 'GPU':
self.logger.info("ONNX Runtime running in GPU mode")
else:
raise Exception("CUDA requested but not available with current ONNX Runtime installation. Try pip install --force-reinstall onnxruntime-gpu")

elif self.use_coreml:
self.logger.debug("Apple Silicon CoreML requested, checking Torch version")
self.logger.debug(f"Torch version: {str(torch.__version__)}")
Expand Down

0 comments on commit 6d932e2

Please sign in to comment.