Skip to content

Commit

Permalink
Add FLUX Control LoRA weight param (#7452)
Browse files Browse the repository at this point in the history
## Summary

Add the ability to control the weight of a FLUX Control LoRA.

## Example

Original image:
<div style="display: flex; gap: 10px;">
<img
src="https://github.com/user-attachments/assets/4a2d9f4a-b58b-4df6-af90-67b018763a38"
alt="Image 1" width="300"/>
</div>

Prompt: `a scarecrow playing tennis`
Weights: 0.4, 0.6, 0.8, 1.0
<div style="display: flex; gap: 10px;">
<img
src="https://github.com/user-attachments/assets/62b83fd6-46ce-460a-8d51-9c2cda9b05c9"
alt="Image 1" width="300"/>
<img
src="https://github.com/user-attachments/assets/75442207-1538-46bc-9d6b-08ac5c235c93"
alt="Image 2" width="300"/>
</div>
<div style="display: flex; gap: 10px;">
<img
src="https://github.com/user-attachments/assets/4a9dc9ea-9757-4965-837e-197fc9243007"
alt="Image 1" width="300"/>
<img
src="https://github.com/user-attachments/assets/846f6918-ca82-4482-8c19-19172752fa8c"
alt="Image 2" width="300"/>
</div>

## QA Instructions

- [x] weight control changes strength of control image
- [x] Test that results match across both quantized and non-quantized.

## Merge Plan

**_Do not merge this PR yet._**

1. Merge #7450 
2. Merge #7446 
3. Change target branch to main
4. Merge this branch.

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Dec 17, 2024
2 parents 4d5f74c + d764aa4 commit 594511c
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 16 deletions.
5 changes: 3 additions & 2 deletions invokeai/app/invocations/flux_control_lora_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
title="Flux Control LoRA",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class FluxControlLoRALoaderInvocation(BaseInvocation):
Expand All @@ -34,6 +34,7 @@ class FluxControlLoRALoaderInvocation(BaseInvocation):
description=FieldDescriptions.control_lora_model, title="Control LoRA", ui_type=UIType.ControlLoRAModel
)
image: ImageField = InputField(description="The image to encode.")
weight: float = InputField(description="The weight of the LoRA.", default=1.0)

def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
if not context.models.exists(self.lora.key):
Expand All @@ -43,6 +44,6 @@ def invoke(self, context: InvocationContext) -> FluxControlLoRALoaderOutput:
control_lora=ControlLoRAField(
lora=self.lora,
img=self.image,
weight=1,
weight=self.weight,
)
)
4 changes: 3 additions & 1 deletion invokeai/backend/patches/layers/set_parameter_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def __init__(self, param_name: str, weight: torch.Tensor):
self.param_name = param_name

def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]:
# Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX
# Control LoRA implementation.
diff = self.weight - orig_module.get_parameter(self.param_name)
return {self.param_name: diff * weight}
return {self.param_name: diff}

