-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switch from use_cuda_malloc flag to a general pytorch_cuda_alloc_conf…
… config field that allows full customization of the CUDA allocator.
- Loading branch information
Showing
4 changed files
with
27 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,11 @@ | ||
import pytest | ||
|
||
from invokeai.app.util.torch_cuda_allocator import enable_torch_cuda_malloc, is_torch_cuda_malloc_enabled | ||
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator | ||
|
||
|
||
def test_is_torch_cuda_malloc_enabled(): | ||
"""Test that if torch CUDA malloc hasn't been explicitly enabled, then is_torch_cuda_malloc_enabled() returns | ||
False. | ||
""" | ||
assert not is_torch_cuda_malloc_enabled() | ||
|
||
|
||
def test_enable_torch_cuda_malloc_raises_if_torch_is_already_imported(): | ||
def test_configure_torch_cuda_allocator_raises_if_torch_is_already_imported(): | ||
"""Test that enable_torch_cuda_malloc() raises a RuntimeError if torch is already imported.""" | ||
import torch # noqa: F401 | ||
|
||
with pytest.raises(RuntimeError): | ||
enable_torch_cuda_malloc() | ||
with pytest.raises(RuntimeError, match="Failed to configure the PyTorch CUDA memory allocator."): | ||
configure_torch_cuda_allocator("backend:cudaMallocAsync") |