Skip to content

Commit

Permalink
🐛 Fix in test_arch_mapde and test_arch_sccnn (#911)
Browse files Browse the repository at this point in the history
- If cuda is available model should be moved to cuda otherwise tests will fail as test data is moved to cuda.
  • Loading branch information
Jiaqi-Lv authored Feb 19, 2025
1 parent 264b079 commit ba0109f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
5 changes: 2 additions & 3 deletions tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
ON_GPU = toolbox_env.has_gpu()


def _load_mapde(name: str) -> torch.nn.Module:
def _load_mapde(name: str) -> MapDe:
"""Loads MapDe model with specified weights."""
model = MapDe()
weights_path = fetch_pretrained_weights(name)
map_location = select_device(on_gpu=ON_GPU)
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)

model.to(map_location)
return model


Expand All @@ -45,7 +45,6 @@ def test_functionality(remote_sample: Callable) -> None:
model = _load_mapde(name="mapde-conic")
patch = model.preproc(patch)
batch = torch.from_numpy(patch)[None]
model = model.to()
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
output = model.postproc(output[0])
assert np.all(output[0:2] == [[19, 171], [53, 89]])
5 changes: 2 additions & 3 deletions tests/models/test_arch_sccnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from tiatoolbox.wsicore.wsireader import WSIReader


def _load_sccnn(name: str) -> torch.nn.Module:
def _load_sccnn(name: str) -> SCCNN:
"""Loads SCCNN model with specified weights."""
model = SCCNN()
weights_path = fetch_pretrained_weights(name)
map_location = select_device(on_gpu=env_detection.has_gpu())
pretrained = torch.load(weights_path, map_location=map_location)
model.load_state_dict(pretrained)

model.to(map_location)
return model


Expand Down Expand Up @@ -48,7 +48,6 @@ def test_functionality(remote_sample: Callable) -> None:
)
output = model.postproc(output[0])
assert np.all(output == [[8, 7]])

model = _load_sccnn(name="sccnn-conic")
output = model.infer_batch(
model,
Expand Down

0 comments on commit ba0109f

Please sign in to comment.