diff --git a/.gitignore b/.gitignore
index 409fc1261..66c072da5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -116,3 +116,6 @@ ENV/
 
 # vim/vi generated
 *.swp
+
+# output zarr generated
+*.zarr
diff --git a/tests/models/test_abc.py b/tests/models/test_abc.py
index d8af37193..f7a60e34c 100644
--- a/tests/models/test_abc.py
+++ b/tests/models/test_abc.py
@@ -6,14 +6,15 @@
 
 import pytest
 import torch
+import torchvision.models as torch_models
 from torch import nn
 
-from tiatoolbox import rcParam
+from tiatoolbox import rcParam, utils
 from tiatoolbox.models.architecture import (
     fetch_pretrained_weights,
     get_pretrained_model,
 )
-from tiatoolbox.models.models_abc import ModelABC
+from tiatoolbox.models.models_abc import ModelABC, model_to
 from tiatoolbox.utils import env_detection as toolbox_env
 
 if TYPE_CHECKING:
@@ -149,3 +150,18 @@ def test_model_abc() -> None:
     weights_path = fetch_pretrained_weights("alexnet-kather100k")
     with pytest.raises(RuntimeError, match=r".*loading state_dict*"):
         _ = model.load_weights_from_file(weights_path)
+
+
+def test_model_to() -> None:
+    """Test for placing model on device."""
+    # Test on GPU
+    # no GPU on GitHub Actions so this will crash
+    if not utils.env_detection.has_gpu():
+        model = torch_models.resnet18()
+        with pytest.raises((AssertionError, RuntimeError)):
+            _ = model_to(device="cuda", model=model)
+
+    # Test on CPU
+    model = torch_models.resnet18()
+    model = model_to(device="cpu", model=model)
+    assert isinstance(model, nn.Module)
diff --git a/tests/models/test_arch_mapde.py b/tests/models/test_arch_mapde.py
index febcfbdec..61bfde817 100644
--- a/tests/models/test_arch_mapde.py
+++ b/tests/models/test_arch_mapde.py
@@ -45,7 +45,7 @@ def test_functionality(remote_sample: Callable) -> None:
     model = _load_mapde(name="mapde-conic")
     patch = model.preproc(patch)
     batch = torch.from_numpy(patch)[None]
-    model = model.to(select_device(on_gpu=ON_GPU))
-    output = model.infer_batch(model, batch, on_gpu=ON_GPU)
+    model = model.to()
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
     output = model.postproc(output[0])
     assert np.all(output[0:2] == [[19, 171], [53, 89]])
diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py
index cd4bd0833..e7aa23d5b 100644
--- a/tests/models/test_arch_micronet.py
+++ b/tests/models/test_arch_micronet.py
@@ -39,7 +39,7 @@ def test_functionality(
     model = model.to(map_location)
     pretrained = torch.load(weights_path, map_location=map_location)
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=ON_GPU)
+    output = model.infer_batch(model, batch, device=map_location)
     output, _ = model.postproc(output[0])
     assert np.max(np.unique(output)) == 46
 
diff --git a/tests/models/test_arch_nuclick.py b/tests/models/test_arch_nuclick.py
index fda0c01a6..b84516125 100644
--- a/tests/models/test_arch_nuclick.py
+++ b/tests/models/test_arch_nuclick.py
@@ -10,6 +10,7 @@
 from tiatoolbox.models import NuClick
 from tiatoolbox.models.architecture import fetch_pretrained_weights
 from tiatoolbox.utils import imread
+from tiatoolbox.utils.misc import select_device
 
 ON_GPU = False
 
@@ -53,7 +54,7 @@ def test_functional_nuclick(
     model = NuClick(num_input_channels=5, num_output_channels=1)
     pretrained = torch.load(weights_path, map_location="cpu")
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=ON_GPU)
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
     postproc_masks = model.postproc(
         output,
         do_reconstruction=True,
diff --git a/tests/models/test_arch_sccnn.py b/tests/models/test_arch_sccnn.py
index b3dd94e50..2729d2b3a 100644
--- a/tests/models/test_arch_sccnn.py
+++ b/tests/models/test_arch_sccnn.py
@@ -5,9 +5,10 @@
 import numpy as np
 import torch
 
-from tiatoolbox import utils
 from tiatoolbox.models import SCCNN
 from tiatoolbox.models.architecture import fetch_pretrained_weights
+from tiatoolbox.utils import env_detection
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.wsicore.wsireader import WSIReader
 
 
@@ -15,7 +16,7 @@ def _load_sccnn(name: str) -> torch.nn.Module:
     """Loads SCCNN model with specified weights."""
     model = SCCNN()
     weights_path = fetch_pretrained_weights(name)
-    map_location = utils.misc.select_device(on_gpu=utils.env_detection.has_gpu())
+    map_location = select_device(on_gpu=env_detection.has_gpu())
     pretrained = torch.load(weights_path, map_location=map_location)
     model.load_state_dict(pretrained)
 
@@ -40,11 +41,19 @@ def test_functionality(remote_sample: Callable) -> None:
     )
     batch = torch.from_numpy(patch)[None]
     model = _load_sccnn(name="sccnn-crchisto")
-    output = model.infer_batch(model, batch, on_gpu=False)
+    output = model.infer_batch(
+        model,
+        batch,
+        device=select_device(on_gpu=env_detection.has_gpu()),
+    )
     output = model.postproc(output[0])
     assert np.all(output == [[8, 7]])
 
     model = _load_sccnn(name="sccnn-conic")
-    output = model.infer_batch(model, batch, on_gpu=False)
+    output = model.infer_batch(
+        model,
+        batch,
+        device=select_device(on_gpu=env_detection.has_gpu()),
+    )
     output = model.postproc(output[0])
     assert np.all(output == [[7, 8]])
diff --git a/tests/models/test_arch_unet.py b/tests/models/test_arch_unet.py
index b0cbc6085..2ac231c7c 100644
--- a/tests/models/test_arch_unet.py
+++ b/tests/models/test_arch_unet.py
@@ -9,6 +9,7 @@
 
 from tiatoolbox.models.architecture import fetch_pretrained_weights
 from tiatoolbox.models.architecture.unet import UNetModel
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.wsicore.wsireader import WSIReader
 
 ON_GPU = False
@@ -48,7 +49,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
     model = UNetModel(3, 2, encoder="resnet50", decoder_block=[3])
     pretrained = torch.load(pretrained_weights, map_location="cpu")
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=ON_GPU)
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
     _ = output[0]
 
     # run untrained network to test for architecture
@@ -60,4 +61,4 @@ def test_functional_unet(remote_sample: Callable) -> None:
         encoder_levels=[32, 64],
         skip_type="concat",
     )
-    _ = model.infer_batch(model, batch, on_gpu=ON_GPU)
+    _ = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py
index 29c76ab4e..a87424dfd 100644
--- a/tests/models/test_arch_vanilla.py
+++ b/tests/models/test_arch_vanilla.py
@@ -5,10 +5,11 @@
 import torch
 
 from tiatoolbox.models.architecture.vanilla import CNNModel, TimmModel
-from tiatoolbox.utils.misc import model_to
+from tiatoolbox.models.models_abc import model_to
 
 ON_GPU = False
 RNG = np.random.default_rng()  # Numpy Random Generator
+device = "cuda" if ON_GPU else "cpu"
 
 
 def test_functional() -> None:
@@ -43,8 +44,8 @@ def test_functional() -> None:
     try:
         for backbone in backbones:
             model = CNNModel(backbone, num_classes=1)
-            model_ = model_to(on_gpu=ON_GPU, model=model)
-            model.infer_batch(model_, samples, on_gpu=ON_GPU)
+            model_ = model_to(device=device, model=model)
+            model.infer_batch(model_, samples, device=device)
     except ValueError as exc:
         msg = f"Model {backbone} failed."
         raise AssertionError(msg) from exc
@@ -70,8 +71,8 @@ def test_timm_functional() -> None:
     try:
         for backbone in backbones:
             model = TimmModel(backbone=backbone, num_classes=1, pretrained=False)
-            model_ = model_to(on_gpu=ON_GPU, model=model)
-            model.infer_batch(model_, samples, on_gpu=ON_GPU)
+            model_ = model_to(device=device, model=model)
+            model.infer_batch(model_, samples, device=device)
     except ValueError as exc:
         msg = f"Model {backbone} failed."
         raise AssertionError(msg) from exc
diff --git a/tests/models/test_feature_extractor.py b/tests/models/test_feature_extractor.py
index 15468ab32..cd33f0a5a 100644
--- a/tests/models/test_feature_extractor.py
+++ b/tests/models/test_feature_extractor.py
@@ -14,6 +14,7 @@
     IOSegmentorConfig,
 )
 from tiatoolbox.utils import env_detection as toolbox_env
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.wsicore.wsireader import WSIReader
 
 ON_GPU = not toolbox_env.running_on_ci() and toolbox_env.has_gpu()
