Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase VAE decode memory estimates #7674

Draft
wants to merge 2 commits into
base: ryan/vae-decode-mem
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions docs/features/low-vram.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,24 +86,26 @@ But, if your GPU has enough VRAM to hold models fully, you might get a perf boos
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
# might be a good value to achieve aggressive model caching.
max_cache_ram_gb: 28

# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
# consideration the VRAM used by other processes).
# You can override the default value by setting `max_cache_vram_gb`. Note that this value takes precedence over the
# `device_working_mem_gb`.
# It is recommended to set the VRAM cache size to be as large as possible while leaving enough room for the working
# memory of the tasks you will be doing. For example, on a 24GB GPU that will be running unquantized FLUX without any
# auxiliary models, 18GB might be a good value.
max_cache_vram_gb: 18
# You can override the default value by setting `max_cache_vram_gb`.
# CAUTION: Most users should not manually set this value. See warning below.
max_cache_vram_gb: 16
```

!!! tip "Max safe value for `max_cache_vram_gb`"
!!! warning "Max safe value for `max_cache_vram_gb`"

Most users should not manually configure the `max_cache_vram_gb`. This configuration value takes precedence over the `device_working_mem_gb` and any operations that explicitly reserve additional working memory (e.g. VAE decode). As such, manually configuring it increases the likelihood of encountering out-of-memory errors.

To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
For users who wish to configure `max_cache_vram_gb`, the max safe value can be determined by subtracting `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.

For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.

If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.

Most users who override `max_cache_vram_gb` are doing so because they wish to use significantly less VRAM, and should be setting `max_cache_vram_gb` to a value significantly less than the 'max safe value'.

### Working memory

Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
Expand Down
7 changes: 1 addition & 6 deletions invokeai/app/invocations/flux_vae_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):

def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1090 # Determined experimentally.
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant

# We add a 20% buffer to the working memory estimate to be safe.
working_memory = working_memory * 1.2
return int(working_memory)

def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
Expand Down
6 changes: 2 additions & 4 deletions invokeai/app/invocations/latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _estimate_working_memory(
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
element_size = 4 if self.fp32 else 2
scaling_constant = 960 # Determined experimentally.
scaling_constant = 2200 # Determined experimentally.

if use_tiling:
tile_size = self.tile_size
Expand All @@ -84,9 +84,7 @@ def _estimate_working_memory(
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
working_memory += 250 * 2**20

# We add 20% to the working memory estimate to be safe.
working_memory = int(working_memory * 1.2)
return working_memory
return int(working_memory)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
Expand Down
7 changes: 1 addition & 6 deletions invokeai/app/invocations/sd3_latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,11 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):

def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1230 # Determined experimentally.
scaling_constant = 2200 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant

# We add a 20% buffer to the working memory estimate to be safe.
working_memory = working_memory * 1.2
return int(working_memory)

@torch.no_grad()
Expand Down
Loading