diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3d2601f..9b618cf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -101,7 +101,7 @@ jobs: mkdir slides && cd slides wget -q https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/JP2K-33003-1.svs cd .. - wsinfer run --wsi-dir slides/ --results-dir results/ --model breast-tumor-resnet34.tcga-brca + WSINFER_FORCE_CPU=1 wsinfer run --wsi-dir slides/ --results-dir results/ --model breast-tumor-resnet34.tcga-brca test -f results/run_metadata_*.json test -f results/patches/JP2K-33003-1.h5 test -f results/model-outputs-csv/JP2K-33003-1.csv diff --git a/pyproject.toml b/pyproject.toml index d57c5d4..b9cb93d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dev = [ ] docs = [ "pydata-sphinx-theme", - "sphinx<6.0.0", + "sphinx", "sphinx-autoapi", "sphinx-click", "sphinx-copybutton", diff --git a/wsinfer/modellib/run_inference.py b/wsinfer/modellib/run_inference.py index b27805c..8750984 100644 --- a/wsinfer/modellib/run_inference.py +++ b/wsinfer/modellib/run_inference.py @@ -6,6 +6,7 @@ """ from __future__ import annotations +import os from pathlib import Path from typing import TYPE_CHECKING from typing import cast as type_cast @@ -93,7 +94,14 @@ def run_inference( model.eval() # Set the device. - if torch.cuda.is_available(): + # Use CPU if env var specifies. Prefer checking if this is false, because + # if the var is set to almost anything (other than 0, False, f), then it + # should be true. + # This env var was introduced mainly for continuous integration tests that + # failed on apple silicon in github actions. Forcing to cpu avoids failures. + if os.getenv("WSINFER_FORCE_CPU", "0").lower() not in {"0", "f", "false"}: + device = torch.device("cpu") + elif torch.cuda.is_available(): device = torch.device("cuda") if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) diff --git a/wsinfer/wsi.py b/wsinfer/wsi.py index dd1f6be..5bb8721 100644 --- a/wsinfer/wsi.py +++ b/wsinfer/wsi.py @@ -21,6 +21,10 @@ try: import openslide + # Test that OpenSlide object exists. If it doesn't, an error will be thrown and + # caught. For some reason, it is possible that openslide-python can be installed + # but the OpenSlide object (and other openslide things) are not available. + openslide.OpenSlide HAS_OPENSLIDE = True logger.debug("Imported openslide") except Exception as err: