diff --git a/launch.py b/launch.py index 0894c55..256130e 100644 --- a/launch.py +++ b/launch.py @@ -168,11 +168,11 @@ def prepare_environment(): torch_command = os.environ.get( "TORCH_COMMAND", - "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117", + "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118", ) xformers_command = os.environ.get( "XFORMERS_COMMAND", - "pip install xformers==0.0.16", + "pip install xformers==0.0.20", ) sys.argv, skip_install = extract_arg(sys.argv, "--skip-install") @@ -181,7 +181,6 @@ def prepare_environment(): sys.argv, reinstall_torch = extract_arg(sys.argv, "--reinstall-torch") sys.argv, reinstall_xformers = extract_arg(sys.argv, "--reinstall-xformers") - sys.argv, reinstall_tensorrt = extract_arg(sys.argv, "--reinstall-tensorrt") tensorrt = "--tensorrt" in sys.argv if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): diff --git a/modules/acceleration/tensorrt/engine.py b/modules/acceleration/tensorrt/engine.py index 8984e94..727cc58 100644 --- a/modules/acceleration/tensorrt/engine.py +++ b/modules/acceleration/tensorrt/engine.py @@ -3,6 +3,7 @@ import tensorrt import torch +from diffusers.models.attention_processor import AttnProcessor from api.models.tensorrt import BuildEngineOptions, TensorRTEngineData from lib.tensorrt.utilities import ( @@ -64,6 +65,7 @@ def build(self): if model_name == "unet": model = load_unet(self.model.model_id, device=self.device) model = model.to(dtype=torch.float16) + model.set_attn_processor(AttnProcessor()) elif model_name == "clip": model = load_text_encoder(self.model.model_id, device=self.device) elif model_name == "vae": diff --git a/modules/diffusion/utils.py b/modules/diffusion/utils.py index ad6bf91..e683e3e 100644 --- a/modules/diffusion/utils.py +++ b/modules/diffusion/utils.py @@ -20,7 +20,9 @@ def convert_checkpoint_to_pipe(model_id: str): ) -def load_unet(model_id: str, device: Optional[torch.device] = None): +def load_unet( + model_id: str, device: Optional[torch.device] = None +) -> UNet2DConditionModel: temporary_pipe = convert_checkpoint_to_pipe(model_id) if temporary_pipe is not None: unet = temporary_pipe.unet diff --git a/modules/model.py b/modules/model.py index 31c7aaf..5d502ac 100644 --- a/modules/model.py +++ b/modules/model.py @@ -57,10 +57,7 @@ def trt_available(self): filepath = os.path.join(trt_path, *file.split("/")) if not os.path.exists(filepath): return False - trt_module_status, trt_version_status = utils.tensorrt_is_available() - if not trt_module_status or not trt_version_status: - return False - return config.get("tensorrt") + return utils.tensorrt_is_available() and config.get("tensorrt") def trt_full_acceleration_available(self): trt_path = self.get_trt_path() diff --git a/modules/tabs/generate.py b/modules/tabs/generate.py index d544e2f..0cd6166 100644 --- a/modules/tabs/generate.py +++ b/modules/tabs/generate.py @@ -1,3 +1,4 @@ +import time from typing import * import gradio as gr @@ -93,6 +94,8 @@ def generate_image(self, opts, plugin_values): else: inference_steps = opts.num_inference_steps + start = time.perf_counter() + for data in model_manager.sd_model(opts, plugin_data): if type(data) == tuple: step, preview = data @@ -112,11 +115,13 @@ def generate_image(self, opts, plugin_values): else: image = data + end = time.perf_counter() + results = [] for images, opts in image: results.extend(images) - yield results, "Finished", gr.Button.update( + yield results, f"Finished in {end - start:0.4f} seconds", gr.Button.update( value="Generate", variant="primary", interactive=True ) diff --git a/modules/tabs/tensorrt.py b/modules/tabs/tensorrt.py index 593677e..8d50df6 100644 --- a/modules/tabs/tensorrt.py +++ b/modules/tabs/tensorrt.py @@ -15,8 +15,7 @@ def sort(self): return 2 def visible(self): - module, version = tensorrt_is_available() - return module and version and config.get("tensorrt") + return tensorrt_is_available() and config.get("tensorrt") def ui(self, outlet): with gr.Column(): diff --git a/modules/utils.py b/modules/utils.py index fbf5a9e..44dcec9 100644 --- a/modules/utils.py +++ b/modules/utils.py @@ -2,7 +2,6 @@ import base64 import importlib import io -from distutils.version import LooseVersion from typing import * import numpy as np @@ -12,11 +11,8 @@ from lib.diffusers.scheduler import SCHEDULERS from . import config -from .logger import logger from .shared import hf_diffusers_cache_dir -logged_trt_warning = False - def img2b64(img: Image.Image, format="png"): buf = io.BytesIO() @@ -56,19 +52,7 @@ def is_installed(package: str): def tensorrt_is_available(): - global logged_trt_warning - tensorrt = is_installed("tensorrt") - version = LooseVersion("2") > LooseVersion(torch.__version__) or LooseVersion( - "2.1" - ) <= LooseVersion(torch.__version__) - - if not tensorrt or not version: - if not logged_trt_warning and tensorrt and config.get("tensorrt"): - logger.warning( - "TensorRT is available, but torch version is not compatible." - ) - logged_trt_warning = True - return tensorrt, version + return is_installed("tensorrt") def fire_and_forget(f):