Skip to content

Commit

Permalink
fix device error
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed May 8, 2023
1 parent c23f576 commit 3a98003
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
17 changes: 12 additions & 5 deletions modules/diffusion/pipelines/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,22 @@ def to(self, device: torch.device = None, dtype: torch.dtype = None):
device = self.device
if dtype is None:
dtype = self.dtype
self.vae.to(device, dtype)
self.text_encoder.to(device, dtype)
self.unet.to(device, dtype)
self.tokenizer
self.scheduler

models = [
self.vae,
self.text_encoder,
self.unet,
]
for model in models:
if hasattr(model, "to"):
model.to(device, dtype)

if device is not None:
self.device = device
self.lpw.device = device
if dtype is not None:
self.dtype = dtype

return self

def enterers(self):
Expand Down
2 changes: 1 addition & 1 deletion modules/diffusion/pipelines/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def model_path(model_name):
tokenizer=tokenizer,
scheduler=scheduler,
full_acceleration=full_acceleration,
).to(device)
).to(device=device)
return pipe

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def activate(self):
model_id=self.model_id,
engine_dir=os.path.join(model_dir, "engine"),
use_auth_token=config.get("hf_token"),
device=get_device(),
max_batch_size=1,
device=device,
hf_cache_dir=hf_diffusers_cache_dir(),
full_acceleration=self.trt_full_acceleration_available(),
)
Expand Down

0 comments on commit 3a98003

Please sign in to comment.