@@ -35,7 +36,7 @@ def test_engine(remote_sample: Callable, tmp_path: Path) -> None:
     output_list = extractor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -82,7 +83,7 @@ def test_full_inference(
         [mini_wsi_svs],
         mode="wsi",
         ioconfig=ioconfig,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
diff --git a/tests/models/test_hovernet.py b/tests/models/test_hovernet.py
index b2271ab4c..2567018b8 100644
--- a/tests/models/test_hovernet.py
+++ b/tests/models/test_hovernet.py
@@ -14,6 +14,7 @@
     ResidualBlock,
     TFSamepaddingLayer,
 )
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.wsicore.wsireader import WSIReader
 
 
@@ -34,7 +35,7 @@ def test_functionality(remote_sample: Callable) -> None:
     weights_path = fetch_pretrained_weights("hovernet_fast-pannuke")
     pretrained = torch.load(weights_path)
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=False)
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
     output = [v[0] for v in output]
     output = model.postproc(output)
     assert len(output[1]) > 0, "Must have some nuclei."
@@ -51,7 +52,7 @@ def test_functionality(remote_sample: Callable) -> None:
     weights_path = fetch_pretrained_weights("hovernet_fast-monusac")
     pretrained = torch.load(weights_path)
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=False)
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
     output = [v[0] for v in output]
     output = model.postproc(output)
     assert len(output[1]) > 0, "Must have some nuclei."
@@ -68,7 +69,7 @@ def test_functionality(remote_sample: Callable) -> None:
     weights_path = fetch_pretrained_weights("hovernet_original-consep")
     pretrained = torch.load(weights_path)
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=False)
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
     output = [v[0] for v in output]
     output = model.postproc(output)
     assert len(output[1]) > 0, "Must have some nuclei."
@@ -85,7 +86,7 @@ def test_functionality(remote_sample: Callable) -> None:
     weights_path = fetch_pretrained_weights("hovernet_original-kumar")
     pretrained = torch.load(weights_path)
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=False)
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
     output = [v[0] for v in output]
     output = model.postproc(output)
     assert len(output[1]) > 0, "Must have some nuclei."
diff --git a/tests/models/test_hovernetplus.py b/tests/models/test_hovernetplus.py
index 96d0f9d23..1377fdd82 100644
--- a/tests/models/test_hovernetplus.py
+++ b/tests/models/test_hovernetplus.py
@@ -7,6 +7,7 @@
 from tiatoolbox.models import HoVerNetPlus
 from tiatoolbox.models.architecture import fetch_pretrained_weights
 from tiatoolbox.utils import imread
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.utils.transforms import imresize
 
 
@@ -28,7 +29,7 @@ def test_functionality(remote_sample: Callable) -> None:
     weights_path = fetch_pretrained_weights("hovernetplus-oed")
     pretrained = torch.load(weights_path)
     model.load_state_dict(pretrained)
-    output = model.infer_batch(model, batch, on_gpu=False)
+    output = model.infer_batch(model, batch, device=select_device(on_gpu=False))
     assert len(output) == 4, "Must contain predictions for: np, hv, tp and ls branches."
     output = [v[0] for v in output]
     output = model.postproc(output)
diff --git a/tests/models/test_multi_task_segmentor.py b/tests/models/test_multi_task_segmentor.py
index a7e76f719..1f135b303 100644
--- a/tests/models/test_multi_task_segmentor.py
+++ b/tests/models/test_multi_task_segmentor.py
@@ -17,6 +17,7 @@
 from tiatoolbox.utils import env_detection as toolbox_env
 from tiatoolbox.utils import imwrite
 from tiatoolbox.utils.metrics import f1_detection
+from tiatoolbox.utils.misc import select_device
 
 ON_GPU = toolbox_env.has_gpu()
 BATCH_SIZE = 1 if not ON_GPU else 8  # 16
@@ -64,7 +65,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
     output = multi_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -83,7 +84,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
     output = multi_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -117,7 +118,7 @@ def test_functionality_hovernetplus(remote_sample: Callable, tmp_path: Path) ->
     output = multi_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -148,7 +149,7 @@ def test_functionality_hovernet(remote_sample: Callable, tmp_path: Path) -> None
     output = multi_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -195,7 +196,7 @@ def test_masked_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
         masks=[sample_wsi_msk],
         mode="wsi",
         ioconfig=ioconfig,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -230,7 +231,7 @@ def test_functionality_process_instance_predictions(
     output = semantic_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -268,7 +269,7 @@ def test_empty_image(tmp_path: Path) -> None:
     _ = multi_segmentor.predict(
         [sample_patch_path],
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -284,7 +285,7 @@ def test_empty_image(tmp_path: Path) -> None:
     _ = multi_segmentor.predict(
         [sample_patch_path],
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -312,7 +313,7 @@ def test_empty_image(tmp_path: Path) -> None:
     _ = multi_segmentor.predict(
         [sample_patch_path],
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
         ioconfig=bcc_wsi_ioconfig,
@@ -361,7 +362,7 @@ def test_functionality_semantic(remote_sample: Callable, tmp_path: Path) -> None
     output = multi_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
         ioconfig=bcc_wsi_ioconfig,
@@ -413,7 +414,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
             masks=[sample_wsi_msk],
             mode="wsi",
             ioconfig=ioconfig,
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
             crash_on_exception=True,
             save_dir=save_dir,
         )
diff --git a/tests/models/test_nucleus_instance_segmentor.py b/tests/models/test_nucleus_instance_segmentor.py
index ff6b9a4cc..2956849fb 100644
--- a/tests/models/test_nucleus_instance_segmentor.py
+++ b/tests/models/test_nucleus_instance_segmentor.py
@@ -28,6 +28,7 @@
 from tiatoolbox.utils import env_detection as toolbox_env
 from tiatoolbox.utils import imwrite
 from tiatoolbox.utils.metrics import f1_detection
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.wsicore.wsireader import WSIReader
 
 ON_GPU = toolbox_env.has_gpu()
@@ -278,7 +279,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
             masks=[sample_wsi_msk],
             mode="wsi",
             ioconfig=ioconfig,
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
             crash_on_exception=True,
             save_dir=save_dir,
         )
@@ -326,7 +327,7 @@ def test_functionality_ci(remote_sample: Callable, tmp_path: Path) -> None:
         [mini_wsi_svs],
         mode="wsi",
         ioconfig=ioconfig,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -373,7 +374,7 @@ def test_functionality_merge_tile_predictions_ci(
     output = semantic_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         ioconfig=ioconfig,
         crash_on_exception=True,
         save_dir=save_dir,
@@ -453,7 +454,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
     output = inst_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=True,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -471,7 +472,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
     output = inst_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=True,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
@@ -496,7 +497,7 @@ def test_functionality_local(remote_sample: Callable, tmp_path: Path) -> None:
     output = semantic_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=True,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir,
     )
diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py
index 5fd930138..913d63241 100644
--- a/tests/models/test_patch_predictor.py
+++ b/tests/models/test_patch_predictor.py
@@ -25,6 +25,7 @@
 )
 from tiatoolbox.utils import download_data, imread, imwrite
 from tiatoolbox.utils import env_detection as toolbox_env
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.wsicore.wsireader import WSIReader
 
 ON_GPU = toolbox_env.has_gpu()
@@ -547,7 +548,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
                 [mini_wsi_svs],
                 mode="wsi",
                 save_dir=f"{tmp_path}/dump",
-                on_gpu=ON_GPU,
+                device=select_device(on_gpu=ON_GPU),
                 **_kwargs,
             )
         shutil.rmtree(tmp_path / "dump", ignore_errors=True)
@@ -563,7 +564,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
         ioconfig=ioconfig,
         mode="wsi",
         save_dir=f"{tmp_path}/dump",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
     )
     shutil.rmtree(tmp_path / "dump", ignore_errors=True)
 
@@ -571,7 +572,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
         [mini_wsi_svs],
         mode="wsi",
         save_dir=f"{tmp_path}/dump",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         **kwargs,
     )
     shutil.rmtree(tmp_path / "dump", ignore_errors=True)
@@ -582,7 +583,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
         [mini_wsi_svs],
         patch_input_shape=(300, 300),
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=f"{tmp_path}/dump",
     )
     assert predictor._ioconfig.patch_input_shape == (300, 300)
@@ -592,7 +593,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
         [mini_wsi_svs],
         stride_shape=(300, 300),
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=f"{tmp_path}/dump",
     )
     assert predictor._ioconfig.stride_shape == (300, 300)
@@ -602,7 +603,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
         [mini_wsi_svs],
         resolution=1.99,
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=f"{tmp_path}/dump",
     )
     assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99
@@ -612,7 +613,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
         [mini_wsi_svs],
         units="baseline",
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=f"{tmp_path}/dump",
     )
     assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline"
@@ -624,7 +625,7 @@ def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
         mode="wsi",
         merge_predictions=True,
         save_dir=f"{tmp_path}/dump",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
     )
     shutil.rmtree(tmp_path / "dump", ignore_errors=True)
 
