Skip to content

Commit

Permalink
Fix type hints for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
LocalToasty committed Jan 17, 2023
1 parent f2255c0 commit 44b0c88
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
16 changes: 8 additions & 8 deletions create_heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def load_slide(slide: openslide.OpenSlide, target_mpp: float = 256 / 224) -> np.
slide_mpp = float(slide.properties[openslide.PROPERTY_NAME_MPP_X])
tile_target_size = np.round(stride * slide_mpp / target_mpp).astype(int)

with futures.ThreadPoolExecutor(min(32, os.cpu_count())) as executor:
with futures.ThreadPoolExecutor(min(32, os.cpu_count() or 1)) as executor:
# map from future to its (row, col) index
future_coords: Dict[futures.Future, Tuple[int, int]] = {}
for i in range(steps): # row
for j in range(steps): # column
future = executor.submit(
_load_tile, slide, (stride * (j, i)), stride, tile_target_size
_load_tile, slide, (stride * (j, i)), stride, tile_target_size # type: ignore
)
future_coords[future] = (i, j)

Expand Down Expand Up @@ -244,8 +244,8 @@ def linear_to_conv2d(linear):

if __name__ == "__main__":
# use all the threads
torch.set_num_threads(os.cpu_count())
torch.set_num_interop_threads(os.cpu_count())
torch.set_num_threads(os.cpu_count() or 1)
torch.set_num_interop_threads(os.cpu_count() or 1)

if args.force_cpu:
device = torch.device("cpu")
Expand Down Expand Up @@ -309,9 +309,9 @@ def linear_to_conv2d(linear):
# we operate in two steps: we first collect all attention values / scores,
# the entirety of which we then calculate our scaling parameters from. Only
# then we output the actual maps.
attention_maps: Dict[Path, torch.Tensor] = {}
score_maps: Dict[Path, torch.Tensor] = {}
masks: Dict[Path, torch.Tensor] = {}
attention_maps: Dict[str, torch.Tensor] = {}
score_maps: Dict[str, torch.Tensor] = {}
masks: Dict[str, torch.Tensor] = {}

print("Extracting features, attentions and scores...")
for slide_url in (progress := tqdm(args.slide_urls, leave=False)):
Expand Down Expand Up @@ -354,7 +354,7 @@ def linear_to_conv2d(linear):
feat_t = torch.concat(slices, 3).squeeze()
# save the features (with compression)
with ZstdFile(feats_pt, mode="wb") as fp:
torch.save(feat_t, fp)
torch.save(feat_t, fp) # type: ignore

feat_t = feat_t.to(device)
# pool features, but use gaussian blur instead of avg pooling to reduce artifacts
Expand Down
15 changes: 10 additions & 5 deletions sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from pathlib import Path
import re
from typing import Mapping
from typing import MutableMapping, Tuple
from urllib.parse import ParseResult
import paramiko

Expand All @@ -18,26 +18,31 @@ def get_wsi(url: ParseResult, *, cache_dir: Path) -> Path:
transport = paramiko.Transport((host, port))
transport.connect(None, username, password)
print("Authentication successful")
with paramiko.SFTPClient.from_transport(transport) as sftp:
with paramiko.SFTPClient.from_transport(transport) as sftp: # type: ignore
assert sftp is not None
remote_stats = sftp.stat(url.path)

cached_wsi_path = cache_dir / Path(url.path).name
# do we have a cached copy?
if (
cached_wsi_path.exists()
and (cached_stats := os.stat(cached_wsi_path))
and remote_stats.st_size
and cached_stats.st_size == remote_stats.st_size # same file size
and remote_stats.st_mtime
and remote_stats.st_mtime <= cached_stats.st_mtime
): # remote file not newer
return cached_wsi_path # yes, we have a good copy

sftp.get(remotepath=url.path, localpath=cached_wsi_path)
sftp.get(remotepath=str(url.path), localpath=str(cached_wsi_path))
# if all else fails, download it
return cached_wsi_path
else:
raise RuntimeError(f"unsupported scheme: {url.scheme}")


def _get_password_for_netloc(
netloc: str, netloc_passwds: Mapping[str, str] = {}
netloc: str, netloc_passwds: MutableMapping[str, str] = {}
) -> str:
# absolutely disgusting use of a "static variable" in the form of a default argument
# don't try this at home
Expand All @@ -51,7 +56,7 @@ def _get_password_for_netloc(
return netloc_passwds[netloc]


def _parse_netloc(netloc: str) -> str:
def _parse_netloc(netloc: str) -> Tuple[str, str, int]:
# parses a username / host / port in the format user@host[:port]
if not (match := re.match(r"(.*)@([^:]*)(?::)?(.*)?", netloc)):
raise RuntimeError(
Expand Down

0 comments on commit 44b0c88

Please sign in to comment.