def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.weight = self.weight.to(device=device, dtype=dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ export const ControlLayerControlAdapter = memo(() => {
/>
<input {...uploadApi.getUploadInputProps()} />
</Flex>
{controlAdapter.type !== 'control_lora' && <Weight weight={controlAdapter.weight} onChange={onChangeWeight} />}
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
{controlAdapter.type !== 'control_lora' && (
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ import {
getReferenceImageState,
getRegionalGuidanceState,
imageDTOToImageWithDims,
initialControlLoRA,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
Expand Down Expand Up @@ -462,38 +463,64 @@ export const canvasSlice = createSlice({
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);

// When converting between control layer types, we may need to add or remove properties. For example, ControlNet
// has a control mode, while T2I Adapter does not - otherwise they are the same.

switch (layer.controlAdapter.model.type) {
// Converting to T2I adapter from...
case 't2i_adapter': {
if (layer.controlAdapter.type === 'controlnet') {
// T2I Adapters have all the ControlNet properties, minus control mode - strip it
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
const t2iAdapterConfig: T2IAdapterConfig = { ...initialT2IAdapter, ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
} else if (layer.controlAdapter.type === 'control_lora') {
const t2iAdapterConfig: T2IAdapterConfig = { ...layer.controlAdapter, ...initialT2IAdapter };
// Control LoRAs have only model and weight
const t2iAdapterConfig: T2IAdapterConfig = {
...initialT2IAdapter,
...layer.controlAdapter,
type: 't2i_adapter',
};
layer.controlAdapter = t2iAdapterConfig;
}
break;
}

// Converting to ControlNet from...
case 'controlnet': {
if (layer.controlAdapter.type === 't2i_adapter') {
// ControlNets have all the T2I Adapter properties, plus control mode
const controlNetConfig: ControlNetConfig = {
...initialControlNet,
...layer.controlAdapter,
type: 'controlnet',
controlMode: initialControlNet.controlMode,
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'control_lora') {
const controlNetConfig: ControlNetConfig = { ...layer.controlAdapter, ...initialControlNet };
// ControlNets have all the Control LoRA properties, plus control mode and begin/end step pct
const controlNetConfig: ControlNetConfig = {
...initialControlNet,
...layer.controlAdapter,
type: 'controlnet',
};
layer.controlAdapter = controlNetConfig;
}
break;
}

// Converting to ControlLoRA from...
case 'control_lora': {
const controlLoraConfig: ControlLoRAConfig = { ...layer.controlAdapter, type: 'control_lora' };
layer.controlAdapter = controlLoraConfig;

if (layer.controlAdapter.type === 'controlnet') {
// We only need the model and weight for Control LoRA
const { model, weight } = layer.controlAdapter;
const controlNetConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 't2i_adapter') {
// We only need the model and weight for Control LoRA
const { model, weight } = layer.controlAdapter;
const t2iAdapterConfig: ControlLoRAConfig = { ...initialControlLoRA, model, weight };
layer.controlAdapter = t2iAdapterConfig;
}
break;
}

Expand All @@ -518,7 +545,7 @@ export const canvasSlice = createSlice({
) => {
const { entityIdentifier, weight } = action.payload;
const layer = selectEntity(state, entityIdentifier);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type === 'control_lora') {
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.weight = weight;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;

const zControlLoRAConfig = z.object({
type: z.literal('control_lora'),
weight: z.number().gte(-1).lte(2),
model: zServerValidatedModelIdentifierField.nullable(),
});
export type ControlLoRAConfig = z.infer<typeof zControlLoRAConfig>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type {
CanvasRasterLayerState,
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
ControlLoRAConfig,
ControlNetConfig,
ImageWithDims,
IPAdapterConfig,
Expand Down Expand Up @@ -82,6 +83,11 @@ export const initialControlNet: ControlNetConfig = {
beginEndStepPct: [0, 0.75],
controlMode: 'balanced',
};
export const initialControlLoRA: ControlLoRAConfig = {
type: 'control_lora',
model: null,
weight: 0.75,
};

export const getReferenceImageState = (
id: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ const addControlLoRAToGraph = (
) => {
const { id, controlAdapter } = layer;
assert(controlAdapter.type === 'control_lora');
const { model } = controlAdapter;
const { model, weight } = controlAdapter;
assert(model !== null);
const { image_name } = imageDTO;

Expand All @@ -216,6 +216,7 @@ const addControlLoRAToGraph = (
type: 'flux_control_lora_loader',
lora: model,
image: { image_name },
weight: weight,
});

g.addEdge(controlLoRA, 'control_lora', denoise, 'control_lora');
Expand Down
12 changes: 9 additions & 3 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6708,6 +6708,12 @@ export type components = {
* @default null
*/
image?: components["schemas"]["ImageField"];
/**
* Weight
* @description The weight of the LoRA.
* @default 1
*/
weight?: number;
/**
* type
* @default flux_control_lora_loader
Expand All @@ -6722,11 +6728,11 @@ export type components = {
*/
FluxControlLoRALoaderOutput: {
/**
* Flux Control Lora
* Flux Control LoRA
* @description Control LoRAs to apply on model loading
* @default null
*/
control_lora: components["schemas"]["ControlLoRAField"] | null;
control_lora: components["schemas"]["ControlLoRAField"];
/**
* type
* @default flux_control_lora_loader_output
Expand Down Expand Up @@ -6926,7 +6932,7 @@ export type components = {
*/
transformer?: components["schemas"]["TransformerField"];
/**
* Control Lora
* Control LoRA
* @description Control LoRA model to load
* @default null
*/
Expand Down

0 comments on commit 594511c

Please sign in to comment.