diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 9cfebbdca17..fda232496e7 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,12 +1,8 @@ import asyncio import logging -import mimetypes -import socket from contextlib import asynccontextmanager from pathlib import Path -import torch -import uvicorn from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -15,11 +11,7 @@ from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from torch.backends.mps import is_available as is_mps_available -# for PyCharm: -# noinspection PyUnresolvedReferences -import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles @@ -36,39 +28,15 @@ workflows, ) from invokeai.app.api.sockets import SocketIO -from invokeai.app.invocations.load_custom_nodes import load_custom_nodes from invokeai.app.services.config.config_default import get_config from invokeai.app.util.custom_openapi import get_openapi_func -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger app_config = get_config() - - -if is_mps_available(): - import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) - - logger = InvokeAILogger.get_logger(config=app_config) -# fix for windows mimetypes registry entries being borked -# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 -mimetypes.add_type("application/javascript", ".js") -mimetypes.add_type("text/css", ".css") - -torch_device_name = TorchDevice.get_torch_device_name() -logger.info(f"Using torch device: {torch_device_name}") loop = asyncio.new_event_loop() -# We may change the port if the default is in use, this global variable is used to store the port so that we can log -# the correct port when the server starts in the lifespan handler. -port = app_config.port - -# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the -# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the -# core nodes have been imported so that we can catch when a custom node clobbers a core node. -load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path) - @asynccontextmanager async def lifespan(app: FastAPI): @@ -77,7 +45,7 @@ async def lifespan(app: FastAPI): # Log the server address when it starts - in case the network log level is not high enough to see the startup log proto = "https" if app_config.ssl_certfile else "http" - msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)" + msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)" # Logging this way ignores the logger's log level and _always_ logs the message record = logger.makeRecord( @@ -192,73 +160,3 @@ def overridden_redoc() -> HTMLResponse: app.mount( "/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static" ) # docs favicon is in here - - -def check_cudnn(logger: logging.Logger) -> None: - """Check for cuDNN issues that could be causing degraded performance.""" - if torch.backends.cudnn.is_available(): - try: - # Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first - # time it is called. Subsequent calls will return the version number without complaining about a mismatch. - cudnn_version = torch.backends.cudnn.version() - logger.info(f"cuDNN version: {cudnn_version}") - except RuntimeError as e: - logger.warning( - "Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually " - "caused by an incompatible cuDNN version installed in your python environment, or on the host " - f"system. Full error message:\n{e}" - ) - - -def invoke_api() -> None: - def find_port(port: int) -> int: - """Find a port not in use starting at given port""" - # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! - # https://github.com/WaylonWalker - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.settimeout(1) - if s.connect_ex(("localhost", port)) == 0: - return find_port(port=port + 1) - else: - return port - - if app_config.dev_reload: - try: - import jurigged - except ImportError as e: - logger.error( - 'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.', - exc_info=e, - ) - else: - jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info) - - global port - port = find_port(app_config.port) - if port != app_config.port: - logger.warn(f"Port {app_config.port} in use, using port {port}") - - check_cudnn(logger) - - config = uvicorn.Config( - app=app, - host=app_config.host, - port=port, - loop="asyncio", - log_level=app_config.log_level_network, - ssl_certfile=app_config.ssl_certfile, - ssl_keyfile=app_config.ssl_keyfile, - ) - server = uvicorn.Server(config) - - # replace uvicorn's loggers with InvokeAI's for consistent appearance - uvicorn_logger = InvokeAILogger.get_logger("uvicorn") - uvicorn_logger.handlers.clear() - for hdlr in logger.handlers: - uvicorn_logger.addHandler(hdlr) - - loop.run_until_complete(server.serve()) - - -if __name__ == "__main__": - invoke_api() diff --git a/invokeai/app/run_app.py b/invokeai/app/run_app.py index 701f1dab739..6eb64909927 100644 --- a/invokeai/app/run_app.py +++ b/invokeai/app/run_app.py @@ -1,12 +1,74 @@ -"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app.""" +import uvicorn +from invokeai.app.invocations.load_custom_nodes import load_custom_nodes +from invokeai.app.services.config.config_default import get_config +from invokeai.app.util.startup_utils import ( + apply_monkeypatches, + check_cudnn, + enable_dev_reload, + find_open_port, + register_mime_types, +) +from invokeai.backend.util.logging import InvokeAILogger +from invokeai.frontend.cli.arg_parser import InvokeAIArgs -def run_app() -> None: - # Before doing _anything_, parse CLI args! - from invokeai.frontend.cli.arg_parser import InvokeAIArgs +def get_app(): + """Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because + importing from api_app does a bunch of stuff - it's more like calling a function than importing a module. + """ + from invokeai.app.api_app import app, loop + + return app, loop + + +def run_app() -> None: + """The main entrypoint for the app.""" + # Parse the CLI arguments. InvokeAIArgs.parse_args() - from invokeai.app.api_app import invoke_api + # Load config. + app_config = get_config() + + logger = InvokeAILogger.get_logger(config=app_config) + + # Find an open port, and modify the config accordingly. + orig_config_port = app_config.port + app_config.port = find_open_port(app_config.port) + if orig_config_port != app_config.port: + logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.") + + # Miscellaneous startup tasks. + apply_monkeypatches() + register_mime_types() + if app_config.dev_reload: + enable_dev_reload() + check_cudnn(logger) + + # Initialize the app and event loop. + app, loop = get_app() + + # Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the + # invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the + # core nodes have been imported so that we can catch when a custom node clobbers a core node. + load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path) + + # Start the server. + config = uvicorn.Config( + app=app, + host=app_config.host, + port=app_config.port, + loop="asyncio", + log_level=app_config.log_level_network, + ssl_certfile=app_config.ssl_certfile, + ssl_keyfile=app_config.ssl_keyfile, + ) + server = uvicorn.Server(config) + + # replace uvicorn's loggers with InvokeAI's for consistent appearance + uvicorn_logger = InvokeAILogger.get_logger("uvicorn") + uvicorn_logger.handlers.clear() + for hdlr in logger.handlers: + uvicorn_logger.addHandler(hdlr) - invoke_api() + loop.run_until_complete(server.serve()) diff --git a/invokeai/app/util/startup_utils.py b/invokeai/app/util/startup_utils.py new file mode 100644 index 00000000000..726d40a7a65 --- /dev/null +++ b/invokeai/app/util/startup_utils.py @@ -0,0 +1,64 @@ +import logging +import mimetypes +import socket + +import torch + + +def find_open_port(port: int) -> int: + """Find a port not in use starting at given port""" + # Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon! + # https://github.com/WaylonWalker + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(1) + if s.connect_ex(("localhost", port)) == 0: + return find_open_port(port=port + 1) + else: + return port + + +def check_cudnn(logger: logging.Logger) -> None: + """Check for cuDNN issues that could be causing degraded performance.""" + if torch.backends.cudnn.is_available(): + try: + # Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first + # time it is called. Subsequent calls will return the version number without complaining about a mismatch. + cudnn_version = torch.backends.cudnn.version() + logger.info(f"cuDNN version: {cudnn_version}") + except RuntimeError as e: + logger.warning( + "Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually " + "caused by an incompatible cuDNN version installed in your python environment, or on the host " + f"system. Full error message:\n{e}" + ) + + +def enable_dev_reload() -> None: + """Enable hot reloading on python file changes during development.""" + from invokeai.backend.util.logging import InvokeAILogger + + try: + import jurigged + except ImportError as e: + raise RuntimeError( + 'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.' + ) from e + else: + jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info) + + +def apply_monkeypatches() -> None: + """Apply monkeypatches to fix issues with third-party libraries.""" + + import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) + + if torch.backends.mps.is_available(): + import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import) + + +def register_mime_types() -> None: + """Register additional mime types for windows.""" + # Fix for windows mimetypes registry entries being borked. + # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 + mimetypes.add_type("application/javascript", ".js") + mimetypes.add_type("text/css", ".css")