Skip to content

Commit

Permalink
🔨 Use model.to
Browse files Browse the repository at this point in the history
  • Loading branch information
shaneahmed committed Feb 9, 2024
1 parent 030d84f commit 2c0a5b8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.utils.data import DataLoader

from tiatoolbox.annotation import AnnotationStore
from tiatoolbox.models.models_abc import ModelABC
from tiatoolbox.typing import IntPair, Resolution, Units
from tiatoolbox.wsicore.wsireader import WSIReader

Expand Down Expand Up @@ -85,7 +86,7 @@ class EngineABC(ABC):
"""Abstract base class for engines used in tiatoolbox.
Args:
model (str | nn.Module):
model (str | ModelABC):
A PyTorch model. Default is `None`.
The user can request pretrained models from the toolbox using
the list of pretrained models available at this `link
Expand Down Expand Up @@ -191,7 +192,7 @@ class EngineABC(ABC):

def __init__(
self: EngineABC,
model: str | nn.Module,
model: str | ModelABC,
batch_size: int = 8,
num_loader_workers: int = 0,
num_post_proc_workers: int = 0,
Expand All @@ -213,7 +214,7 @@ def __init__(
model=model,
weights=weights,
)
self.model = model_to(model=self.model, device=self.device)
self.model.to(device=self.device)
self._ioconfig = self.ioconfig # runtime ioconfig

self.batch_size = batch_size
Expand Down

0 comments on commit 2c0a5b8

Please sign in to comment.