Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maryhipp/workflow thumbnails #7676

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
Expand Down Expand Up @@ -83,6 +84,7 @@ def initialize(

model_images_folder = config.models_path
style_presets_folder = config.style_presets_path
workflow_thumbnails_folder = config.workflow_thumbnails_path

db = init_db(config=config, logger=logger, image_files=image_files)

Expand Down Expand Up @@ -120,6 +122,7 @@ def initialize(
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)

services = InvocationServices(
board_image_records=board_image_records,
Expand Down Expand Up @@ -147,6 +150,7 @@ def initialize(
conditioning=conditioning,
style_preset_records=style_preset_records,
style_preset_image_files=style_preset_image_files,
workflow_thumbnails=workflow_thumbnails,
)

ApiDependencies.invoker = Invoker(services)
Expand Down
128 changes: 120 additions & 8 deletions invokeai/app/api/routers/workflows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import io
import traceback
from typing import Optional

from fastapi import APIRouter, Body, HTTPException, Path, Query
from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFile
from fastapi.responses import FileResponse
from PIL import Image

from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.shared.pagination import PaginatedResults
Expand All @@ -10,27 +14,31 @@
WorkflowCategory,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordListItemWithThumbnailDTO,
WorkflowRecordOrderBy,
WorkflowRecordWithThumbnailDTO,
WorkflowWithoutID,
)

IMAGE_MAX_AGE = 31536000
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])


@workflows_router.get(
"/i/{workflow_id}",
operation_id="get_workflow",
responses={
200: {"model": WorkflowRecordDTO},
200: {"model": WorkflowRecordWithThumbnailDTO},
},
)
async def get_workflow(
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowRecordDTO:
) -> WorkflowRecordWithThumbnailDTO:
"""Gets a workflow"""
try:
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")

Expand All @@ -57,6 +65,7 @@ async def delete_workflow(
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)


Expand All @@ -78,7 +87,7 @@ async def create_workflow(
"/",
operation_id="list_workflows",
responses={
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
200: {"model": PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]},
},
)
async def list_workflows(
Expand All @@ -90,8 +99,111 @@ async def list_workflows(
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
) -> PaginatedResults[WorkflowRecordListItemDTO]:
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
return ApiDependencies.invoker.services.workflow_records.get_many(
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by, direction=direction, page=page, per_page=per_page, query=query, category=category
)
for workflow in workflows.items:
workflows_with_thumbnails.append(
WorkflowRecordListItemWithThumbnailDTO(
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
**workflow.model_dump(),
)
)
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
items=workflows_with_thumbnails,
total=workflows.total,
page=workflows.page,
pages=workflows.pages,
per_page=workflows.per_page,
)


@workflows_router.put(
"/i/{workflow_id}/thumbnail",
operation_id="set_workflow_thumbnail",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def set_workflow_thumbnail(
workflow_id: str = Path(description="The workflow to update"),
image: UploadFile = File(description="The image file to upload"),
):
"""Sets a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")

if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")

contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))

except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")

try:
ApiDependencies.invoker.services.workflow_thumbnails.save(workflow_id, pil_image)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@workflows_router.delete(
"/i/{workflow_id}/thumbnail",
operation_id="delete_workflow_thumbnail",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def delete_workflow_thumbnail(
workflow_id: str = Path(description="The workflow to update"),
):
"""Removes a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")

try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e))


@workflows_router.get(
"/i/{workflow_id}/thumbnail",
operation_id="get_workflow_thumbnail",
responses={
200: {
"description": "The workflow thumbnail was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The workflow thumbnail could not be found"},
},
status_code=200,
)
async def get_workflow_thumbnail(
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
) -> FileResponse:
"""Gets a workflow's thumbnail image"""

try:
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)