@@ -643,7 +644,7 @@ def test_patch_predictor_api(
     # don't run test on GPU
     output = predictor.predict(
         inputs,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=save_dir_path,
     )
     assert sorted(output.keys()) == ["predictions"]
@@ -654,7 +655,7 @@ def test_patch_predictor_api(
         inputs,
         labels=[1, "a"],
         return_labels=True,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=save_dir_path,
     )
     assert sorted(output.keys()) == sorted(["labels", "predictions"])
@@ -665,7 +666,7 @@ def test_patch_predictor_api(
     output = predictor.predict(
         inputs,
         return_probabilities=True,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=save_dir_path,
     )
     assert sorted(output.keys()) == sorted(["predictions", "probabilities"])
@@ -677,7 +678,7 @@ def test_patch_predictor_api(
         return_probabilities=True,
         labels=[1, "a"],
         return_labels=True,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=save_dir_path,
     )
     assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"])
@@ -687,7 +688,7 @@ def test_patch_predictor_api(
     # test saving output, should have no effect
     _ = predictor.predict(
         inputs,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir="special_dir_not_exist",
     )
     assert not Path.is_dir(Path("special_dir_not_exist"))
@@ -721,7 +722,7 @@ def test_patch_predictor_api(
         return_probabilities=True,
         labels=[1, "a"],
         return_labels=True,
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         save_dir=save_dir_path,
     )
     assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"])
@@ -751,7 +752,7 @@ def test_wsi_predictor_api(
     kwargs = {
         "return_probabilities": True,
         "return_labels": True,
-        "on_gpu": ON_GPU,
+        "device": select_device(on_gpu=ON_GPU),
         "patch_input_shape": patch_size,
         "stride_shape": patch_size,
         "resolution": 1.0,
@@ -788,7 +789,7 @@ def test_wsi_predictor_api(
     kwargs = {
         "return_probabilities": True,
         "return_labels": True,
-        "on_gpu": ON_GPU,
+        "device": select_device(on_gpu=ON_GPU),
         "patch_input_shape": patch_size,
         "stride_shape": patch_size,
         "resolution": 0.5,
@@ -903,7 +904,7 @@ def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None:
     kwargs = {
         "return_probabilities": True,
         "return_labels": True,
-        "on_gpu": ON_GPU,
+        "device": select_device(on_gpu=ON_GPU),
         "patch_input_shape": np.array([224, 224]),
         "stride_shape": np.array([224, 224]),
         "resolution": 1.0,
@@ -958,8 +959,7 @@ def _test_predictor_output(
     pretrained_model: str,
     probabilities_check: list | None = None,
     predictions_check: list | None = None,
-    *,
-    on_gpu: bool = ON_GPU,
+    device: str = select_device(on_gpu=ON_GPU),
 ) -> None:
     """Test the predictions of multiple models included in tiatoolbox."""
     predictor = PatchPredictor(
@@ -972,7 +972,7 @@ def _test_predictor_output(
         inputs,
         return_probabilities=True,
         return_labels=False,
-        on_gpu=on_gpu,
+        device=device,
     )
     predictions = output["predictions"]
     probabilities = output["probabilities"]
@@ -1025,7 +1025,7 @@ def test_patch_predictor_kather100k_output(
             pretrained_model,
             probabilities_check=expected_prob,
             predictions_check=[6, 3],
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
         )
         # only test 1 on travis to limit runtime
         if toolbox_env.running_on_ci():
@@ -1060,7 +1060,7 @@ def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -
             pretrained_model,
             probabilities_check=expected_prob,
             predictions_check=[1, 0],
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
         )
         # only test 1 on travis to limit runtime
         if toolbox_env.running_on_ci():
diff --git a/tests/models/test_semantic_segmentation.py b/tests/models/test_semantic_segmentation.py
index 8fee41a9b..01776b800 100644
--- a/tests/models/test_semantic_segmentation.py
+++ b/tests/models/test_semantic_segmentation.py
@@ -32,6 +32,7 @@
 from tiatoolbox.models.models_abc import ModelABC
 from tiatoolbox.utils import env_detection as toolbox_env
 from tiatoolbox.utils import imread, imwrite
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.wsicore.wsireader import WSIReader
 
 ON_GPU = toolbox_env.has_gpu()
@@ -70,12 +71,7 @@ def forward(self: _CNNTo1, img: np.ndarray) -> torch.Tensor:
         return self.conv(img)
 
     @staticmethod
-    def infer_batch(
-        model: nn.Module,
-        batch_data: torch.Tensor,
-        *,
-        on_gpu: bool,
-    ) -> list:
+    def infer_batch(model: nn.Module, batch_data: torch.Tensor, device: str) -> list:
         """Run inference on an input batch.
 
         Contains logic for forward operation as well as i/o
@@ -85,10 +81,14 @@ def infer_batch(
             model (nn.Module): PyTorch defined model.
             batch_data (torch.Tensor): A batch of data generated by
                 torch.utils.data.DataLoader.
-            on_gpu (bool): Whether to run inference on a GPU.
+            device (str):
+                :class:`torch.device` to run the model.
+                Select the device to run the model. Please see
+                https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
+                for more details on input parameters for device. Default value is "cpu".
 
         """
-        device = "cuda" if on_gpu else "cpu"
+        device = "cuda" if ON_GPU else "cpu"
         ####
         model.eval()  # infer mode
 
@@ -307,7 +307,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
         semantic_segmentor.predict(
             [mini_wsi_jpg],
             mode="tile",
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
             crash_on_exception=True,
             save_dir=save_dir,
         )
@@ -325,7 +325,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
         semantic_segmentor.predict(
             [mini_wsi_jpg],
             mode="tile",
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
             crash_on_exception=True,
             save_dir=save_dir,
         )
@@ -339,7 +339,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
             [mini_wsi_svs],
             patch_input_shape=(2048, 2048),
             mode="wsi",
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
             crash_on_exception=True,
             save_dir=save_dir,
         )
@@ -350,7 +350,7 @@ def test_crash_segmentor(remote_sample: Callable, tmp_path: Path) -> None:
         [mini_wsi_svs],
         patch_input_shape=(2048, 2048),
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=False,
         save_dir=save_dir,
     )
@@ -494,7 +494,7 @@ def test_functional_segmentor(
     semantic_segmentor.predict(
         [mini_wsi_jpg],
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         patch_input_shape=(512, 512),
         resolution=resolution,
         units="mpp",
@@ -506,7 +506,7 @@ def test_functional_segmentor(
     semantic_segmentor.predict(
         [mini_wsi_jpg],
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         patch_input_shape=(512, 512),
         resolution=1 / resolution,
         units="baseline",
@@ -521,7 +521,7 @@ def test_functional_segmentor(
         semantic_segmentor.predict(
             [mini_wsi_jpg],
             mode="tile",
-            on_gpu=ON_GPU,
+            device=select_device(on_gpu=ON_GPU),
             patch_input_shape=(512, 512),
             patch_output_shape=(512, 512),
             stride_shape=(512, 512),
@@ -552,7 +552,7 @@ def test_functional_segmentor(
     output_list = semantic_segmentor.predict(
         file_list,
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         ioconfig=ioconfig,
         crash_on_exception=True,
         save_dir=f"{save_dir}/raw/",
@@ -581,7 +581,7 @@ def test_functional_segmentor(
         [mini_wsi_svs],
         masks=[mini_wsi_msk],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         ioconfig=ioconfig,
         crash_on_exception=True,
         save_dir=f"{save_dir}/raw/",
@@ -605,7 +605,7 @@ def test_functional_segmentor(
         [mini_wsi_svs],
         masks=[mini_wsi_msk],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         ioconfig=ioconfig,
         crash_on_exception=True,
         save_dir=f"{save_dir}/raw/",
@@ -631,7 +631,7 @@ def __init__(self: XSegmentor) -> None:
     semantic_segmentor.predict(
         [mini_wsi_jpg],
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         patch_input_shape=(1024, 1024),
         patch_output_shape=(512, 512),
         stride_shape=(256, 256),
@@ -661,7 +661,7 @@ def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None:
     semantic_segmentor.predict(
         [mini_wsi_svs],
         mode="wsi",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=f"{save_dir}/raw/",
     )
@@ -672,7 +672,7 @@ def test_functional_pretrained(remote_sample: Callable, tmp_path: Path) -> None:
     semantic_segmentor.predict(
         [mini_wsi_jpg],
         mode="tile",
-        on_gpu=ON_GPU,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=f"{save_dir}/raw/",
     )
@@ -699,7 +699,7 @@ def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) ->
     semantic_segmentor.predict(
         [wsi_with_artifacts],
         mode="wsi",
-        on_gpu=True,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir / "raw",
     )
@@ -715,7 +715,7 @@ def test_behavior_tissue_mask_local(remote_sample: Callable, tmp_path: Path) ->
     semantic_segmentor.predict(
         [mini_wsi_jpg],
         mode="tile",
-        on_gpu=True,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=f"{save_dir}/raw/",
     )
@@ -738,7 +738,7 @@ def test_behavior_bcss_local(remote_sample: Callable, tmp_path: Path) -> None:
     semantic_segmentor.predict(
         [wsi_breast],
         mode="wsi",
-        on_gpu=True,
+        device=select_device(on_gpu=ON_GPU),
         crash_on_exception=True,
         save_dir=save_dir / "raw",
     )
diff --git a/tests/test_annotation_stores.py b/tests/test_annotation_stores.py
index 01bbdac45..66c990161 100644
--- a/tests/test_annotation_stores.py
+++ b/tests/test_annotation_stores.py
@@ -53,14 +53,6 @@
 FILLED_LEN = 2 * (GRID_SIZE[0] * GRID_SIZE[1])
 RNG = np.random.default_rng(0)  # Numpy Random Generator
 
-# ----------------------------------------------------------------------
-# Resets
-# ----------------------------------------------------------------------
-
-# Reset filters in logger.
-for filter_ in logger.filters:
-    logger.removeFilter(filter_)
-
 # ----------------------------------------------------------------------
 # Helper Functions
 # ----------------------------------------------------------------------
@@ -546,6 +538,9 @@ def test_sqlite_store_compile_options_missing_math(
     caplog: pytest.LogCaptureFixture,
 ) -> None:
     """Test that a warning is shown if the sqlite math module is missing."""
+    # Reset filters in logger.
+    for filter_ in logger.filters[:]:
+        logger.removeFilter(filter_)
     monkeypatch.setattr(
         SQLiteStore,
         "compile_options",
diff --git a/tests/test_annotation_tilerendering.py b/tests/test_annotation_tilerendering.py
index 0734b9164..0ee34b17b 100644
--- a/tests/test_annotation_tilerendering.py
+++ b/tests/test_annotation_tilerendering.py
@@ -23,7 +23,7 @@
 from tiatoolbox.annotation import Annotation, AnnotationStore, SQLiteStore
 from tiatoolbox.tools.pyramid import AnnotationTileGenerator
 from tiatoolbox.utils.env_detection import running_on_travis
-from tiatoolbox.utils.visualization import AnnotationRenderer
+from tiatoolbox.utils.visualization import AnnotationRenderer, _find_minimum_mpp_sf
 from tiatoolbox.wsicore import wsireader
 
 RNG = np.random.default_rng(0)  # Numpy Random Generator
@@ -462,6 +462,7 @@ def test_function_mapper(fill_store: Callable, tmp_path: Path) -> None:
     _, store = fill_store(SQLiteStore, tmp_path / "test.db")
 
     def color_fn(props: dict[str, str]) -> tuple[int, int, int]:
+        """Tests Red for cells, otherwise green."""
         # simple test function that returns red for cells, otherwise green.
         if props["type"] == "cell":
             return 1, 0, 0
@@ -480,3 +481,15 @@ def color_fn(props: dict[str, str]) -> tuple[int, int, int]:
     assert num == 50  # expect 50 green objects
     _, num = label(np.array(thumb)[:, :, 2])
     assert num == 0  # expect 0 blue objects
+
+
+def test_minimum_mpp_sf() -> None:
+    """Test minimum mpp_sf."""
+    mpp_sf = _find_minimum_mpp_sf((0.5, 0.5))
+    assert mpp_sf == 1.0
+
+    mpp_sf = _find_minimum_mpp_sf((0.20, 0.20))
+    assert mpp_sf == 0.20 / 0.25
+
+    mpp_sf = _find_minimum_mpp_sf(None)
+    assert mpp_sf == 1.0
diff --git a/tests/test_init.py b/tests/test_init.py
index 509a9c49f..6d8ed8238 100644
--- a/tests/test_init.py
+++ b/tests/test_init.py
@@ -114,7 +114,7 @@ def test_duplicate_filter(caplog: pytest.LogCaptureFixture) -> None:
     logger.addFilter(duplicate_filter)
 
     # Reset filters in logger.
-    for filter_ in logger.filters:
+    for filter_ in logger.filters[:]:
         logger.removeFilter(filter_)
 
     for _ in range(2):
diff --git a/tests/test_utils.py b/tests/test_utils.py
index fe18e0d36..95e6ee520 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1336,24 +1336,6 @@ def test_select_device() -> None:
     assert device == "cpu"
 
 
-def test_model_to() -> None:
-    """Test for placing model on device."""
-    import torchvision.models as torch_models
-    from torch import nn
-
-    # Test on GPU
-    # no GPU on Travis so this will crash
-    if not utils.env_detection.has_gpu():
-        model = torch_models.resnet18()
-        with pytest.raises((AssertionError, RuntimeError)):
-            _ = misc.model_to(on_gpu=True, model=model)
-
-    # Test on CPU
-    model = torch_models.resnet18()
-    model = misc.model_to(on_gpu=False, model=model)
-    assert isinstance(model, nn.Module)
-
-
 def test_save_as_json(tmp_path: Path) -> None:
     """Test save data to json."""
     # This should be broken up into separate tests!
@@ -1673,7 +1655,7 @@ def test_patch_pred_store() -> None:
 
     store = misc.dict_to_store(patch_output, (1.0, 1.0))
 
-    # Check that its an SQLiteStore containing the expected annotations
+    # Check that it is an SQLiteStore containing the expected annotations
     assert isinstance(store, SQLiteStore)
     assert len(store) == 3
     for annotation in store.values():
@@ -1700,7 +1682,7 @@ def test_patch_pred_store_cdict() -> None:
     class_dict = {0: "class0", 1: "class1"}
     store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict)
 
-    # Check that its an SQLiteStore containing the expected annotations
+    # Check that it is an SQLiteStore containing the expected annotations
     assert isinstance(store, SQLiteStore)
     assert len(store) == 3
     for annotation in store.values():
diff --git a/tests/test_wsimeta.py b/tests/test_wsimeta.py
index bc3555e36..01b1cac8b 100644
--- a/tests/test_wsimeta.py
+++ b/tests/test_wsimeta.py
@@ -8,7 +8,6 @@
 from tiatoolbox.wsicore import WSIMeta, wsimeta, wsireader
 
 
-# noinspection PyTypeChecker
 def test_wsimeta_init_fail() -> None:
     """Test incorrect init for WSIMeta raises TypeError."""
     with pytest.raises(TypeError):
diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py
index 0cd476358..420e94085 100644
--- a/tiatoolbox/annotation/storage.py
+++ b/tiatoolbox/annotation/storage.py
@@ -2556,7 +2556,21 @@ def _unpack_wkb(
         cx: float,
         cy: float,
     ) -> bytes:
-        """Unpack WKB data."""
+        """Return the geometry as bytes using WKB.
+
+        Args:
+            data (bytes or str):
+                The WKB/WKT data to be unpacked.
+            cx (int):
+                The X coordinate of the centroid/representative point.
+            cy (float):
+                The Y coordinate of the centroid/representative point.
+
+        Returns:
+            bytes:
+                The geometry as bytes.
+
+        """
         return (
             self._decompress_data(data)
             if data
diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py
index 18e731b4c..26f85625e 100644
--- a/tiatoolbox/cli/common.py
+++ b/tiatoolbox/cli/common.py
@@ -234,6 +234,18 @@ def cli_pretrained_weights(
     )
 
 
+def cli_device(
+    usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.",
+    default: str = "cpu",
+) -> Callable:
+    """Enables --pretrained-weights option for cli."""
+    return click.option(
+        "--device",
+        help=add_default_to_usage_help(usage_help, default),
+        default=default,
+    )
+
+
 def cli_return_probabilities(
     usage_help: str = "Whether to return raw model probabilities.",
     *,
@@ -333,20 +345,6 @@ def cli_yaml_config_path(
     )
 
 
-def cli_on_gpu(
-    usage_help: str = "Run the model on GPU.",
-    *,
-    default: bool = False,
-) -> Callable:
-    """Enables --on-gpu option for cli."""
-    return click.option(
-        "--on-gpu",
-        type=bool,
-        default=default,
-        help=add_default_to_usage_help(usage_help, default),
-    )
-
-
 def cli_num_loader_workers(
     usage_help: str = "Number of workers to load the data. Please note that they will "
     "also perform preprocessing.",
diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py
index b38dcdaed..fdb4b95ca 100644
--- a/tiatoolbox/cli/nucleus_instance_segment.py
+++ b/tiatoolbox/cli/nucleus_instance_segment.py
@@ -7,13 +7,13 @@
 from tiatoolbox.cli.common import (
     cli_auto_generate_mask,
     cli_batch_size,
+    cli_device,
     cli_file_type,
     cli_img_input,
     cli_masks,
     cli_mode,
     cli_num_loader_workers,
     cli_num_postproc_workers,
-    cli_on_gpu,
     cli_output_path,
     cli_pretrained_model,
     cli_pretrained_weights,
@@ -41,7 +41,7 @@
 )
 @cli_pretrained_model(default="hovernet_fast-pannuke")
 @cli_pretrained_weights(default=None)
-@cli_on_gpu(default=False)
+@cli_device(default="cpu")
 @cli_batch_size()
 @cli_masks(default=None)
 @cli_yaml_config_path(default=None)
@@ -61,9 +61,9 @@ def nucleus_instance_segment(
     yaml_config_path: str,
     num_loader_workers: int,
     num_postproc_workers: int,
+    device: str,
     *,
     auto_generate_mask: bool,
-    on_gpu: bool,
     verbose: bool,
 ) -> None:
     """Process an image/directory of input images with a patch classification CNN."""
@@ -97,7 +97,7 @@ def nucleus_instance_segment(
         imgs=files_all,
         masks=masks_all,
         mode=mode,
-        on_gpu=on_gpu,
+        device=device,
         save_dir=output_path,
         ioconfig=ioconfig,
     )
diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py
index a97ecb571..069b6c367 100644
--- a/tiatoolbox/cli/patch_predictor.py
+++ b/tiatoolbox/cli/patch_predictor.py
@@ -6,13 +6,13 @@
 
 from tiatoolbox.cli.common import (
     cli_batch_size,
+    cli_device,
     cli_file_type,
     cli_img_input,
     cli_masks,
     cli_merge_predictions,
     cli_mode,
     cli_num_loader_workers,
-    cli_on_gpu,
     cli_output_path,
     cli_pretrained_model,
     cli_pretrained_weights,
@@ -45,7 +45,7 @@
 @cli_return_probabilities(default=False)
 @cli_merge_predictions(default=True)
 @cli_return_labels(default=True)
-@cli_on_gpu(default=False)
+@cli_device(default="cpu")
 @cli_batch_size(default=1)
 @cli_resolution(default=0.5)
 @cli_units(default="mpp")
@@ -64,11 +64,11 @@ def patch_predictor(
     resolution: float,
     units: str,
     num_loader_workers: int,
+    device: str,
     *,
     return_probabilities: bool,
     return_labels: bool,
     merge_predictions: bool,
-    on_gpu: bool,
     verbose: bool,
 ) -> None:
     """Process an image/directory of input images with a patch classification CNN."""
@@ -100,7 +100,7 @@ def patch_predictor(
         return_labels=return_labels,
         resolution=resolution,
         units=units,
-        on_gpu=on_gpu,
+        device=device,
         save_dir=output_path,
         save_output=True,
     )
diff --git a/tiatoolbox/cli/semantic_segment.py b/tiatoolbox/cli/semantic_segment.py
index 8947b2beb..cbfe18e58 100644
--- a/tiatoolbox/cli/semantic_segment.py
+++ b/tiatoolbox/cli/semantic_segment.py
@@ -6,12 +6,12 @@
 
 from tiatoolbox.cli.common import (
     cli_batch_size,
+    cli_device,
     cli_file_type,
     cli_img_input,
     cli_masks,
     cli_mode,
     cli_num_loader_workers,
-    cli_on_gpu,
     cli_output_path,
     cli_pretrained_model,
     cli_pretrained_weights,
@@ -39,7 +39,7 @@
 )
 @cli_pretrained_model(default="fcn-tissue_mask")
 @cli_pretrained_weights(default=None)
-@cli_on_gpu()
+@cli_device()
 @cli_batch_size()
 @cli_masks(default=None)
 @cli_yaml_config_path()
@@ -56,8 +56,8 @@ def semantic_segment(
     batch_size: int,
     yaml_config_path: str,
     num_loader_workers: int,
+    device: str,
     *,
-    on_gpu: bool,
     verbose: bool,
 ) -> None:
     """Process an image/directory of input images with a patch classification CNN."""
@@ -89,7 +89,7 @@ def semantic_segment(
         imgs=files_all,
         masks=masks_all,
         mode=mode,
-        on_gpu=on_gpu,
+        device=device,
         save_dir=output_path,
         ioconfig=ioconfig,
     )
diff --git a/tiatoolbox/models/architecture/hovernet.py b/tiatoolbox/models/architecture/hovernet.py
index 8f061d273..2853c4946 100644
--- a/tiatoolbox/models/architecture/hovernet.py
+++ b/tiatoolbox/models/architecture/hovernet.py
@@ -20,7 +20,6 @@
     centre_crop_to_shape,
 )
 from tiatoolbox.models.models_abc import ModelABC
-from tiatoolbox.utils import misc
 from tiatoolbox.utils.misc import get_bounding_box
 
 
@@ -766,7 +765,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]:
             >>> pretrained = torch.load(weights_path)
             >>> model = HoVerNet(num_types=6, mode="fast")
             >>> model.load_state_dict(pretrained)
-            >>> output = model.infer_batch(model, batch, on_gpu=False)
+            >>> output = model.infer_batch(model, batch, device="cuda")
             >>> output = [v[0] for v in output]
             >>> output = model.postproc(output)
 
@@ -785,7 +784,9 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple[np.ndarray, dict]:
         return pred_inst, nuc_inst_info_dict
 
     @staticmethod
-    def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple:
+    def infer_batch(  # skipcq: PYL-W0221
+        model: nn.Module, batch_data: np.ndarray, *, device: str
+    ) -> tuple:
         """Run inference on an input batch.
 
         This contains logic for forward operation as well as batch i/o
@@ -797,8 +798,8 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu
             batch_data (ndarray):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
             tuple:
@@ -810,7 +811,6 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu
         """
         patch_imgs = batch_data
 
-        device = misc.select_device(on_gpu=on_gpu)
         patch_imgs_gpu = patch_imgs.to(device).type(torch.float32)  # to NCHW
         patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
 
diff --git a/tiatoolbox/models/architecture/hovernetplus.py b/tiatoolbox/models/architecture/hovernetplus.py
index 87db17295..700eb303f 100644
--- a/tiatoolbox/models/architecture/hovernetplus.py
+++ b/tiatoolbox/models/architecture/hovernetplus.py
@@ -13,7 +13,6 @@
 
 from tiatoolbox.models.architecture.hovernet import HoVerNet
 from tiatoolbox.models.architecture.utils import UpSample2x
-from tiatoolbox.utils import misc
 
 
 class HoVerNetPlus(HoVerNet):
@@ -306,7 +305,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple:
             >>> pretrained = torch.load(weights_path)
             >>> model = HoVerNetPlus(num_types=3, num_layers=5)
             >>> model.load_state_dict(pretrained)
-            >>> output = model.infer_batch(model, batch, on_gpu=False)
+            >>> output = model.infer_batch(model, batch, device="cuda")
             >>> output = [v[0] for v in output]
             >>> output = model.postproc(output)
 
@@ -325,7 +324,7 @@ def postproc(raw_maps: list[np.ndarray]) -> tuple:
         return pred_inst, nuc_inst_info_dict, pred_layer, layer_info_dict
 
     @staticmethod
-    def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tuple:
+    def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> tuple:
         """Run inference on an input batch.
 
         This contains logic for forward operation as well as batch i/o
@@ -337,13 +336,12 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, on_gpu: bool) -> tu
             batch_data (ndarray):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         """
         patch_imgs = batch_data
 
-        device = misc.select_device(on_gpu=on_gpu)
         patch_imgs_gpu = patch_imgs.to(device).type(torch.float32)  # to NCHW
         patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
 
diff --git a/tiatoolbox/models/architecture/mapde.py b/tiatoolbox/models/architecture/mapde.py
index a7156531f..bbb468bb8 100644
--- a/tiatoolbox/models/architecture/mapde.py
+++ b/tiatoolbox/models/architecture/mapde.py
@@ -14,7 +14,6 @@
 from skimage.feature import peak_local_max
 
 from tiatoolbox.models.architecture.micronet import MicroNet
-from tiatoolbox.utils.misc import select_device
 
 
 class MapDe(MicroNet):
@@ -259,7 +258,7 @@ def infer_batch(
         model: torch.nn.Module,
         batch_data: torch.Tensor,
         *,
-        on_gpu: bool,
+        device: str,
     ) -> list[np.ndarray]:
         """Run inference on an input batch.
 
@@ -272,8 +271,8 @@ def infer_batch(
             batch_data (:class:`numpy.ndarray`):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
             list(np.ndarray):
@@ -282,7 +281,6 @@ def infer_batch(
         """
         patch_imgs = batch_data
 
-        device = select_device(on_gpu=on_gpu)
         patch_imgs_gpu = patch_imgs.to(device).type(torch.float32)  # to NCHW
         patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
 
diff --git a/tiatoolbox/models/architecture/micronet.py b/tiatoolbox/models/architecture/micronet.py
index bfc62c8ab..6065fcd46 100644
--- a/tiatoolbox/models/architecture/micronet.py
+++ b/tiatoolbox/models/architecture/micronet.py
@@ -19,7 +19,6 @@
 
 from tiatoolbox.models.architecture.hovernet import HoVerNet
 from tiatoolbox.models.models_abc import ModelABC
-from tiatoolbox.utils import misc
 
 
 def group1_forward_branch(
@@ -625,11 +624,11 @@ def preproc(image: np.ndarray) -> np.ndarray:
         return np.transpose(image.numpy(), axes=(1, 2, 0))
 
     @staticmethod
-    def infer_batch(
+    def infer_batch(  # skipcq: PYL-W0221
         model: torch.nn.Module,
         batch_data: torch.Tensor,
         *,
-        on_gpu: bool,
+        device: str,
     ) -> list[np.ndarray]:
         """Run inference on an input batch.
 
@@ -642,8 +641,8 @@ def infer_batch(
             batch_data (:class:`torch.Tensor`):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
             list(np.ndarray):
@@ -652,7 +651,6 @@ def infer_batch(
         """
         patch_imgs = batch_data
 
-        device = misc.select_device(on_gpu=on_gpu)
         patch_imgs_gpu = patch_imgs.to(device).type(torch.float32)  # to NCHW
         patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
 
diff --git a/tiatoolbox/models/architecture/nuclick.py b/tiatoolbox/models/architecture/nuclick.py
index 339777eb1..77f4ad993 100644
--- a/tiatoolbox/models/architecture/nuclick.py
+++ b/tiatoolbox/models/architecture/nuclick.py
@@ -22,7 +22,6 @@
 
 from tiatoolbox import logger
 from tiatoolbox.models.models_abc import ModelABC
-from tiatoolbox.utils import misc
 
 if TYPE_CHECKING:  # pragma: no cover
     from tiatoolbox.typing import IntPair
@@ -647,7 +646,7 @@ def infer_batch(
         model: nn.Module,
         batch_data: torch.Tensor,
         *,
-        on_gpu: bool,
+        device: str,
     ) -> np.ndarray:
         """Run inference on an input batch.
 
@@ -656,16 +655,16 @@ def infer_batch(
 
         Args:
             model (nn.Module): PyTorch defined model.
-            batch_data (torch.Tensor): a batch of data generated by
-                torch.utils.data.DataLoader.
-            on_gpu (bool): Whether to run inference on a GPU.
+            batch_data (torch.Tensor):
+                A batch of data generated by torch.utils.data.DataLoader.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
             Pixel-wise nuclei prediction for each patch, shape: (no.patch, h, w).
 
         """
         model.eval()
-        device = misc.select_device(on_gpu=on_gpu)
 
         # Assume batch_data is NCHW
         batch_data = batch_data.to(device).type(torch.float32)
diff --git a/tiatoolbox/models/architecture/sccnn.py b/tiatoolbox/models/architecture/sccnn.py
index 9941eabff..4da0f9dca 100644
--- a/tiatoolbox/models/architecture/sccnn.py
+++ b/tiatoolbox/models/architecture/sccnn.py
@@ -17,7 +17,6 @@
 from torch import nn
 
 from tiatoolbox.models.models_abc import ModelABC
-from tiatoolbox.utils import misc
 
 
 class SCCNN(ModelABC):
@@ -354,8 +353,7 @@ def postproc(self: SCCNN, prediction_map: np.ndarray) -> np.ndarray:
     def infer_batch(
         model: nn.Module,
         batch_data: np.ndarray | torch.Tensor,
-        *,
-        on_gpu: bool,
+        device: str,
     ) -> list[np.ndarray]:
         """Run inference on an input batch.
 
@@ -368,8 +366,8 @@ def infer_batch(
             batch_data (:class:`numpy.ndarray` or :class:`torch.Tensor`):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
             list of :class:`numpy.ndarray`:
@@ -378,7 +376,6 @@ def infer_batch(
         """
         patch_imgs = batch_data
 
-        device = misc.select_device(on_gpu=on_gpu)
         patch_imgs_gpu = patch_imgs.to(device).type(torch.float32)
         # to NCHW
         patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous()
diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py
index fe1a97cc9..6385e7587 100644
--- a/tiatoolbox/models/architecture/unet.py
+++ b/tiatoolbox/models/architecture/unet.py
@@ -12,7 +12,6 @@
 
 from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop
 from tiatoolbox.models.models_abc import ModelABC
-from tiatoolbox.utils import misc
 
 
 class ResNetEncoder(ResNet):
@@ -416,7 +415,7 @@ def infer_batch(
         model: nn.Module,
         batch_data: torch.Tensor,
         *,
-        on_gpu: bool,
+        device: str,
     ) -> list:
         """Run inference on an input batch.
 
@@ -429,8 +428,8 @@ def infer_batch(
             batch_data (:class:`torch.Tensor`):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
             list:
@@ -439,7 +438,6 @@ def infer_batch(
 
         """
         model.eval()
-        device = misc.select_device(on_gpu=on_gpu)
 
         ####
         imgs = batch_data
diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py
index 9df4dd56f..2ec47d99d 100644
--- a/tiatoolbox/models/architecture/utils.py
+++ b/tiatoolbox/models/architecture/utils.py
@@ -3,7 +3,7 @@
 from __future__ import annotations
 
 import sys
-from typing import Callable, NoReturn
+from typing import NoReturn
 
 import numpy as np
 import torch
@@ -41,7 +41,7 @@ def compile_model(
     model: nn.Module | None = None,
     *,
     mode: str = "default",
-) -> Callable:
+) -> nn.Module:
     """A decorator to compile a model using torch-compile.
 
     Args:
@@ -60,7 +60,7 @@ def compile_model(
               CUDA graphs
 
     Returns:
-        Callable:
+        torch.nn.Module:
             Compiled model.
 
     """
@@ -71,7 +71,7 @@ def compile_model(
     is_torch_compile_compatible()
 
     # This check will be removed when torch.compile is supported in Python 3.12+
-    if sys.version_info >= (3, 12):  # pragma: no cover
+    if sys.version_info > (3, 12):  # pragma: no cover
         logger.warning(
             ("torch-compile is currently not supported in Python 3.12+. ",),
         )
diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py
index 4879ce04c..c7d3d1498 100644
--- a/tiatoolbox/models/architecture/vanilla.py
+++ b/tiatoolbox/models/architecture/vanilla.py
@@ -11,7 +11,6 @@
 from torch import nn
 
 from tiatoolbox.models.models_abc import ModelABC
-from tiatoolbox.utils.misc import select_device
 
 if TYPE_CHECKING:  # pragma: no cover
     from torchvision.models import WeightsEnum
@@ -149,9 +148,8 @@ def _postproc(image: np.ndarray) -> np.ndarray:
 def _infer_batch(
     model: nn.Module,
     batch_data: torch.Tensor,
-    *,
-    on_gpu: bool,
-) -> np.ndarray:
+    device: str,
+) -> dict[str, np.ndarray]:
     """Run inference on an input batch.
 
     Contains logic for forward operation as well as i/o aggregation.
@@ -162,11 +160,11 @@ def _infer_batch(
         batch_data (torch.Tensor):
             A batch of data generated by
             `torch.utils.data.DataLoader`.
-        on_gpu (bool):
-            Whether to run inference on a GPU.
+        device (str):
+                Transfers model to the specified device. Default is "cpu".
 
     """
-    img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type(
+    img_patches_device = batch_data.to(device=device).type(
         torch.float32,
     )  # to NCHW
     img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()
@@ -243,9 +241,8 @@ def postproc(image: np.ndarray) -> np.ndarray:
     def infer_batch(
         model: nn.Module,
         batch_data: torch.Tensor,
-        *,
-        on_gpu: bool,
-    ) -> np.ndarray:
+        device: str = "cpu",
+    ) -> dict[str, np.ndarray]:
         """Run inference on an input batch.
 
         Contains logic for forward operation as well as i/o aggregation.
@@ -256,11 +253,11 @@ def infer_batch(
             batch_data (torch.Tensor):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         """
-        return _infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)
+        return _infer_batch(model=model, batch_data=batch_data, device=device)
 
 
 class TimmModel(ModelABC):
@@ -339,9 +336,8 @@ def postproc(image: np.ndarray) -> np.ndarray:
     def infer_batch(
         model: nn.Module,
         batch_data: torch.Tensor,
-        *,
-        on_gpu: bool,
-    ) -> np.ndarray:
+        device: str,
+    ) -> dict[str, np.ndarray]:
         """Run inference on an input batch.
 
         Contains logic for forward operation as well as i/o aggregation.
@@ -352,11 +348,11 @@ def infer_batch(
             batch_data (torch.Tensor):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         """
-        return _infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)
+        return _infer_batch(model=model, batch_data=batch_data, device=device)
 
 
 class CNNBackbone(ModelABC):
@@ -425,9 +421,8 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
     def infer_batch(
         model: nn.Module,
         batch_data: torch.Tensor,
-        *,
-        on_gpu: bool,
-    ) -> list[np.ndarray]:
+        device: str,
+    ) -> list[dict[str, np.ndarray]]:
         """Run inference on an input batch.
 
         Contains logic for forward operation as well as i/o aggregation.
@@ -438,15 +433,15 @@ def infer_batch(
             batch_data (torch.Tensor):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
-            list[np.ndarray]:
-                list of numpy arrays.
+            list[dict[str, np.ndarray]]:
+                list of dictionary values with numpy arrays.
 
         """
-        return [_infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)]
+        return [_infer_batch(model=model, batch_data=batch_data, device=device)]
 
 
 class TimmBackbone(ModelABC):
@@ -500,9 +495,8 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
     def infer_batch(
         model: nn.Module,
         batch_data: torch.Tensor,
-        *,
-        on_gpu: bool,
-    ) -> list[np.ndarray]:
+        device: str,
+    ) -> list[dict[str, np.ndarray]]:
         """Run inference on an input batch.
 
         Contains logic for forward operation as well as i/o aggregation.
@@ -513,12 +507,12 @@ def infer_batch(
             batch_data (torch.Tensor):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
 
         Returns:
-            list[np.ndarray]:
-                list of numpy arrays.
+            list[dict[str, np.ndarray]]:
+                list of dictionary values with numpy arrays.
 
         """
-        return [_infer_batch(model=model, batch_data=batch_data, on_gpu=on_gpu)]
+        return [_infer_batch(model=model, batch_data=batch_data, device=device)]
diff --git a/tiatoolbox/models/engine/nucleus_instance_segmentor.py b/tiatoolbox/models/engine/nucleus_instance_segmentor.py
index cc76b68a0..6649324b1 100644
--- a/tiatoolbox/models/engine/nucleus_instance_segmentor.py
+++ b/tiatoolbox/models/engine/nucleus_instance_segmentor.py
@@ -450,7 +450,7 @@ def _get_tile_info(
             * ioconfig.patch_output_shape
         ).astype(np.int32)
         image_shape = np.array(image_shape)
-        (_, tile_outputs) = PatchExtractor.get_coordinates(
+        tile_outputs = PatchExtractor.get_coordinates(
             image_shape=image_shape,
             patch_input_shape=tile_shape,
             patch_output_shape=tile_shape,
@@ -459,7 +459,7 @@ def _get_tile_info(
 
         # * === Now generating the flags to indicate which side should
         # * === be removed in postproc callback
-        boxes = tile_outputs
+        boxes = tile_outputs[1]
 
         # This saves computation time if the image is smaller than the expected tile
         if np.all(image_shape <= tile_shape):
@@ -485,7 +485,7 @@ def unset_removal_flag(boxes: tuple, removal_flag: np.ndarray) -> np.ndarray:
             return removal_flag
 
         w, h = image_shape
-        boxes = tile_outputs
+        boxes = tile_outputs[1]
         #  expand to full four corners
         boxes_br = boxes[:, 2:]
         boxes_tr = np.dstack([boxes[:, 2], boxes[:, 1]])[0]
@@ -646,7 +646,7 @@ def _infer_once(self: NucleusInstanceSegmentor) -> list:
             sample_outputs = self.model.infer_batch(
                 self._model,
                 sample_datas,
-                on_gpu=self._on_gpu,
+                device=self._device,
             )
             # repackage so that it's a N list, each contains
             # L x etc. output
diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py
index da4420cb0..9989c313e 100644
--- a/tiatoolbox/models/engine/patch_predictor.py
+++ b/tiatoolbox/models/engine/patch_predictor.py
@@ -16,7 +16,8 @@
 from tiatoolbox.models.architecture.utils import compile_model
 from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset
 from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig
-from tiatoolbox.utils import misc, save_as_json
+from tiatoolbox.models.models_abc import model_to
+from tiatoolbox.utils import save_as_json
 from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader
 
 if TYPE_CHECKING:  # pragma: no cover
@@ -383,11 +384,11 @@ def merge_predictions(
     def _predict_engine(
         self: PatchPredictor,
         dataset: torch.utils.data.Dataset,
+        device: str = "cpu",
         *,
         return_probabilities: bool = False,
         return_labels: bool = False,
         return_coordinates: bool = False,
-        on_gpu: bool = True,
     ) -> np.ndarray:
         """Make a prediction on a dataset. The dataset may be mutated.
 
@@ -401,8 +402,11 @@ def _predict_engine(
                 Whether to return labels.
             return_coordinates (bool):
                 Whether to return patch coordinates.
-            on_gpu (bool):
-                Whether to run model on the GPU.
+            device (str):
+                :class:`torch.device` to run the model.
+                Select the device to run the model. Please see
+                https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
+                for more details on input parameters for device. Default value is "cpu".
 
         Returns:
             :class:`numpy.ndarray`:
@@ -430,7 +434,7 @@ def _predict_engine(
             )
 
         # use external for testing
-        model = misc.model_to(model=self.model, on_gpu=on_gpu)
+        model = model_to(model=self.model, device=device)
 
         cum_output = {
             "probabilities": [],
@@ -442,7 +446,7 @@ def _predict_engine(
             batch_output_probabilities = self.model.infer_batch(
                 model,
                 batch_data["image"],
-                on_gpu=on_gpu,
+                device=device,
             )
             # We get the index of the class with the maximum probability
             batch_output_predictions = self.model.postproc_func(
@@ -587,10 +591,10 @@ def _predict_patch(
         self: PatchPredictor,
         imgs: list | np.ndarray,
         labels: list,
+        device: str = "cpu",
         *,
         return_probabilities: bool,
         return_labels: bool,
-        on_gpu: bool,
     ) -> np.ndarray:
         """Process patch mode.
 
@@ -609,8 +613,11 @@ def _predict_patch(
                 Whether to return per-class probabilities.
             return_labels (bool):
                 Whether to return the labels with the predictions.
-            on_gpu (bool):
-                Whether to run model on the GPU.
+            device (str):
+                :class:`torch.device` to run the model.
+                Select the device to run the model. Please see
+                https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
+                for more details on input parameters for device. Default value is "cpu".
 
         Returns:
             :class:`numpy.ndarray`:
@@ -635,7 +642,7 @@ def _predict_patch(
             return_probabilities=return_probabilities,
             return_labels=return_labels,
             return_coordinates=return_coordinates,
-            on_gpu=on_gpu,
+            device=device,
         )
 
     def _predict_tile_wsi(  # noqa: PLR0913
@@ -647,11 +654,11 @@ def _predict_tile_wsi(  # noqa: PLR0913
         ioconfig: IOPatchPredictorConfig,
         save_dir: str | Path,
         highest_input_resolution: list[dict],
+        device: str = "cpu",
         *,
         save_output: bool,
         return_probabilities: bool,
         merge_predictions: bool,
-        on_gpu: bool,
     ) -> list | dict:
         """Predict on Tile and WSIs.
 
@@ -678,8 +685,11 @@ def _predict_tile_wsi(  # noqa: PLR0913
                 `tile` or `wsi`.
             return_probabilities (bool):
                 Whether to return per-class probabilities.
-            on_gpu (bool):
-                Whether to run model on the GPU.
+            device (str):
+                :class:`torch.device` to run the model.
+                Select the device to run the model. Please see
+                https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
+                for more details on input parameters for device. Default value is "cpu".
             ioconfig (IOPatchPredictorConfig):
                 Patch Predictor IO configuration..
             merge_predictions (bool):
@@ -747,7 +757,7 @@ def _predict_tile_wsi(  # noqa: PLR0913
                 return_labels=False,
                 return_probabilities=return_probabilities,
                 return_coordinates=return_coordinates,
-                on_gpu=on_gpu,
+                device=device,
             )
             output_model["label"] = img_label
             # add extra information useful for downstream analysis
@@ -795,10 +805,10 @@ def predict(  # noqa: PLR0913
         stride_shape: tuple[int, int] | None = None,
         resolution: Resolution | None = None,
         units: Units = None,
+        device: str = "cpu",
         *,
         return_probabilities: bool = False,
         return_labels: bool = False,
-        on_gpu: bool = True,
         merge_predictions: bool = False,
         save_dir: str | Path | None = None,
         save_output: bool = False,
@@ -830,8 +840,11 @@ def predict(  # noqa: PLR0913
                 Whether to return per-class probabilities.
             return_labels (bool):
                 Whether to return the labels with the predictions.
-            on_gpu (bool):
-                Whether to run model on the GPU.
+            device (str):
+                :class:`torch.device` to run the model.
+                Select the device to run the model. Please see
+                https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
+                for more details on input parameters for device. Default value is "cpu".
             ioconfig (IOPatchPredictorConfig):
                 Patch Predictor IO configuration.
             patch_input_shape (tuple):
@@ -901,7 +914,7 @@ def predict(  # noqa: PLR0913
                 labels,
                 return_probabilities=return_probabilities,
                 return_labels=return_labels,
-                on_gpu=on_gpu,
+                device=device,
             )
 
         if not isinstance(imgs, list):
@@ -948,7 +961,7 @@ def predict(  # noqa: PLR0913
             labels=labels,
             mode=mode,
             return_probabilities=return_probabilities,
-            on_gpu=on_gpu,
+            device=device,
             ioconfig=ioconfig,
             merge_predictions=merge_predictions,
             save_dir=save_dir,
diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py
index 271d49150..029deebf9 100644
--- a/tiatoolbox/models/engine/semantic_segmentor.py
+++ b/tiatoolbox/models/engine/semantic_segmentor.py
@@ -20,9 +20,9 @@
 from tiatoolbox import logger, rcParam
 from tiatoolbox.models.architecture import get_pretrained_model
 from tiatoolbox.models.architecture.utils import compile_model
-from tiatoolbox.models.models_abc import IOConfigABC
+from tiatoolbox.models.models_abc import IOConfigABC, model_to
 from tiatoolbox.tools.patchextraction import PatchExtractor
-from tiatoolbox.utils import imread, misc
+from tiatoolbox.utils import imread
 from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader
 
 if TYPE_CHECKING:  # pragma: no cover
@@ -554,7 +554,7 @@ def __init__(
         self._cache_dir = None
         self._loader = None
         self._model = None
-        self._on_gpu = None
+        self._device = None
         self._mp_shared_space = None
         self._postproc_workers = None
         self.num_postproc_workers = num_postproc_workers
@@ -818,7 +818,7 @@ def _predict_one_wsi(
             sample_outputs = self.model.infer_batch(
                 self._model,
                 sample_datas,
-                on_gpu=self._on_gpu,
+                device=self._device,
             )
             # repackage so that it's an N list, each contains
             # L x etc. output
@@ -1168,7 +1168,7 @@ def _memory_cleanup(self: SemanticSegmentor) -> None:
         self._cache_dir = None
         self._model = None
         self._loader = None
-        self._on_gpu = None
+        self._device = None
         self._futures = None
         self._mp_shared_space = None
         if self._postproc_workers is not None:
@@ -1266,8 +1266,8 @@ def predict(  # noqa: PLR0913
         resolution: Resolution = 1.0,
         units: Units = "baseline",
         save_dir: str | Path | None = None,
+        device: str = "cpu",
         *,
-        on_gpu: bool = True,
         crash_on_exception: bool = False,
     ) -> list[tuple[Path, Path]]:
         """Make a prediction for a list of input data.
@@ -1305,8 +1305,11 @@ def predict(  # noqa: PLR0913
                 `stride_shape`, `resolution`, and `units` arguments are
                 ignored. Otherwise, those arguments will be internally
                 converted to a :class:`IOSegmentorConfig` object.
-            on_gpu (bool):
-                Whether to run the model on the GPU.
+            device (str):
+                :class:`torch.device` to run the model.
+                Select the device to run the model. Please see
+                https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
+                for more details on input parameters for device. Default value is "cpu".
             patch_input_shape (tuple):
                 Size of patches input to the model. The values
                 are at requested read resolution and must be positive.
@@ -1366,8 +1369,8 @@ def predict(  # noqa: PLR0913
         )
 
         # use external for testing
-        self._on_gpu = on_gpu
-        self._model = misc.model_to(model=self.model, on_gpu=on_gpu)
+        self._device = device
+        self._model = model_to(model=self.model, device=device)
 
         # workers should be > 0 else Value Error will be thrown
         self._prepare_workers()
@@ -1566,8 +1569,8 @@ def predict(  # noqa: PLR0913
         resolution: Resolution = 1.0,
         units: Units = "baseline",
         save_dir: str | Path | None = None,
+        device: str = "cpu",
         *,
-        on_gpu: bool = True,
         crash_on_exception: bool = False,
     ) -> list[tuple[Path, Path]]:
         """Make a prediction for a list of input data.
@@ -1605,8 +1608,11 @@ def predict(  # noqa: PLR0913
                 `stride_shape`, `resolution`, and `units` arguments are
                 ignored. Otherwise, those arguments will be internally
                 converted to a :class:`IOSegmentorConfig` object.
-            on_gpu (bool):
-                Whether to run the model on the GPU.
+            device (str):
+                :class:`torch.device` to run the model.
+                Select the device to run the model. Please see
+                https://pytorch.org/docs/stable/tensor_attributes.html#torch.device
+                for more details on input parameters for device. Default value is "cpu".
             patch_input_shape (IntPair):
                 Size of patches input to the model. The values are at
                 requested read resolution and must be positive.
@@ -1662,7 +1668,7 @@ def predict(  # noqa: PLR0913
             imgs=imgs,
             masks=masks,
             mode=mode,
-            on_gpu=on_gpu,
+            device=device,
             ioconfig=ioconfig,
             patch_input_shape=patch_input_shape,
             patch_output_shape=patch_output_shape,
diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py
index e16540c87..a3af4e7f0 100644
--- a/tiatoolbox/models/models_abc.py
+++ b/tiatoolbox/models/models_abc.py
@@ -39,6 +39,28 @@ def output_resolutions(self: IOConfigABC) -> None:
         raise NotImplementedError
 
 
+def model_to(model: torch.nn.Module, device: str = "cpu") -> torch.nn.Module:
+    """Transfers model to specified device e.g., "cpu" or "cuda".
+
+    Args:
+        model (torch.nn.Module):
+            PyTorch defined model.
+        device (str):
+            Transfers model to the specified device. Default is "cpu".
+
+    Returns:
+        torch.nn.Module:
+            The model after being moved to specified device.
+
+    """
+    if device != "cpu":
+        # DataParallel work only for cuda
+        model = torch.nn.DataParallel(model)
+
+    device = torch.device(device)
+    return model.to(device)
+
+
 class ModelABC(ABC, torch.nn.Module):
     """Abstract base class for models used in tiatoolbox."""
 
@@ -59,8 +81,7 @@ def forward(self: ModelABC, *args: tuple[Any, ...], **kwargs: dict) -> None:
     def infer_batch(
         model: torch.nn.Module,
         batch_data: np.ndarray,
-        *,
-        on_gpu: bool,
+        device: str,
     ) -> None:
         """Run inference on an input batch.
 
@@ -72,8 +93,13 @@ def infer_batch(
             batch_data (np.ndarray):
                 A batch of data generated by
                 `torch.utils.data.DataLoader`.
-            on_gpu (bool):
-                Whether to run inference on a GPU.
+            device (str):
+                Transfers model to the specified device. Default is "cpu".
+
+        Returns:
+            dict:
+                Returns a dictionary of predictions and other expected outputs
+                depending on the network architecture.
 
         """
         ...  # pragma: no cover
@@ -106,7 +132,7 @@ def preproc_func(self: ModelABC, func: Callable) -> None:
             >>> # `func` is a user defined function
             >>> model = ModelABC()
             >>> model.preproc_func = func
-            >>> transformed_img = model.preproc_func(img)
+            >>> transformed_img = model.preproc_func(image=np.ndarray)
 
         """
         if func is not None and not callable(func):
@@ -137,7 +163,7 @@ def postproc_func(self: ModelABC, func: Callable) -> None:
             >>> # `func` is a user defined function
             >>> model = ModelABC()
             >>> model.postproc_func = func
-            >>> transformed_img = model.postproc_func(img)
+            >>> transformed_img = model.postproc_func(image=np.ndarray)
 
         """
         if func is not None and not callable(func):
diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py
index 7239f0a8c..7c1c349e7 100644
--- a/tiatoolbox/utils/misc.py
+++ b/tiatoolbox/utils/misc.py
@@ -16,7 +16,6 @@
 import numpy as np
 import pandas as pd
 import requests
-import torch
 import yaml
 import zarr
 from filelock import FileLock
@@ -878,24 +877,6 @@ def select_device(*, on_gpu: bool) -> str:
     return "cpu"
 
 
-def model_to(model: torch.nn.Module, *, on_gpu: bool) -> torch.nn.Module:
-    """Transfers model to cpu/gpu.
-
-    Args:
-        model (torch.nn.Module): PyTorch defined model.
-        on_gpu (bool): Transfers model to gpu if True otherwise to cpu.
-
-    Returns:
-        torch.nn.Module:
-            The model after being moved to cpu/gpu.
-    """
-    if on_gpu:  # DataParallel work only for cuda
-        model = torch.nn.DataParallel(model)
-        return model.to("cuda")
-
-    return model.to("cpu")
-
-
 def get_bounding_box(img: np.ndarray) -> np.ndarray:
     """Get bounding box coordinate information.
 
diff --git a/tiatoolbox/utils/visualization.py b/tiatoolbox/utils/visualization.py
index 142b6e061..817485711 100644
--- a/tiatoolbox/utils/visualization.py
+++ b/tiatoolbox/utils/visualization.py
@@ -559,6 +559,13 @@ def to_int_tuple(x: tuple[int, ...] | np.ndarray) -> tuple[int, ...]:
     return canvas
 
 
+def _find_minimum_mpp_sf(mpp: tuple[float, float] | None) -> float:
+    """Calculates minimum mpp scale factor."""
+    if mpp is not None:
+        return np.minimum(mpp[0] / 0.25, 1)
+    return 1.0
+
+
 class AnnotationRenderer:
     """Renders AnnotationStore to a tile.
 
@@ -971,9 +978,7 @@ def render_annotations(
             int((bounds[2] - bounds[0]) / scale),
         ]
 
-        mpp_sf = 1
-        if self.info["mpp"] is not None:
-            mpp_sf = np.minimum(self.info["mpp"][0] / 0.25, 1)
+        mpp_sf = _find_minimum_mpp_sf(self.info["mpp"])
 
         min_area = 0.0005 * (output_size[0] * output_size[1]) * (scale * mpp_sf) ** 2
 
diff --git a/tiatoolbox/visualization/bokeh_app/main.py b/tiatoolbox/visualization/bokeh_app/main.py
index 6f9aff33d..4e4195558 100644
--- a/tiatoolbox/visualization/bokeh_app/main.py
+++ b/tiatoolbox/visualization/bokeh_app/main.py
@@ -69,6 +69,7 @@
     NucleusInstanceSegmentor,
 )
 from tiatoolbox.tools.pyramid import ZoomifyGenerator
+from tiatoolbox.utils.misc import select_device
 from tiatoolbox.utils.visualization import random_colors
 from tiatoolbox.visualization.ui_utils import get_level_by_extent
 from tiatoolbox.wsicore.wsireader import WSIReader
@@ -1237,7 +1238,7 @@ def segment_on_box() -> None:
         [tmp_mask_dir / "mask.png"],
         save_dir=tmp_save_dir / "hover_out",
         mode="wsi",
-        on_gpu=torch.cuda.is_available(),
+        device=select_device(on_gpu=torch.cuda.is_available()),
         crash_on_exception=True,
     )
 
diff --git a/whitelist.txt b/whitelist.txt
index 07a1b13c3..d1e723f26 100644
--- a/whitelist.txt
+++ b/whitelist.txt
@@ -96,6 +96,7 @@ coord
 coords
 csv
 cuda
+customizable
 cv2
 dataframe
 dataset