Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLaVA OneVision model support #7693

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions invokeai/app/invocations/llava_onevision_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
from PIL.Image import Image

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, UIComponent
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.devices import TorchDevice


@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
class LlavaOnevisionVllmInvocation(BaseInvocation):
"""Run a LLaVA OneVision VLLM model."""

images: list[ImageField] | ImageField | None = InputField(default=None, description="Input image.")
prompt: str = InputField(
default="",
description="Input text prompt.",
ui_component=UIComponent.Textarea,
)
# vllm_model: ModelIdentifierField = InputField(
# title="Image-to-Image Model",
# description=FieldDescriptions.vllm_model,
# ui_type=UIType.LlavaOnevisionModel,
# )

def _get_images(self, context: InvocationContext) -> list[Image]:
if self.images is None:
return []

image_fields = self.images if isinstance(self.images, list) else [self.images]
return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields]

@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
images = self._get_images(context)

# with context.models.load(self.vllm_model) as vllm_model:
with context.models.load_by_attrs(
name="LLaVA Onevision Qwen2 0.5B", base=BaseModelType.Any, type=ModelType.LlavaOnevision
) as vllm_model:
assert isinstance(vllm_model, LlavaOnevisionModel)
output = vllm_model.run(
prompt=self.prompt,
images=images,
device=TorchDevice.choose_torch_device(),
dtype=TorchDevice.choose_torch_dtype(),
)

return StringOutput(value=output)
49 changes: 49 additions & 0 deletions invokeai/backend/llava_onevision_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Optional

import torch
from PIL.Image import Image
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor

from invokeai.backend.raw_model import RawModel


class LlavaOnevisionModel(RawModel):
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
self._vllm_model = vllm_model
self._processor = processor

@classmethod
def load_from_path(cls, path: str | Path):
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
return cls(vllm_model, processor)

def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
# TODO(ryand): Tune the max number of images that are useful for the model.
if len(images) > 3:
raise ValueError(
f"{len(images)} images were provided as input to the LLaVA OneVision model. "
"Pass <=3 images for good performance."
)

# Define a chat history and use `apply_chat_template` to get correctly formatted prompt.
# "content" is a list of dicts with types "text" or "image".
content = [{"type": "text", "text": prompt}]
# Add the correct number of images.
for _ in images:
content.append({"type": "image"})

conversation = [{"role": "user", "content": content}]
prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype)
output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False)
output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True)
# The output_str will include the prompt, so we extract the response.
response = output_str.split("assistant\n", 1)[1].strip()
return response

def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._vllm_model.to(device=device, dtype=dtype)
13 changes: 13 additions & 0 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class ModelType(str, Enum):
T2IAdapter = "t2i_adapter"
T5Encoder = "t5_encoder"
SpandrelImageToImage = "spandrel_image_to_image"
LlavaOnevision = "llava_onevision"


class SubModelType(str, Enum):
Expand Down Expand Up @@ -528,6 +529,17 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")


class LlavaOnevisionConfig(DiffusersConfigBase):
"""Model config for Llava Onevision models."""

type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LlavaOnevision.value}.{ModelFormat.Diffusers.value}")


def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
Expand Down Expand Up @@ -575,6 +587,7 @@ def get_model_discriminator_value(v: Any) -> str:
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pathlib import Path
from typing import Optional

from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers)
class LlavaOnevisionModelLoader(ModelLoader):
"""Class for loading LLaVA Onevision VLLM models."""

def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")

model_path = Path(config.path)
model = LlavaOnevisionModel.load_from_path(model_path)
model.to(dtype=self._torch_dtype)
return model
13 changes: 13 additions & 0 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class ModelProbe(object):
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
"LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision,
}

TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
Expand Down Expand Up @@ -752,6 +753,11 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class LlavaOnevisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


########################################################
# classes for probing folders
#######################################################
Expand Down Expand Up @@ -1022,6 +1028,11 @@ def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


class LlaveOnevisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
Expand Down Expand Up @@ -1055,6 +1066,7 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)

ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
Expand All @@ -1066,5 +1078,6 @@ def get_base_type(self) -> BaseModelType:
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)

ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
11 changes: 10 additions & 1 deletion invokeai/backend/model_manager/starter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,15 @@ class StarterModelBundles(BaseModel):
)

# endregion

# region LlavaOnevisionModel
llava_onevision = StarterModel(
name="LLaVA Onevision Qwen2 0.5B",
base=BaseModelType.Any,
source="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
description="LLaVA Onevision VLLM model",
type=ModelType.LlavaOnevision,
)
# endregion

# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
Expand Down Expand Up @@ -661,6 +669,7 @@ class StarterModelBundles(BaseModel):
t5_base_encoder,
t5_8b_quantized_encoder,
clip_l_encoder,
llava_onevision,
]

sd1_bundle: list[StarterModel] = [
Expand Down
Loading