From 3a98003c2d0cf418bd12a738fc00cff84e68f0c9 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Mon, 8 May 2023 09:50:47 +0900 Subject: [PATCH] fix device error --- modules/diffusion/pipelines/diffusers.py | 17 ++++++++++++----- modules/diffusion/pipelines/tensorrt.py | 2 +- modules/model.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/modules/diffusion/pipelines/diffusers.py b/modules/diffusion/pipelines/diffusers.py index 7f11bc57..e4511b30 100644 --- a/modules/diffusion/pipelines/diffusers.py +++ b/modules/diffusion/pipelines/diffusers.py @@ -127,15 +127,22 @@ def to(self, device: torch.device = None, dtype: torch.dtype = None): device = self.device if dtype is None: dtype = self.dtype - self.vae.to(device, dtype) - self.text_encoder.to(device, dtype) - self.unet.to(device, dtype) - self.tokenizer - self.scheduler + + models = [ + self.vae, + self.text_encoder, + self.unet, + ] + for model in models: + if hasattr(model, "to"): + model.to(device, dtype) + if device is not None: self.device = device + self.lpw.device = device if dtype is not None: self.dtype = dtype + return self def enterers(self): diff --git a/modules/diffusion/pipelines/tensorrt.py b/modules/diffusion/pipelines/tensorrt.py index c3250679..87e91da5 100644 --- a/modules/diffusion/pipelines/tensorrt.py +++ b/modules/diffusion/pipelines/tensorrt.py @@ -114,7 +114,7 @@ def model_path(model_name): tokenizer=tokenizer, scheduler=scheduler, full_acceleration=full_acceleration, - ).to(device) + ).to(device=device) return pipe def __init__( diff --git a/modules/model.py b/modules/model.py index a237c65b..83dcca79 100644 --- a/modules/model.py +++ b/modules/model.py @@ -107,8 +107,8 @@ def activate(self): model_id=self.model_id, engine_dir=os.path.join(model_dir, "engine"), use_auth_token=config.get("hf_token"), - device=get_device(), max_batch_size=1, + device=device, hf_cache_dir=hf_diffusers_cache_dir(), full_acceleration=self.trt_full_acceleration_available(), )