response = FileResponse(
path,
media_type="image/png",
filename=workflow_id + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
7 changes: 7 additions & 0 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class InvokeAIAppConfig(BaseSettings):
outputs_dir: Path to directory for outputs.
custom_nodes_dir: Path to directory for custom nodes.
style_presets_dir: Path to directory for style presets.
workflow_thumbnails_dir: Path to directory for workflow thumbnails.
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
Expand Down Expand Up @@ -141,6 +142,7 @@ class InvokeAIAppConfig(BaseSettings):
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
workflow_thumbnails_dir: Path = Field(default=Path("workflow_thumbnails"), description="Path to directory for workflow thumbnails.")

# LOGGING
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
Expand Down Expand Up @@ -300,6 +302,11 @@ def style_presets_path(self) -> Path:
"""Path to the style presets directory, resolved to an absolute path.."""
return self._resolve(self.style_presets_dir)

@property
def workflow_thumbnails_path(self) -> Path:
"""Path to the workflow thumbnails directory, resolved to an absolute path.."""
return self._resolve(self.workflow_thumbnails_dir)

@property
def convert_cache_path(self) -> Path:
"""Path to the converted cache models directory, resolved to an absolute path.."""
Expand Down
3 changes: 3 additions & 0 deletions invokeai/app/services/invocation_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.urls.urls_base import UrlServiceBase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import WorkflowThumbnailServiceBase
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData


Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
style_preset_records: "StylePresetRecordsStorageBase",
style_preset_image_files: "StylePresetImageFileStorageBase",
workflow_thumbnails: "WorkflowThumbnailServiceBase",
):
self.board_images = board_images
self.board_image_records = board_image_records
Expand All @@ -91,3 +93,4 @@ def __init__(
self.conditioning = conditioning
self.style_preset_records = style_preset_records
self.style_preset_image_files = style_preset_image_files
self.workflow_thumbnails = workflow_thumbnails
5 changes: 5 additions & 0 deletions invokeai/app/services/urls/urls_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ def get_model_image_url(self, model_key: str) -> str:
def get_style_preset_image_url(self, style_preset_id: str) -> str:
"""Gets the URL for a style preset image"""
pass

@abstractmethod
def get_workflow_thumbnail_url(self, workflow_id: str) -> str:
"""Gets the URL for a workflow thumbnail"""
pass
3 changes: 3 additions & 0 deletions invokeai/app/services/urls/urls_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ def get_model_image_url(self, model_key: str) -> str:

def get_style_preset_image_url(self, style_preset_id: str) -> str:
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"

def get_workflow_thumbnail_url(self, workflow_id: str) -> str:
return f"{self._base_url}/workflows/i/{workflow_id}/thumbnail"
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,11 @@ class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):


WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)


class WorkflowRecordWithThumbnailDTO(WorkflowRecordDTO):
thumbnail_url: str | None = Field(default=None, description="The URL of the workflow thumbnail.")


class WorkflowRecordListItemWithThumbnailDTO(WorkflowRecordListItemDTO):
thumbnail_url: str | None = Field(default=None, description="The URL of the workflow thumbnail.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pathlib import Path

from PIL import Image


class WorkflowThumbnailFileNotFoundException(Exception):
"""Raised when a workflow thumbnail file is not found"""

def __init__(self, message: str = "Workflow thumbnail file not found"):
self.message = message
super().__init__(self.message)


class WorkflowThumbnailFileSaveException(Exception):
"""Raised when a workflow thumbnail file cannot be saved"""

def __init__(self, message: str = "Workflow thumbnail file cannot be saved"):
self.message = message
super().__init__(self.message)


class WorkflowThumbnailFileDeleteException(Exception):
"""Raised when a workflow thumbnail file cannot be deleted"""

def __init__(self, message: str = "Workflow thumbnail file cannot be deleted"):
self.message = message
super().__init__(self.message)


class WorkflowThumbnailServiceBase:
"""Base class for workflow thumbnail services"""

def get_path(self, workflow_id: str) -> Path:
"""Gets the path to a workflow thumbnail"""
raise NotImplementedError

def get_url(self, workflow_id: str) -> str | None:
"""Gets the URL of a workflow thumbnail"""
raise NotImplementedError

def save(self, workflow_id: str, image: Image.Image) -> None:
"""Saves a workflow thumbnail"""
raise NotImplementedError

def delete(self, workflow_id: str) -> None:
"""Deletes a workflow thumbnail"""
raise NotImplementedError
Loading