Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy app entrypoint #7668

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
104 changes: 1 addition & 103 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import asyncio
import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from pathlib import Path

import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
Expand All @@ -15,11 +11,7 @@
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from torch.backends.mps import is_available as is_mps_available

# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
Expand All @@ -36,39 +28,15 @@
workflows,
)
from invokeai.app.api.sockets import SocketIO
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

app_config = get_config()


if is_mps_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)


logger = InvokeAILogger.get_logger(config=app_config)
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")

torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")

loop = asyncio.new_event_loop()

# We may change the port if the default is in use, this global variable is used to store the port so that we can log
# the correct port when the server starts in the lifespan handler.
port = app_config.port

# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -77,7 +45,7 @@ async def lifespan(app: FastAPI):

# Log the server address when it starts - in case the network log level is not high enough to see the startup log
proto = "https" if app_config.ssl_certfile else "http"
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"

# Logging this way ignores the logger's log level and _always_ logs the message
record = logger.makeRecord(
Expand Down Expand Up @@ -192,73 +160,3 @@ def overridden_redoc() -> HTMLResponse:
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here


def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)


def invoke_api() -> None:
def find_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_port(port=port + 1)
else:
return port

if app_config.dev_reload:
try:
import jurigged
except ImportError as e:
logger.error(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
exc_info=e,
)
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)

global port
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")

check_cudnn(logger)

config = uvicorn.Config(
app=app,
host=app_config.host,
port=port,
loop="asyncio",
log_level=app_config.log_level_network,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)

# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)

loop.run_until_complete(server.serve())


if __name__ == "__main__":
invoke_api()
74 changes: 68 additions & 6 deletions invokeai/app/run_app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,74 @@
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
import uvicorn

from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.startup_utils import (
apply_monkeypatches,
check_cudnn,
enable_dev_reload,
find_open_port,
register_mime_types,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.cli.arg_parser import InvokeAIArgs

def run_app() -> None:
# Before doing _anything_, parse CLI args!
from invokeai.frontend.cli.arg_parser import InvokeAIArgs

def get_app():
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
"""
from invokeai.app.api_app import app, loop

return app, loop


def run_app() -> None:
"""The main entrypoint for the app."""
# Parse the CLI arguments.
InvokeAIArgs.parse_args()

from invokeai.app.api_app import invoke_api
# Load config.
app_config = get_config()

logger = InvokeAILogger.get_logger(config=app_config)

# Find an open port, and modify the config accordingly.
orig_config_port = app_config.port
app_config.port = find_open_port(app_config.port)
if orig_config_port != app_config.port:
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")

# Miscellaneous startup tasks.
apply_monkeypatches()
register_mime_types()
if app_config.dev_reload:
enable_dev_reload()
check_cudnn(logger)

# Initialize the app and event loop.
app, loop = get_app()

# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)

# Start the server.
config = uvicorn.Config(
app=app,
host=app_config.host,
port=app_config.port,
loop="asyncio",
log_level=app_config.log_level_network,
ssl_certfile=app_config.ssl_certfile,
ssl_keyfile=app_config.ssl_keyfile,
)
server = uvicorn.Server(config)

# replace uvicorn's loggers with InvokeAI's for consistent appearance
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
uvicorn_logger.handlers.clear()
for hdlr in logger.handlers:
uvicorn_logger.addHandler(hdlr)

invoke_api()
loop.run_until_complete(server.serve())
64 changes: 64 additions & 0 deletions invokeai/app/util/startup_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging
import mimetypes
import socket

import torch


def find_open_port(port: int) -> int:
"""Find a port not in use starting at given port"""
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
# https://github.com/WaylonWalker
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
if s.connect_ex(("localhost", port)) == 0:
return find_open_port(port=port + 1)
else:
return port


def check_cudnn(logger: logging.Logger) -> None:
"""Check for cuDNN issues that could be causing degraded performance."""
if torch.backends.cudnn.is_available():
try:
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
cudnn_version = torch.backends.cudnn.version()
logger.info(f"cuDNN version: {cudnn_version}")
except RuntimeError as e:
logger.warning(
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
f"system. Full error message:\n{e}"
)


def enable_dev_reload() -> None:
"""Enable hot reloading on python file changes during development."""
from invokeai.backend.util.logging import InvokeAILogger

try:
import jurigged
except ImportError as e:
raise RuntimeError(
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
) from e
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)


def apply_monkeypatches() -> None:
"""Apply monkeypatches to fix issues with third-party libraries."""

import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)

if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)


def register_mime_types() -> None:
"""Register additional mime types for windows."""
# Fix for windows mimetypes registry entries being borked.
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")