Skip to content

Commit

Permalink
Partial Loading PR5: Dynamic cache ram/vram limits (#7509)
Browse files Browse the repository at this point in the history
## Summary

This PR enables RAM/VRAM cache size limits to be determined dynamically
based on availability.

**Config Changes**

This PR modifies the app configs in the following ways:
- A new `device_working_mem_gb` config was added. This is the amount of
non-model working memory to keep available on the execution device (i.e.
GPU) when using dynamic cache limits. It default to 3GB.
- The `ram` and `vram` configs now default to `None`. If these configs
are set, they will take precedence over the dynamic limits. **Note: Some
users may have previously overriden the `ram` and `vram` values in their
`invokeai.yaml`. They will need to remove these configs to enable the
new dynamic limit feature.**

**Working Memory**

In addition to the new `device_working_mem_gb` config described above,
memory-intensive operations can estimate the amount of working memory
that they will need and request it from the model cache. This is
currently applied to the VAE decoding step for all models. In the
future, we may apply this to other operations as we work out which ops
tend to exceed the default working memory reservation.

**Mitigations for #7513

This PR includes some mitigations for the issue described in
#7513. Without these
mitigations, it would occur with higher frequency when dynamic RAM
limits are used and the RAM is close to maxed-out.

## Limitations / Future Work

- Only _models_ can be offloaded to RAM to conserve VRAM. I.e. if VAE
decoding requires more working VRAM than available, the best we can do
is keep the full model on the CPU, but we will still hit an OOM error.
In the future, we could detect this ahead of time and switch to running
inference on the CPU for those ops.
- There is often a non-negligible amount of VRAM 'reserved' by the torch
CUDA allocator, but not used by any allocated tensors. We may be able to
tune the torch CUDA allocator to work better for our use case.
Reference:
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
- There may be some ops that require high working memory that haven't
been updated to request extra memory yet. We will update these as we
uncover them.
- If a model is 'locked' in VRAM, it won't be partially unloaded if a
later model load requests extra working memory. This should be uncommon,
but I can think of cases where it would matter.

## Related Issues / Discussions

- #7492 
- #7494 
- #7500 
- #7505 

## QA Instructions

Run a variety of models near the cache limits to ensure that model
switching works properly for the following configurations:
- [x] CUDA, `enable_partial_loading=true`, all other configs default
(i.e. dynamic memory limits)
- [x] CUDA, `enable_partial_loading=true`, CPU and CUDA memory reserved
in another process so there is limited RAM/VRAM remaining, all other
configs default (i.e. dynamic memory limits)
- [x] CUDA, `enable_partial_loading=false`, all other configs default
(i.e. dynamic memory limits)
- [x] CUDA, ram/vram limits set (these should take precedence over the
dynamic limits)
- [x] MPS, all other default (i.e. dynamic memory limits)
- [x] CPU, all other default (i.e. dynamic memory limits) 

## Merge Plan

- [x] Merge #7505 first and change target branch to main

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Jan 7, 2025
2 parents 87fdcb7 + d7ab464 commit 0258b6a
Show file tree
Hide file tree
Showing 20 changed files with 314 additions and 298 deletions.
70 changes: 0 additions & 70 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import contextlib
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from enum import Enum
Expand All @@ -21,7 +20,6 @@
from typing_extensions import Annotated

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.config import get_config
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
Expand Down Expand Up @@ -848,74 +846,6 @@ async def get_starter_models() -> StarterModelResponse:
return StarterModelResponse(starter_models=starter_models, starter_bundles=starter_bundles)


@model_manager_router.get(
"/model_cache",
operation_id="get_cache_size",
response_model=float,
summary="Get maximum size of model manager RAM or VRAM cache.",
)
async def get_cache_size(cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM)) -> float:
"""Return the current RAM or VRAM cache size setting (in GB)."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
value = 0.0
if cache_type == CacheType.RAM:
value = cache.max_cache_size
elif cache_type == CacheType.VRAM:
value = cache.max_vram_cache_size
return value


@model_manager_router.put(
"/model_cache",
operation_id="set_cache_size",
response_model=float,
summary="Set maximum size of model manager RAM or VRAM cache, optionally writing new value out to invokeai.yaml config file.",
)
async def set_cache_size(
value: float = Query(description="The new value for the maximum cache size"),
cache_type: CacheType = Query(description="The cache type", default=CacheType.RAM),
persist: bool = Query(description="Write new value out to invokeai.yaml", default=False),
) -> float:
"""Set the current RAM or VRAM cache size setting (in GB). ."""
cache = ApiDependencies.invoker.services.model_manager.load.ram_cache
app_config = get_config()
# Record initial state.
vram_old = app_config.vram
ram_old = app_config.ram

# Prepare target state.
vram_new = vram_old
ram_new = ram_old
if cache_type == CacheType.RAM:
ram_new = value
elif cache_type == CacheType.VRAM:
vram_new = value
else:
raise ValueError(f"Unexpected {cache_type=}.")

config_path = app_config.config_file_path
new_config_path = config_path.with_suffix(".yaml.new")

try:
# Try to apply the target state.
cache.max_vram_cache_size = vram_new
cache.max_cache_size = ram_new
app_config.ram = ram_new
app_config.vram = vram_new
if persist:
app_config.write_file(new_config_path)
shutil.move(new_config_path, config_path)
except Exception as e:
# If there was a failure, restore the initial state.
cache.max_cache_size = ram_old
cache.max_vram_cache_size = vram_old
app_config.ram = ram_old
app_config.vram = vram_old

raise RuntimeError("Failed to update cache size") from e
return value


@model_manager_router.get(
"/stats",
operation_id="get_stats",
Expand Down
11 changes: 3 additions & 8 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ class CompelInvocation(BaseInvocation):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)

def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
Expand All @@ -76,12 +73,13 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:

# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]

text_encoder_info = context.models.load(self.clip.text_encoder)
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)

with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
context.models.load(self.clip.tokenizer) as tokenizer,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
Expand Down Expand Up @@ -140,9 +138,7 @@ def run_clip_compel(
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)

# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.model
Expand Down Expand Up @@ -180,7 +176,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
with (
# apply all patches while the model is on the target device
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
tokenizer_info as tokenizer,
context.models.load(clip_field.tokenizer) as tokenizer,
LayerPatcher.apply_smart_model_patches(
model=text_encoder,
patches=_lora_loader(),
Expand Down Expand Up @@ -226,7 +222,6 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:

del tokenizer
del text_encoder
del tokenizer_info
del text_encoder_info

c = c.detach().to("cpu")
Expand Down
14 changes: 4 additions & 10 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def prep_ip_adapter_image_prompts(
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
Expand All @@ -556,7 +555,7 @@ def prep_ip_adapter_image_prompts(
single_ipa_images = [
context.images.get_pil(image.image_name, mode="RGB") for image in single_ipa_image_fields
]
with image_encoder_model_info as image_encoder_model:
with context.models.load(single_ip_adapter.image_encoder_model) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
Expand Down Expand Up @@ -621,7 +620,6 @@ def run_t2i_adapters(
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")

# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
Expand All @@ -637,7 +635,7 @@ def run_t2i_adapters(
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")

t2i_adapter_model: T2IAdapter
with t2i_adapter_loaded_model as t2i_adapter_model:
with context.models.load(t2i_adapter_field.t2i_adapter_model) as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor

# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
Expand Down Expand Up @@ -926,10 +924,8 @@ def step_callback(state: PipelineIntermediateState) -> None:
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (cached_weights, unet),
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
Expand Down Expand Up @@ -995,11 +991,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
del lora_info
return

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
unet_info.model_on_device() as (cached_weights, unet),
context.models.load(self.unet.unet).model_on_device() as (cached_weights, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
Expand Down
27 changes: 15 additions & 12 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def _run_diffusion(
else None
)

transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_info.config, "config_path", "")
transformer_config = context.models.get_config(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")

# Calculate the timestep schedule.
timesteps = get_schedule(
Expand Down Expand Up @@ -299,9 +299,11 @@ def _run_diffusion(
)

# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
(cached_weights, transformer) = exit_stack.enter_context(
context.models.load(self.transformer.transformer).model_on_device()
)
assert isinstance(transformer, Flux)
config = transformer_info.config
config = transformer_config
assert config is not None

# Determine if the model is quantized.
Expand Down Expand Up @@ -512,15 +514,18 @@ def _prep_controlnet_extensions(
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.

# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]

# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
for controlnet in controlnets:
image = context.images.get_pil(controlnet.image.image_name)

# HACK(ryand): We have to load the ControlNet model to determine whether the VAE needs to be run. We really
# shouldn't have to load the model here. There's a risk that the model will be dropped from the model cache
# before we load it into VRAM and thus we'll have to load it again (context:
# https://github.com/invoke-ai/InvokeAI/issues/7513).
controlnet_model = context.models.load(controlnet.control_model)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
Expand Down Expand Up @@ -550,10 +555,8 @@ def _prep_controlnet_extensions(

# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)
for controlnet, controlnet_cond in zip(controlnets, controlnet_conds, strict=True):
model = exit_stack.enter_context(context.models.load(controlnet.control_model))

if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
Expand Down
19 changes: 7 additions & 12 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,11 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
)

def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)

prompt = [self.prompt]

with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
context.models.load(self.t5_encoder.text_encoder) as t5_text_encoder,
context.models.load(self.t5_encoder.tokenizer) as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, T5Tokenizer)
Expand All @@ -90,22 +87,20 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
return prompt_embeds

def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)

prompt = [self.prompt]

clip_text_encoder_info = context.models.load(self.clip.text_encoder)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None

with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
context.models.load(self.clip.tokenizer) as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)

clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None

# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
Expand Down
20 changes: 18 additions & 2 deletions invokeai/app/invocations/flux_vae_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from PIL import Image

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
Expand All @@ -24,7 +25,7 @@
title="FLUX Latents to Image",
tags=["latents", "image", "vae", "l2i", "flux"],
category="latents",
version="1.0.0",
version="1.0.1",
)
class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
Expand All @@ -38,8 +39,23 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)

def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
scaling_constant = 1090 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant

# We add a 20% buffer to the working memory estimate to be safe.
working_memory = working_memory * 1.2
return int(working_memory)

def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
estimated_working_memory = self._estimate_working_memory(latents, vae_info.model)
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
assert isinstance(vae, AutoEncoder)
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
Expand Down
Loading

0 comments on commit 0258b6a

Please sign in to comment.