Skip to content

Commit

Permalink
FLUX Regional Prompting (#7388)
Browse files Browse the repository at this point in the history
## Summary

This PR adds support for regional prompting with FLUX.

### Example 1
Global prompt: `An architecture rendering of the reception area of a
corporate office with modern decor.`
<img width="1386" alt="image"
src="https://github.com/user-attachments/assets/c8169bdb-49a9-44bc-bd9e-58d98e09094b">

![image](https://github.com/user-attachments/assets/4a426be9-9d7a-4527-b27c-2d2514ee73fe)

## QA Instructions

- [x] Test that there is no slowdown in the base case with a single
global prompt.
- [x] Test image fully covered by regional masks.
- [x] Test image covered by region masks with small gaps.
- [x] Test region masks with large unmasked ‘background’ regions
- [x] Test region masks with significant overlap
- [x] Test multiple global prompts.
- [x] Test no global prompt.
- [x] Test regional negative prompts (It runs... but results are not
great. Needs more tuning to be useful.)
- Test compatibility with:
    - [x] ControlNet
    - [x] LoRA
    - [x] IP-Adapter

## Remaining TODO

- [x] Disable the following UI features for FLUX prompt regions:
negative prompts, reference images, auto-negative.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Nov 29, 2024
2 parents 5bca68d + 7d488a5 commit 54b7f9a
Show file tree
Hide file tree
Showing 25 changed files with 1,105 additions and 360 deletions.
5 changes: 5 additions & 0 deletions invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)


class SD3ConditioningField(BaseModel):
Expand Down
126 changes: 78 additions & 48 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
Expand All @@ -42,6 +43,7 @@
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
Expand Down Expand Up @@ -87,10 +89,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: FluxConditioningField | None = InputField(
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
Expand Down Expand Up @@ -139,36 +141,12 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

def _load_text_conditioning(
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
return t5_embeddings, clip_embeddings

def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16

# Load the conditioning data.
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings: torch.Tensor | None = None
neg_clip_embeddings: torch.Tensor | None = None
if self.negative_text_conditioning is not None:
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)

# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
Expand All @@ -183,15 +161,45 @@ def _run_diffusion(
dtype=inference_dtype,
seed=self.seed,
)
b, _c, latent_h, latent_w = noise.shape
packed_h = latent_h // 2
packed_w = latent_w // 2

# Load the conditioning data.
pos_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.positive_text_conditioning,
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
neg_text_conditionings: list[FluxTextConditioning] | None = None
if self.negative_text_conditioning is not None:
neg_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.negative_text_conditioning,
packed_height=packed_h,
packed_width=packed_w,
dtype=inference_dtype,
device=TorchDevice.choose_torch_device(),
)
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
pos_text_conditionings, img_seq_len=packed_h * packed_w
)
neg_regional_prompting_extension = (
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
if neg_text_conditionings
else None
)

transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path

# Calculate the timestep schedule.
image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=image_seq_len,
image_seq_len=packed_h * packed_w,
shift=not is_schnell,
)

Expand Down Expand Up @@ -228,28 +236,17 @@ def _run_diffusion(

inpaint_mask = self._prep_inpaint_mask(context, x)

b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)

pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
pos_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_txt_ids: torch.Tensor | None = None
if neg_t5_embeddings is not None:
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)

# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)

# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
assert image_seq_len == x.shape[1]
# Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
# packed_w correctly.
assert packed_h * packed_w == x.shape[1]

# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
Expand Down Expand Up @@ -338,12 +335,8 @@ def _run_diffusion(
model=transformer,
img=x,
img_ids=img_ids,
txt=pos_t5_embeddings,
txt_ids=pos_txt_ids,
vec=pos_clip_embeddings,
neg_txt=neg_t5_embeddings,
neg_txt_ids=neg_txt_ids,
neg_vec=neg_clip_embeddings,
pos_regional_prompting_extension=pos_regional_prompting_extension,
neg_regional_prompting_extension=neg_regional_prompting_extension,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
Expand All @@ -357,6 +350,43 @@ def _run_diffusion(
x = unpack(x.float(), self.height, self.width)
return x

def _load_text_conditioning(
self,
context: InvocationContext,
cond_field: FluxConditioningField | list[FluxConditioningField],
packed_height: int,
packed_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[FluxTextConditioning]:
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
# Normalize to a list of FluxConditioningFields.
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field

text_conditionings: list[FluxTextConditioning] = []
for cond_field in cond_list:
# Load the text embeddings.
cond_data = context.conditioning.load(cond_field.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds

# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if cond_field.mask is not None:
mask = context.tensors.load(cond_field.mask.tensor_name)
mask = mask.to(device=device)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
mask, packed_height, packed_width, dtype, device
)

text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))

return text_conditionings

@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
Expand Down
21 changes: 15 additions & 6 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Tuple
from typing import Iterator, Literal, Optional, Tuple

import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
Input,
InputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand Down Expand Up @@ -41,9 +48,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(
description="Text prompt to encode.",
ui_component=UIComponent.Textarea,
prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)

@torch.no_grad()
Expand All @@ -57,7 +64,9 @@ def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
)

conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
return FluxConditioningOutput(
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)

def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
Expand Down
63 changes: 59 additions & 4 deletions invokeai/backend/flux/custom_block_processor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import einops
import torch

from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.math import attention
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock


class CustomDoubleStreamBlockProcessor:
Expand All @@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:

@staticmethod
def _double_stream_block_forward(
block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
block: DoubleStreamBlock,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
values.
Expand All @@ -40,7 +46,7 @@ def _double_stream_block_forward(
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

attn = attention(q, k, v, pe=pe)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
Expand All @@ -63,11 +69,15 @@ def custom_double_block_forward(
vec: torch.Tensor,
pe: torch.Tensor,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
regional_prompting_extension: RegionalPromptingExtension,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
- IP-Adapter support
"""
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
block, img, txt, vec, pe, attn_mask=attn_mask
)

# Apply IP-Adapter conditioning.
for ip_adapter_extension in ip_adapter_extensions:
Expand All @@ -81,3 +91,48 @@ def custom_double_block_forward(
)

return img, txt


class CustomSingleStreamBlockProcessor:
"""A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
etc.)
"""

@staticmethod
def _single_stream_block_forward(
block: SingleStreamBlock,
x: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""This function is a direct copy of SingleStreamBlock.forward()."""
mod, _ = block.modulation(vec)
x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)

q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
q, k = block.norm(q, k, v)

# compute attention
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
return x + mod.gate * output

@staticmethod
def custom_single_block_forward(
timestep_index: int,
total_num_timesteps: int,
block_index: int,
block: SingleStreamBlock,
img: torch.Tensor,
vec: torch.Tensor,
pe: torch.Tensor,
regional_prompting_extension: RegionalPromptingExtension,
) -> torch.Tensor:
"""A custom implementation of SingleStreamBlock.forward() with additional features:
- Masking
"""
attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
Loading

0 comments on commit 54b7f9a

Please sign in to comment.