Skip to content

Commit

Permalink
Update torch version
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed May 28, 2023
1 parent 0e38a1e commit 03b4363
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 28 deletions.
5 changes: 2 additions & 3 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ def prepare_environment():

torch_command = os.environ.get(
"TORCH_COMMAND",
"pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117",
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118",
)
xformers_command = os.environ.get(
"XFORMERS_COMMAND",
"pip install xformers==0.0.16",
"pip install xformers==0.0.20",
)

sys.argv, skip_install = extract_arg(sys.argv, "--skip-install")
Expand All @@ -181,7 +181,6 @@ def prepare_environment():

sys.argv, reinstall_torch = extract_arg(sys.argv, "--reinstall-torch")
sys.argv, reinstall_xformers = extract_arg(sys.argv, "--reinstall-xformers")
sys.argv, reinstall_tensorrt = extract_arg(sys.argv, "--reinstall-tensorrt")
tensorrt = "--tensorrt" in sys.argv

if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
Expand Down
2 changes: 2 additions & 0 deletions modules/acceleration/tensorrt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import tensorrt
import torch
from diffusers.models.attention_processor import AttnProcessor

from api.models.tensorrt import BuildEngineOptions, TensorRTEngineData
from lib.tensorrt.utilities import (
Expand Down Expand Up @@ -64,6 +65,7 @@ def build(self):
if model_name == "unet":
model = load_unet(self.model.model_id, device=self.device)
model = model.to(dtype=torch.float16)
model.set_attn_processor(AttnProcessor())
elif model_name == "clip":
model = load_text_encoder(self.model.model_id, device=self.device)
elif model_name == "vae":
Expand Down
4 changes: 3 additions & 1 deletion modules/diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def convert_checkpoint_to_pipe(model_id: str):
)


def load_unet(model_id: str, device: Optional[torch.device] = None):
def load_unet(
model_id: str, device: Optional[torch.device] = None
) -> UNet2DConditionModel:
temporary_pipe = convert_checkpoint_to_pipe(model_id)
if temporary_pipe is not None:
unet = temporary_pipe.unet
Expand Down
5 changes: 1 addition & 4 deletions modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def trt_available(self):
filepath = os.path.join(trt_path, *file.split("/"))
if not os.path.exists(filepath):
return False
trt_module_status, trt_version_status = utils.tensorrt_is_available()
if not trt_module_status or not trt_version_status:
return False
return config.get("tensorrt")
return utils.tensorrt_is_available() and config.get("tensorrt")

def trt_full_acceleration_available(self):
trt_path = self.get_trt_path()
Expand Down
7 changes: 6 additions & 1 deletion modules/tabs/generate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import *

import gradio as gr
Expand Down Expand Up @@ -93,6 +94,8 @@ def generate_image(self, opts, plugin_values):
else:
inference_steps = opts.num_inference_steps

start = time.perf_counter()

for data in model_manager.sd_model(opts, plugin_data):
if type(data) == tuple:
step, preview = data
Expand All @@ -112,11 +115,13 @@ def generate_image(self, opts, plugin_values):
else:
image = data

end = time.perf_counter()

results = []
for images, opts in image:
results.extend(images)

yield results, "Finished", gr.Button.update(
yield results, f"Finished in {end - start:0.4f} seconds", gr.Button.update(
value="Generate", variant="primary", interactive=True
)

Expand Down
3 changes: 1 addition & 2 deletions modules/tabs/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def sort(self):
return 2

def visible(self):
module, version = tensorrt_is_available()
return module and version and config.get("tensorrt")
return tensorrt_is_available() and config.get("tensorrt")

def ui(self, outlet):
with gr.Column():
Expand Down
18 changes: 1 addition & 17 deletions modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import base64
import importlib
import io
from distutils.version import LooseVersion
from typing import *

import numpy as np
Expand All @@ -12,11 +11,8 @@
from lib.diffusers.scheduler import SCHEDULERS

from . import config
from .logger import logger
from .shared import hf_diffusers_cache_dir

logged_trt_warning = False


def img2b64(img: Image.Image, format="png"):
buf = io.BytesIO()
Expand Down Expand Up @@ -56,19 +52,7 @@ def is_installed(package: str):


def tensorrt_is_available():
global logged_trt_warning
tensorrt = is_installed("tensorrt")
version = LooseVersion("2") > LooseVersion(torch.__version__) or LooseVersion(
"2.1"
) <= LooseVersion(torch.__version__)

if not tensorrt or not version:
if not logged_trt_warning and tensorrt and config.get("tensorrt"):
logger.warning(
"TensorRT is available, but torch version is not compatible."
)
logged_trt_warning = True
return tensorrt, version
return is_installed("tensorrt")


def fire_and_forget(f):
Expand Down

0 comments on commit 03b4363

Please sign in to comment.