diff --git a/invokeai/app/invocations/llava_onevision_vllm.py b/invokeai/app/invocations/llava_onevision_vllm.py new file mode 100644 index 00000000000..fa103545030 --- /dev/null +++ b/invokeai/app/invocations/llava_onevision_vllm.py @@ -0,0 +1,52 @@ +import torch +from PIL.Image import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ImageField, InputField, UIComponent +from invokeai.app.invocations.primitives import StringOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.llava_onevision_model import LlavaOnevisionModel +from invokeai.backend.model_manager.config import BaseModelType, ModelType +from invokeai.backend.util.devices import TorchDevice + + +@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0") +class LlavaOnevisionVllmInvocation(BaseInvocation): + """Run a LLaVA OneVision VLLM model.""" + + images: list[ImageField] | ImageField | None = InputField(default=None, description="Input image.") + prompt: str = InputField( + default="", + description="Input text prompt.", + ui_component=UIComponent.Textarea, + ) + # vllm_model: ModelIdentifierField = InputField( + # title="Image-to-Image Model", + # description=FieldDescriptions.vllm_model, + # ui_type=UIType.LlavaOnevisionModel, + # ) + + def _get_images(self, context: InvocationContext) -> list[Image]: + if self.images is None: + return [] + + image_fields = self.images if isinstance(self.images, list) else [self.images] + return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields] + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> StringOutput: + images = self._get_images(context) + + # with context.models.load(self.vllm_model) as vllm_model: + with context.models.load_by_attrs( + name="LLaVA Onevision Qwen2 0.5B", base=BaseModelType.Any, type=ModelType.LlavaOnevision + ) as vllm_model: + assert isinstance(vllm_model, LlavaOnevisionModel) + output = vllm_model.run( + prompt=self.prompt, + images=images, + device=TorchDevice.choose_torch_device(), + dtype=TorchDevice.choose_torch_dtype(), + ) + + return StringOutput(value=output) diff --git a/invokeai/backend/llava_onevision_model.py b/invokeai/backend/llava_onevision_model.py new file mode 100644 index 00000000000..af6aee6467d --- /dev/null +++ b/invokeai/backend/llava_onevision_model.py @@ -0,0 +1,49 @@ +from pathlib import Path +from typing import Optional + +import torch +from PIL.Image import Image +from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor + +from invokeai.backend.raw_model import RawModel + + +class LlavaOnevisionModel(RawModel): + def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor): + self._vllm_model = vllm_model + self._processor = processor + + @classmethod + def load_from_path(cls, path: str | Path): + vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True) + assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration) + processor = AutoProcessor.from_pretrained(path, local_files_only=True) + assert isinstance(processor, LlavaOnevisionProcessor) + return cls(vllm_model, processor) + + def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str: + # TODO(ryand): Tune the max number of images that are useful for the model. + if len(images) > 3: + raise ValueError( + f"{len(images)} images were provided as input to the LLaVA OneVision model. " + "Pass <=3 images for good performance." + ) + + # Define a chat history and use `apply_chat_template` to get correctly formatted prompt. + # "content" is a list of dicts with types "text" or "image". + content = [{"type": "text", "text": prompt}] + # Add the correct number of images. + for _ in images: + content.append({"type": "image"}) + + conversation = [{"role": "user", "content": content}] + prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype) + output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False) + output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True) + # The output_str will include the prompt, so we extract the response. + response = output_str.split("assistant\n", 1)[1].strip() + return response + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + self._vllm_model.to(device=device, dtype=dtype) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 956a8468fa2..1625c8c7e6f 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -76,6 +76,7 @@ class ModelType(str, Enum): T2IAdapter = "t2i_adapter" T5Encoder = "t5_encoder" SpandrelImageToImage = "spandrel_image_to_image" + LlavaOnevision = "llava_onevision" class SubModelType(str, Enum): @@ -528,6 +529,17 @@ def get_tag() -> Tag: return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}") +class LlavaOnevisionConfig(DiffusersConfigBase): + """Model config for Llava Onevision models.""" + + type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision + format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.LlavaOnevision.value}.{ModelFormat.Diffusers.value}") + + def get_model_discriminator_value(v: Any) -> str: """ Computes the discriminator value for a model config. @@ -575,6 +587,7 @@ def get_model_discriminator_value(v: Any) -> str: Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()], Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()], Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()], + Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()], ], Discriminator(get_model_discriminator_value), ] diff --git a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py new file mode 100644 index 00000000000..32c98c6a351 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py @@ -0,0 +1,32 @@ +from pathlib import Path +from typing import Optional + +from invokeai.backend.llava_onevision_model import LlavaOnevisionModel +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry + + +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers) +class LlavaOnevisionModelLoader(ModelLoader): + """Class for loading LLaVA Onevision VLLM models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("Unexpected submodel requested for LLaVA OneVision model.") + + model_path = Path(config.path) + model = LlavaOnevisionModel.load_from_path(model_path) + model.to(dtype=self._torch_dtype) + return model diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 82378d08e01..9ce9ad1ed22 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -139,6 +139,7 @@ class ModelProbe(object): "FluxControlNetModel": ModelType.ControlNet, "SD3Transformer2DModel": ModelType.Main, "CLIPTextModelWithProjection": ModelType.CLIPEmbed, + "LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision, } TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type} @@ -752,6 +753,11 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.Any +class LlavaOnevisionCheckpointProbe(CheckpointProbeBase): + def get_base_type(self) -> BaseModelType: + raise NotImplementedError() + + ######################################################## # classes for probing folders ####################################################### @@ -1022,6 +1028,11 @@ def get_base_type(self) -> BaseModelType: raise NotImplementedError() +class LlaveOnevisionFolderProbe(FolderProbeBase): + def get_base_type(self) -> BaseModelType: + return BaseModelType.Any + + class T2IAdapterFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" @@ -1055,6 +1066,7 @@ def get_base_type(self) -> BaseModelType: ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) @@ -1066,5 +1078,6 @@ def get_base_type(self) -> BaseModelType: ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index 5674a0a6034..c2727df877f 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -592,7 +592,15 @@ class StarterModelBundles(BaseModel): ) # endregion - +# region LlavaOnevisionModel +llava_onevision = StarterModel( + name="LLaVA Onevision Qwen2 0.5B", + base=BaseModelType.Any, + source="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + description="LLaVA Onevision VLLM model", + type=ModelType.LlavaOnevision, +) +# endregion # List of starter models, displayed on the frontend. # The order/sort of this list is not changed by the frontend - set it how you want it here. @@ -661,6 +669,7 @@ class StarterModelBundles(BaseModel): t5_base_encoder, t5_8b_quantized_encoder, clip_l_encoder, + llava_onevision, ] sd1_bundle: list[StarterModel] = [