From ad6afbcce1b6b66fd5411df2467afa7c946b2cdc Mon Sep 17 00:00:00 2001 From: wuyongjianCODE <114390912+wuyongjianCODE@users.noreply.github.com> Date: Fri, 2 Jun 2023 16:56:21 +0800 Subject: [PATCH] Add files via upload --- dataloader/__init__.py | 0 dataloader/augs.py | 109 +++++ dataloader/infer_loader.py | 100 +++++ dataloader/train_loader.py | 194 +++++++++ infer/__init__.py | 0 infer/base.py | 94 +++++ infer/tile.py | 391 ++++++++++++++++++ infer/wsi.py | 751 ++++++++++++++++++++++++++++++++++ metrics/README.md | 55 +++ metrics/__init__.py | 0 metrics/stats_utils.py | 429 +++++++++++++++++++ misc/__init__.py | 0 misc/patch_extractor.py | 155 +++++++ misc/utils.py | 182 ++++++++ misc/viz_utils.py | 173 ++++++++ misc/wsi_handler.py | 204 +++++++++ models/__init__.py | 0 models/hovernet/__init__.py | 0 models/hovernet/net_desc.py | 153 +++++++ models/hovernet/net_utils.py | 295 +++++++++++++ models/hovernet/opt.py | 142 +++++++ models/hovernet/post_proc.py | 186 +++++++++ models/hovernet/run_desc.py | 344 ++++++++++++++++ models/hovernet/targets.py | 153 +++++++ models/hovernet/utils.py | 172 ++++++++ models/hovernet0/__init__.py | 0 models/hovernet0/net_desc.py | 153 +++++++ models/hovernet0/net_utils.py | 295 +++++++++++++ models/hovernet0/opt.py | 142 +++++++ models/hovernet0/post_proc.py | 186 +++++++++ models/hovernet0/run_desc.py | 344 ++++++++++++++++ models/hovernet0/targets.py | 153 +++++++ models/hovernet0/utils.py | 172 ++++++++ models/hovernetC/__init__.py | 0 models/hovernetC/net_desc.py | 221 ++++++++++ models/hovernetC/net_utils.py | 295 +++++++++++++ models/hovernetC/opt.py | 142 +++++++ models/hovernetC/post_proc.py | 186 +++++++++ models/hovernetC/run_desc.py | 331 +++++++++++++++ models/hovernetC/targets.py | 153 +++++++ models/hovernetC/utils.py | 172 ++++++++ 41 files changed, 7227 insertions(+) create mode 100644 dataloader/__init__.py create mode 100644 dataloader/augs.py create mode 100644 dataloader/infer_loader.py create mode 100644 dataloader/train_loader.py create mode 100644 infer/__init__.py create mode 100644 infer/base.py create mode 100644 infer/tile.py create mode 100644 infer/wsi.py create mode 100644 metrics/README.md create mode 100644 metrics/__init__.py create mode 100644 metrics/stats_utils.py create mode 100644 misc/__init__.py create mode 100644 misc/patch_extractor.py create mode 100644 misc/utils.py create mode 100644 misc/viz_utils.py create mode 100644 misc/wsi_handler.py create mode 100644 models/__init__.py create mode 100644 models/hovernet/__init__.py create mode 100644 models/hovernet/net_desc.py create mode 100644 models/hovernet/net_utils.py create mode 100644 models/hovernet/opt.py create mode 100644 models/hovernet/post_proc.py create mode 100644 models/hovernet/run_desc.py create mode 100644 models/hovernet/targets.py create mode 100644 models/hovernet/utils.py create mode 100644 models/hovernet0/__init__.py create mode 100644 models/hovernet0/net_desc.py create mode 100644 models/hovernet0/net_utils.py create mode 100644 models/hovernet0/opt.py create mode 100644 models/hovernet0/post_proc.py create mode 100644 models/hovernet0/run_desc.py create mode 100644 models/hovernet0/targets.py create mode 100644 models/hovernet0/utils.py create mode 100644 models/hovernetC/__init__.py create mode 100644 models/hovernetC/net_desc.py create mode 100644 models/hovernetC/net_utils.py create mode 100644 models/hovernetC/opt.py create mode 100644 models/hovernetC/post_proc.py create mode 100644 models/hovernetC/run_desc.py create mode 100644 models/hovernetC/targets.py create mode 100644 models/hovernetC/utils.py diff --git a/dataloader/__init__.py b/dataloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloader/augs.py b/dataloader/augs.py new file mode 100644 index 0000000..6d23cb3 --- /dev/null +++ b/dataloader/augs.py @@ -0,0 +1,109 @@ +import math + +import cv2 +import matplotlib.cm as cm +import numpy as np + +from scipy import ndimage +from scipy.ndimage import measurements +from scipy.ndimage.filters import gaussian_filter +from scipy.ndimage.interpolation import affine_transform, map_coordinates + +from skimage import morphology as morph + +from misc.utils import cropping_center, get_bounding_box + + +#### +def fix_mirror_padding(ann): + """Deal with duplicated instances due to mirroring in interpolation + during shape augmentation (scale, rotation etc.). + + """ + current_max_id = np.amax(ann) + inst_list = list(np.unique(ann)) + inst_list.remove(0) # 0 is background + for inst_id in inst_list: + inst_map = np.array(ann == inst_id, np.uint8) + remapped_ids = measurements.label(inst_map)[0] + remapped_ids[remapped_ids > 1] += current_max_id + ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] + current_max_id = np.amax(ann) + return ann + + +#### +def gaussian_blur(images, random_state, parents, hooks, max_ksize=3): + """Apply Gaussian blur to input images.""" + img = images[0] # aleju input batch as default (always=1 in our case) + ksize = random_state.randint(0, max_ksize, size=(2,)) + ksize = tuple((ksize * 2 + 1).tolist()) + + ret = cv2.GaussianBlur( + img, ksize, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE + ) + ret = np.reshape(ret, img.shape) + ret = ret.astype(np.uint8) + return [ret] + + +#### +def median_blur(images, random_state, parents, hooks, max_ksize=3): + """Apply median blur to input images.""" + img = images[0] # aleju input batch as default (always=1 in our case) + ksize = random_state.randint(0, max_ksize) + ksize = ksize * 2 + 1 + ret = cv2.medianBlur(img, ksize) + ret = ret.astype(np.uint8) + return [ret] + + +#### +def add_to_hue(images, random_state, parents, hooks, range=None): + """Perturbe the hue of input images.""" + img = images[0] # aleju input batch as default (always=1 in our case) + hue = random_state.uniform(*range) + hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + if hsv.dtype.itemsize == 1: + # OpenCV uses 0-179 for 8-bit images + hsv[..., 0] = (hsv[..., 0] + hue) % 180 + else: + # OpenCV uses 0-360 for floating point images + hsv[..., 0] = (hsv[..., 0] + 2 * hue) % 360 + ret = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + ret = ret.astype(np.uint8) + return [ret] + + +#### +def add_to_saturation(images, random_state, parents, hooks, range=None): + """Perturbe the saturation of input images.""" + img = images[0] # aleju input batch as default (always=1 in our case) + value = 1 + random_state.uniform(*range) + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + ret = img * value + (gray * (1 - value))[:, :, np.newaxis] + ret = np.clip(ret, 0, 255) + ret = ret.astype(np.uint8) + return [ret] + + +#### +def add_to_contrast(images, random_state, parents, hooks, range=None): + """Perturbe the contrast of input images.""" + img = images[0] # aleju input batch as default (always=1 in our case) + value = random_state.uniform(*range) + mean = np.mean(img, axis=(0, 1), keepdims=True) + ret = img * value + mean * (1 - value) + ret = np.clip(img, 0, 255) + ret = ret.astype(np.uint8) + return [ret] + + +#### +def add_to_brightness(images, random_state, parents, hooks, range=None): + """Perturbe the brightness of input images.""" + img = images[0] # aleju input batch as default (always=1 in our case) + value = random_state.uniform(*range) + ret = np.clip(img + value, 0, 255) + ret = ret.astype(np.uint8) + return [ret] diff --git a/dataloader/infer_loader.py b/dataloader/infer_loader.py new file mode 100644 index 0000000..fd25bb7 --- /dev/null +++ b/dataloader/infer_loader.py @@ -0,0 +1,100 @@ +import sys +import math +import numpy as np +import cv2 +import matplotlib.pyplot as plt + +import torch +import torch.utils.data as data + +import psutil + + +#### +class SerializeFileList(data.IterableDataset): + """Read a single file as multiple patches of same shape, perform the padding beforehand.""" + + def __init__(self, img_list, patch_info_list, patch_size, preproc=None): + super().__init__() + self.patch_size = patch_size + + self.img_list = img_list + self.patch_info_list = patch_info_list + + self.worker_start_img_idx = 0 + # * for internal worker state + self.curr_img_idx = 0 + self.stop_img_idx = 0 + self.curr_patch_idx = 0 + self.stop_patch_idx = 0 + self.preproc = preproc + return + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: # single-process data loading, return the full iterator + self.stop_img_idx = len(self.img_list) + self.stop_patch_idx = len(self.patch_info_list) + return self + else: # in a worker process so split workload, return a reduced copy of self + per_worker = len(self.patch_info_list) / float(worker_info.num_workers) + per_worker = int(math.ceil(per_worker)) + + global_curr_patch_idx = worker_info.id * per_worker + global_stop_patch_idx = global_curr_patch_idx + per_worker + self.patch_info_list = self.patch_info_list[ + global_curr_patch_idx:global_stop_patch_idx + ] + self.curr_patch_idx = 0 + self.stop_patch_idx = len(self.patch_info_list) + # * check img indexer, implicit protocol in infer.py + global_curr_img_idx = self.patch_info_list[0][-1] + global_stop_img_idx = self.patch_info_list[-1][-1] + 1 + self.worker_start_img_idx = global_curr_img_idx + self.img_list = self.img_list[global_curr_img_idx:global_stop_img_idx] + self.curr_img_idx = 0 + self.stop_img_idx = len(self.img_list) + return self # does it mutate source copy? + + def __next__(self): + + if self.curr_patch_idx >= self.stop_patch_idx: + raise StopIteration # when there is nothing more to yield + patch_info = self.patch_info_list[self.curr_patch_idx] + img_ptr = self.img_list[patch_info[-1] - self.worker_start_img_idx] + patch_data = img_ptr[ + patch_info[0] : patch_info[0] + self.patch_size, + patch_info[1] : patch_info[1] + self.patch_size, + ] + self.curr_patch_idx += 1 + if self.preproc is not None: + patch_data = self.preproc(patch_data) + return patch_data, patch_info + + +#### +class SerializeArray(data.Dataset): + def __init__(self, mmap_array_path, patch_info_list, patch_size, preproc=None): + super().__init__() + self.patch_size = patch_size + + # use mmap as intermediate sharing, else variable will be duplicated + # accross torch worker => OOM error, open in read only mode + self.image = np.load(mmap_array_path, mmap_mode="r") + + self.patch_info_list = patch_info_list + self.preproc = preproc + return + + def __len__(self): + return len(self.patch_info_list) + + def __getitem__(self, idx): + patch_info = self.patch_info_list[idx] + patch_data = self.image[ + patch_info[0] : patch_info[0] + self.patch_size[0], + patch_info[1] : patch_info[1] + self.patch_size[1], + ] + if self.preproc is not None: + patch_data = self.preproc(patch_data) + return patch_data, patch_info diff --git a/dataloader/train_loader.py b/dataloader/train_loader.py new file mode 100644 index 0000000..5a65c2e --- /dev/null +++ b/dataloader/train_loader.py @@ -0,0 +1,194 @@ +import csv +import glob +import os +import re + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import scipy.io as sio +import torch.utils.data + +import imgaug as ia +from imgaug import augmenters as iaa +from misc.utils import cropping_center + +from .augs import ( + add_to_brightness, + add_to_contrast, + add_to_hue, + add_to_saturation, + gaussian_blur, + median_blur, +) + + +#### +class FileLoader(torch.utils.data.Dataset): + """Data Loader. Loads images from a file list and + performs augmentation with the albumentation library. + After augmentation, horizontal and vertical maps are + generated. + + Args: + file_list: list of filenames to load + input_shape: shape of the input [h,w] - defined in config.py + mask_shape: shape of the output [h,w] - defined in config.py + mode: 'train' or 'valid' + + """ + + # TODO: doc string + + def __init__( + self, + file_list, + with_type=False, + input_shape=None, + mask_shape=None, + mode="train", + setup_augmentor=True, + target_gen=None, + ): + assert input_shape is not None and mask_shape is not None + self.mode = mode + self.info_list = file_list + self.with_type = with_type + self.mask_shape = mask_shape + self.input_shape = input_shape + self.id = 0 + self.target_gen_func = target_gen[0] + self.target_gen_kwargs = target_gen[1] + if setup_augmentor: + self.setup_augmentor(0, 0) + return + + def setup_augmentor(self, worker_id, seed): + self.augmentor = self.__get_augmentation(self.mode, seed) + self.shape_augs = iaa.Sequential(self.augmentor[0]) + self.input_augs = iaa.Sequential(self.augmentor[1]) + self.id = self.id + worker_id + return + + def __len__(self): + return len(self.info_list) + + def __getitem__(self, idx): + path = self.info_list[idx] + data = np.load(path) + + # split stacked channel into image and label + img = (data[..., :3]).astype("uint8") # RGB images + ann = (data[..., 3:]).astype("int32") # instance ID map and type map + + if self.shape_augs is not None: + shape_augs = self.shape_augs.to_deterministic() + img = shape_augs.augment_image(img) + ann = shape_augs.augment_image(ann) + + if self.input_augs is not None: + input_augs = self.input_augs.to_deterministic() + img = input_augs.augment_image(img) + + img = cropping_center(img, self.input_shape) + feed_dict = {"img": img} + + inst_map = ann[..., 0] # HW1 -> HW + if self.with_type: + type_map = (ann[..., 1]).copy() + type_map = cropping_center(type_map, self.mask_shape) + #type_map[type_map == 5] = 1 # merge neoplastic and non-neoplastic + feed_dict["tp_map"] = type_map + + # TODO: document hard coded assumption about #input + target_dict = self.target_gen_func( + inst_map, self.mask_shape, **self.target_gen_kwargs + ) + feed_dict.update(target_dict) + + return feed_dict + + def __get_augmentation(self, mode, rng): + if mode == "train": + shape_augs = [ + # * order = ``0`` -> ``cv2.INTER_NEAREST`` + # * order = ``1`` -> ``cv2.INTER_LINEAR`` + # * order = ``2`` -> ``cv2.INTER_CUBIC`` + # * order = ``3`` -> ``cv2.INTER_CUBIC`` + # * order = ``4`` -> ``cv2.INTER_CUBIC`` + # ! for pannuke v0, no rotation or translation, just flip to avoid mirror padding + iaa.Affine( + # scale images to 80-120% of their size, individually per axis + scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, + # translate by -A to +A percent (per axis) + translate_percent={"x": (-0.01, 0.01), "y": (-0.01, 0.01)}, + shear=(-5, 5), # shear by -5 to +5 degrees + rotate=(-179, 179), # rotate by -179 to +179 degrees + order=0, # use nearest neighbour + backend="cv2", # opencv for fast processing + seed=rng, + ), + # set position to 'center' for center crop + # else 'uniform' for random crop + iaa.CropToFixedSize( + self.input_shape[0], self.input_shape[1], position="center" + ), + iaa.Fliplr(0.5, seed=rng), + iaa.Flipud(0.5, seed=rng), + ] + + input_augs = [ + iaa.OneOf( + [ + iaa.Lambda( + seed=rng, + func_images=lambda *args: gaussian_blur(*args, max_ksize=3), + ), + iaa.Lambda( + seed=rng, + func_images=lambda *args: median_blur(*args, max_ksize=3), + ), + iaa.AdditiveGaussianNoise( + loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5 + ), + ] + ), + iaa.Sequential( + [ + iaa.Lambda( + seed=rng, + func_images=lambda *args: add_to_hue(*args, range=(-8, 8)), + ), + iaa.Lambda( + seed=rng, + func_images=lambda *args: add_to_saturation( + *args, range=(-0.2, 0.2) + ), + ), + iaa.Lambda( + seed=rng, + func_images=lambda *args: add_to_brightness( + *args, range=(-26, 26) + ), + ), + iaa.Lambda( + seed=rng, + func_images=lambda *args: add_to_contrast( + *args, range=(0.75, 1.25) + ), + ), + ], + random_order=True, + ), + ] + elif mode == "valid": + shape_augs = [ + # set position to 'center' for center crop + # else 'uniform' for random crop + iaa.CropToFixedSize( + self.input_shape[0], self.input_shape[1], position="center" + ) + ] + input_augs = [] + + return shape_augs, input_augs diff --git a/infer/__init__.py b/infer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/infer/base.py b/infer/base.py new file mode 100644 index 0000000..81165ec --- /dev/null +++ b/infer/base.py @@ -0,0 +1,94 @@ +import argparse +import glob +import json +import math +import multiprocessing +import os +import re +import sys +from importlib import import_module +from multiprocessing import Lock, Pool + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.utils.data as data +import tqdm + +from run_utils.utils import convert_pytorch_checkpoint + + +#### +class InferManager(object): + def __init__(self, **kwargs): + self.run_step = None + for variable, value in kwargs.items(): + self.__setattr__(variable, value) + self.__load_model() + self.nr_types = self.method["model_args"]["nr_types"] + # create type info name and colour + + # default + self.type_info_dict = { + None: ["no label", [0, 0, 0]], + } + + if self.nr_types is not None and self.type_info_path is not None: + self.type_info_dict = json.load(open(self.type_info_path, "r")) + self.type_info_dict = { + int(k): (v[0], tuple(v[1])) for k, v in self.type_info_dict.items() + } + # availability check + for k in range(self.nr_types): + if k not in self.type_info_dict: + assert False, "Not detect type_id=%d defined in json." % k + + if self.nr_types is not None and self.type_info_path is None: + cmap = plt.get_cmap("hot") + colour_list = np.arange(self.nr_types, dtype=np.int32) + colour_list = (cmap(colour_list)[..., :3] * 255).astype(np.uint8) + # should be compatible out of the box wrt qupath + self.type_info_dict = { + k: (str(k), tuple(v)) for k, v in enumerate(colour_list) + } + return + + def __load_model(self): + """Create the model, load the checkpoint and define + associated run steps to process each data batch. + + """ + model_desc = import_module("models.hovernet.net_desc") + model_creator = getattr(model_desc, "create_model") + + net = model_creator(**self.method["model_args"]) + saved_state_dict = torch.load(self.method["model_path"])["desc"] + saved_state_dict = convert_pytorch_checkpoint(saved_state_dict) + + net.load_state_dict(saved_state_dict, strict=True) + net = torch.nn.DataParallel(net) + net = net.to("cuda") + + module_lib = import_module("models.hovernet.run_desc") + run_step = getattr(module_lib, "infer_step") + self.run_step = lambda input_batch: run_step(input_batch, net) + + module_lib = import_module("models.hovernet.post_proc") + self.post_proc_func = getattr(module_lib, "process") + return + + def __save_json(self, path, old_dict, mag=None): + new_dict = {} + for inst_id, inst_info in old_dict.items(): + new_inst_info = {} + for info_name, info_value in inst_info.items(): + # convert to jsonable + if isinstance(info_value, np.ndarray): + info_value = info_value.tolist() + new_inst_info[info_name] = info_value + new_dict[int(inst_id)] = new_inst_info + + json_dict = {"mag": mag, "nuc": new_dict} # to sync the format protocol + with open(path, "w") as handle: + json.dump(json_dict, handle) + return new_dict diff --git a/infer/tile.py b/infer/tile.py new file mode 100644 index 0000000..127a863 --- /dev/null +++ b/infer/tile.py @@ -0,0 +1,391 @@ +import logging +import multiprocessing +from multiprocessing import Lock, Pool + +multiprocessing.set_start_method("spawn", True) # ! must be at top for VScode debugging +import argparse +import glob +import json +import math +import multiprocessing as mp +import os +import pathlib +import pickle +import re +import sys +import warnings +from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor, as_completed, wait +from functools import reduce +from importlib import import_module +from multiprocessing import Lock, Pool + +import cv2 +import numpy as np +import psutil +import scipy.io as sio +import torch +import torch.utils.data as data +import tqdm +from dataloader.infer_loader import SerializeArray, SerializeFileList +from misc.utils import ( + color_deconvolution, + cropping_center, + get_bounding_box, + log_debug, + log_info, + rm_n_mkdir, +) +from misc.viz_utils import colorize, visualize_instances_dict +from skimage import color + +import convert_format +from . import base + + +#### +def _prepare_patching(img, window_size, mask_size, return_src_top_corner=False): + """Prepare patch information for tile processing. + + Args: + img: original input image + window_size: input patch size + mask_size: output patch size + return_src_top_corner: whether to return coordiante information for top left corner of img + + """ + + win_size = window_size + msk_size = step_size = mask_size + + def get_last_steps(length, msk_size, step_size): + nr_step = math.ceil((length - msk_size) / step_size) + last_step = (nr_step + 1) * step_size + return int(last_step), int(nr_step + 1) + + im_h = img.shape[0] + im_w = img.shape[1] + + last_h, _ = get_last_steps(im_h, msk_size, step_size) + last_w, _ = get_last_steps(im_w, msk_size, step_size) + + diff = win_size - step_size + padt = padl = diff // 2 + padb = last_h + win_size - im_h + padr = last_w + win_size - im_w + + img = np.lib.pad(img, ((padt, padb), (padl, padr), (0, 0)), "reflect") + + # generating subpatches index from orginal + coord_y = np.arange(0, last_h, step_size, dtype=np.int32) + coord_x = np.arange(0, last_w, step_size, dtype=np.int32) + row_idx = np.arange(0, coord_y.shape[0], dtype=np.int32) + col_idx = np.arange(0, coord_x.shape[0], dtype=np.int32) + coord_y, coord_x = np.meshgrid(coord_y, coord_x) + row_idx, col_idx = np.meshgrid(row_idx, col_idx) + coord_y = coord_y.flatten() + coord_x = coord_x.flatten() + row_idx = row_idx.flatten() + col_idx = col_idx.flatten() + # + patch_info = np.stack([coord_y, coord_x, row_idx, col_idx], axis=-1) + if not return_src_top_corner: + return img, patch_info + else: + return img, patch_info, [padt, padl] + + +#### +def _post_process_patches( + post_proc_func, post_proc_kwargs, patch_info, image_info, overlay_kwargs, +): + """Apply post processing to patches. + + Args: + post_proc_func: post processing function to use + post_proc_kwargs: keyword arguments used in post processing function + patch_info: patch data and associated information + image_info: input image data and associated information + overlay_kwargs: overlay keyword arguments + + """ + # re-assemble the prediction, sort according to the patch location within the original image + patch_info = sorted(patch_info, key=lambda x: [x[0][0], x[0][1]]) + patch_info, patch_data = zip(*patch_info) + + src_shape = image_info["src_shape"] + src_image = image_info["src_image"] + + patch_shape = np.squeeze(patch_data[0]).shape + ch = 1 if len(patch_shape) == 2 else patch_shape[-1] + axes = [0, 2, 1, 3, 4] if ch != 1 else [0, 2, 1, 3] + + nr_row = max([x[2] for x in patch_info]) + 1 + nr_col = max([x[3] for x in patch_info]) + 1 + pred_map = np.concatenate(patch_data, axis=0) + pred_map = np.reshape(pred_map, (nr_row, nr_col) + patch_shape) + pred_map = np.transpose(pred_map, axes) + pred_map = np.reshape( + pred_map, (patch_shape[0] * nr_row, patch_shape[1] * nr_col, ch) + ) + # crop back to original shape + pred_map = np.squeeze(pred_map[: src_shape[0], : src_shape[1]]) + + # * Implicit protocol + # * a prediction map with instance of ID 1-N + # * and a dict contain the instance info, access via its ID + # * each instance may have type + pred_inst, inst_info_dict = post_proc_func(pred_map, **post_proc_kwargs) + + overlaid_img = visualize_instances_dict( + src_image.copy(), inst_info_dict, **overlay_kwargs + ) + + return image_info["name"], pred_map, pred_inst, inst_info_dict, overlaid_img + + +class InferManager(base.InferManager): + """Run inference on tiles.""" + + #### + def process_file_list(self, run_args): + """ + Process a single image tile < 5000x5000 in size. + """ + for variable, value in run_args.items(): + self.__setattr__(variable, value) + assert self.mem_usage < 1.0 and self.mem_usage > 0.0 + + # * depend on the number of samples and their size, this may be less efficient + patterning = lambda x: re.sub("([\[\]])", "[\\1]", x) + file_path_list = glob.glob(patterning("%s/*" % self.input_dir)) + file_path_list.sort() # ensure same order + assert len(file_path_list) > 0, 'Not Detected Any Files From Path' + + rm_n_mkdir(self.output_dir + '/json/') + rm_n_mkdir(self.output_dir + '/mat/') + rm_n_mkdir(self.output_dir + '/overlay/') + if self.save_qupath: + rm_n_mkdir(self.output_dir + "/qupath/") + + def proc_callback(results): + """Post processing callback. + + Output format is implicit assumption, taken from `_post_process_patches` + + """ + img_name, pred_map, pred_inst, inst_info_dict, overlaid_img = results + + nuc_val_list = list(inst_info_dict.values()) + # need singleton to make matlab happy + nuc_uid_list = np.array(list(inst_info_dict.keys()))[:,None] + nuc_type_list = np.array([v["type"] for v in nuc_val_list])[:,None] + nuc_coms_list = np.array([v["centroid"] for v in nuc_val_list]) + + mat_dict = { + "inst_map" : pred_inst, + "inst_uid" : nuc_uid_list, + "inst_type": nuc_type_list, + "inst_centroid": nuc_coms_list + } + if self.nr_types is None: # matlab does not have None type array + mat_dict.pop("inst_type", None) + + if self.save_raw_map: + mat_dict["raw_map"] = pred_map + save_path = "%s/mat/%s.mat" % (self.output_dir, img_name) + sio.savemat(save_path, mat_dict) + + save_path = "%s/overlay/%s.png" % (self.output_dir, img_name) + cv2.imwrite(save_path, cv2.cvtColor(overlaid_img, cv2.COLOR_RGB2BGR)) + + if self.save_qupath: + nuc_val_list = list(inst_info_dict.values()) + nuc_type_list = np.array([v["type"] for v in nuc_val_list]) + nuc_coms_list = np.array([v["centroid"] for v in nuc_val_list]) + save_path = "%s/qupath/%s.tsv" % (self.output_dir, img_name) + convert_format.to_qupath( + save_path, nuc_coms_list, nuc_type_list, self.type_info_dict + ) + + save_path = "%s/json/%s.json" % (self.output_dir, img_name) + self.__save_json(save_path, inst_info_dict, None) + return img_name + + def detach_items_of_uid(items_list, uid, nr_expected_items): + item_counter = 0 + detached_items_list = [] + remained_items_list = [] + while True: + pinfo, pdata = items_list.pop(0) + pinfo = np.squeeze(pinfo) + if pinfo[-1] == uid: + detached_items_list.append([pinfo, pdata]) + item_counter += 1 + else: + remained_items_list.append([pinfo, pdata]) + if item_counter == nr_expected_items: + break + # do this to ensure the ordering + remained_items_list = remained_items_list + items_list + return detached_items_list, remained_items_list + + proc_pool = None + if self.nr_post_proc_workers > 0: + proc_pool = ProcessPoolExecutor(self.nr_post_proc_workers) + + while len(file_path_list) > 0: + + hardware_stats = psutil.virtual_memory() + available_ram = getattr(hardware_stats, "available") + available_ram = int(available_ram * self.mem_usage) + # available_ram >> 20 for MB, >> 30 for GB + + # TODO: this portion looks clunky but seems hard to detach into separate func + + # * caching N-files into memory such that their expected (total) memory usage + # * does not exceed the designated percentage of currently available memory + # * the expected memory is a factor w.r.t original input file size and + # * must be manually provided + file_idx = 0 + use_path_list = [] + cache_image_list = [] + cache_patch_info_list = [] + cache_image_info_list = [] + while len(file_path_list) > 0: + file_path = file_path_list.pop(0) + + img = cv2.imread(file_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + src_shape = img.shape + # import skimage.io as io + # img = io.imread(file_path) + # src_shape = img.shape + + img, patch_info, top_corner = _prepare_patching( + img, self.patch_input_shape, self.patch_output_shape, True + ) + self_idx = np.full(patch_info.shape[0], file_idx, dtype=np.int32) + patch_info = np.concatenate([patch_info, self_idx[:, None]], axis=-1) + # ? may be expensive op + patch_info = np.split(patch_info, patch_info.shape[0], axis=0) + patch_info = [np.squeeze(p) for p in patch_info] + + # * this factor=5 is only applicable for HoVerNet + expected_usage = sys.getsizeof(img) * 5 + available_ram -= expected_usage + if available_ram < 0: + break + + file_idx += 1 + # if file_idx == 4: break + use_path_list.append(file_path) + cache_image_list.append(img) + cache_patch_info_list.extend(patch_info) + # TODO: refactor to explicit protocol + cache_image_info_list.append([src_shape, len(patch_info), top_corner]) + + # * apply neural net on cached data + dataset = SerializeFileList( + cache_image_list, cache_patch_info_list, self.patch_input_shape + ) + + dataloader = data.DataLoader( + dataset, + num_workers=self.nr_inference_workers, + batch_size=self.batch_size, + drop_last=False, + ) + + pbar = tqdm.tqdm( + desc="Process Patches", + leave=True, + total=int(len(cache_patch_info_list) / self.batch_size) + 1, + ncols=80, + ascii=True, + position=0, + ) + + accumulated_patch_output = [] + for batch_idx, batch_data in enumerate(dataloader): + sample_data_list, sample_info_list = batch_data + sample_output_list = self.run_step(sample_data_list) + sample_info_list = sample_info_list.numpy() + curr_batch_size = sample_output_list.shape[0] + sample_output_list = np.split( + sample_output_list, curr_batch_size, axis=0 + ) + sample_info_list = np.split(sample_info_list, curr_batch_size, axis=0) + sample_output_list = list(zip(sample_info_list, sample_output_list)) + accumulated_patch_output.extend(sample_output_list) + pbar.update() + pbar.close() + + # * parallely assemble the processed cache data for each file if possible + future_list = [] + for file_idx, file_path in enumerate(use_path_list): + image_info = cache_image_info_list[file_idx] + file_ouput_data, accumulated_patch_output = detach_items_of_uid( + accumulated_patch_output, file_idx, image_info[1] + ) + + # * detach this into func and multiproc dispatch it + src_pos = image_info[2] # src top left corner within padded image + src_image = cache_image_list[file_idx] + src_image = src_image[ + src_pos[0] : src_pos[0] + image_info[0][0], + src_pos[1] : src_pos[1] + image_info[0][1], + ] + + base_name = pathlib.Path(file_path).stem + file_info = { + "src_shape": image_info[0], + "src_image": src_image, + "name": base_name, + } + + post_proc_kwargs = { + "nr_types": self.nr_types, + "return_centroids": True, + } # dynamicalize this + + overlay_kwargs = { + "draw_dot": self.draw_dot, + "type_colour": self.type_info_dict, + "line_thickness": 2, + } + func_args = ( + self.post_proc_func, + post_proc_kwargs, + file_ouput_data, + file_info, + overlay_kwargs, + ) + + # dispatch for parallel post-processing + if proc_pool is not None: + proc_future = proc_pool.submit(_post_process_patches, *func_args) + # ! manually poll future and call callback later as there is no guarantee + # ! that the callback is called from main thread + future_list.append(proc_future) + else: + proc_output = _post_process_patches(*func_args) + proc_callback(proc_output) + + if proc_pool is not None: + # loop over all to check state a.k.a polling + for future in as_completed(future_list): + # TODO: way to retrieve which file crashed ? + # ! silent crash, cancel all and raise error + if future.exception() is not None: + log_info("Silent Crash") + # ! cancel somehow leads to cascade error later + # ! so just poll it then crash once all future + # ! acquired for now + # for future in future_list: + # future.cancel() + # break + else: + file_path = proc_callback(future.result()) + log_info("Done Assembling %s" % file_path) + return + diff --git a/infer/wsi.py b/infer/wsi.py new file mode 100644 index 0000000..a8409cd --- /dev/null +++ b/infer/wsi.py @@ -0,0 +1,751 @@ +import multiprocessing as mp +from concurrent.futures import FIRST_EXCEPTION, ProcessPoolExecutor, as_completed, wait +from multiprocessing import Lock, Pool + +mp.set_start_method("spawn", True) # ! must be at top for VScode debugging + +import argparse +import glob +import json +import logging +import math +import os +import pathlib +import re +import shutil +import sys +import time +from functools import reduce +from importlib import import_module + +import cv2 +import numpy as np +import psutil +import scipy.io as sio +import torch +import torch.utils.data as data +import tqdm +from dataloader.infer_loader import SerializeArray, SerializeFileList +from docopt import docopt +from misc.utils import ( + cropping_center, + get_bounding_box, + log_debug, + log_info, + rm_n_mkdir, +) +from misc.wsi_handler import get_file_handler + +from . import base + +thread_lock = Lock() + + +#### +def _init_worker_child(lock_): + global lock + lock = lock_ + + +#### +def _remove_inst(inst_map, remove_id_list): + """Remove instances with id in remove_id_list. + + Args: + inst_map: map of instances + remove_id_list: list of ids to remove from inst_map + """ + for inst_id in remove_id_list: + inst_map[inst_map == inst_id] = 0 + return inst_map + + +#### +def _get_patch_top_left_info(img_shape, input_size, output_size): + """Get top left coordinate information of patches from original image. + + Args: + img_shape: input image shape + input_size: patch input shape + output_size: patch output shape + + """ + in_out_diff = input_size - output_size + nr_step = np.floor((img_shape - in_out_diff) / output_size) + 1 + last_output_coord = (in_out_diff // 2) + (nr_step) * output_size + # generating subpatches index from orginal + output_tl_y_list = np.arange( + in_out_diff[0] // 2, last_output_coord[0], output_size[0], dtype=np.int32 + ) + output_tl_x_list = np.arange( + in_out_diff[1] // 2, last_output_coord[1], output_size[1], dtype=np.int32 + ) + output_tl_y_list, output_tl_x_list = np.meshgrid(output_tl_y_list, output_tl_x_list) + output_tl = np.stack( + [output_tl_y_list.flatten(), output_tl_x_list.flatten()], axis=-1 + ) + input_tl = output_tl - in_out_diff // 2 + return input_tl, output_tl + + +#### all must be np.array +def _get_tile_info(img_shape, tile_shape, ambiguous_size=128): + """Get information of tiles used for post processing. + + Args: + img_shape: input image shape + tile_shape: tile shape used for post processing + ambiguous_size: used to define area at tile boundaries + + """ + # * get normal tiling set + tile_grid_top_left, _ = _get_patch_top_left_info(img_shape, tile_shape, tile_shape) + tile_grid_bot_right = [] + for idx in list(range(tile_grid_top_left.shape[0])): + tile_tl = tile_grid_top_left[idx][:2] + tile_br = tile_tl + tile_shape + axis_sel = tile_br > img_shape + tile_br[axis_sel] = img_shape[axis_sel] + tile_grid_bot_right.append(tile_br) + tile_grid_bot_right = np.array(tile_grid_bot_right) + tile_grid = np.stack([tile_grid_top_left, tile_grid_bot_right], axis=1) + tile_grid_x = np.unique(tile_grid_top_left[:, 1]) + tile_grid_y = np.unique(tile_grid_top_left[:, 0]) + # * get tiling set to fix vertical and horizontal boundary between tiles + # for sanity, expand at boundary `ambiguous_size` to both side vertical and horizontal + stack_coord = lambda x: np.stack([x[0].flatten(), x[1].flatten()], axis=-1) + tile_boundary_x_top_left = np.meshgrid( + tile_grid_y, tile_grid_x[1:] - ambiguous_size + ) + tile_boundary_x_bot_right = np.meshgrid( + tile_grid_y + tile_shape[0], tile_grid_x[1:] + ambiguous_size + ) + tile_boundary_x_top_left = stack_coord(tile_boundary_x_top_left) + tile_boundary_x_bot_right = stack_coord(tile_boundary_x_bot_right) + tile_boundary_x = np.stack( + [tile_boundary_x_top_left, tile_boundary_x_bot_right], axis=1 + ) + # + tile_boundary_y_top_left = np.meshgrid( + tile_grid_y[1:] - ambiguous_size, tile_grid_x + ) + tile_boundary_y_bot_right = np.meshgrid( + tile_grid_y[1:] + ambiguous_size, tile_grid_x + tile_shape[1] + ) + tile_boundary_y_top_left = stack_coord(tile_boundary_y_top_left) + tile_boundary_y_bot_right = stack_coord(tile_boundary_y_bot_right) + tile_boundary_y = np.stack( + [tile_boundary_y_top_left, tile_boundary_y_bot_right], axis=1 + ) + tile_boundary = np.concatenate([tile_boundary_x, tile_boundary_y], axis=0) + # * get tiling set to fix the intersection of 4 tiles + tile_cross_top_left = np.meshgrid( + tile_grid_y[1:] - 2 * ambiguous_size, tile_grid_x[1:] - 2 * ambiguous_size + ) + tile_cross_bot_right = np.meshgrid( + tile_grid_y[1:] + 2 * ambiguous_size, tile_grid_x[1:] + 2 * ambiguous_size + ) + tile_cross_top_left = stack_coord(tile_cross_top_left) + tile_cross_bot_right = stack_coord(tile_cross_bot_right) + tile_cross = np.stack([tile_cross_top_left, tile_cross_bot_right], axis=1) + return tile_grid, tile_boundary, tile_cross + + +#### +def _get_chunk_patch_info( + img_shape, chunk_input_shape, patch_input_shape, patch_output_shape +): + """Get chunk patch info. Here, chunk refers to tiles used during inference. + + Args: + img_shape: input image shape + chunk_input_shape: shape of tiles used for post processing + patch_input_shape: input patch shape + patch_output_shape: output patch shape + + """ + round_to_multiple = lambda x, y: np.floor(x / y) * y + patch_diff_shape = patch_input_shape - patch_output_shape + + chunk_output_shape = chunk_input_shape - patch_diff_shape + chunk_output_shape = round_to_multiple( + chunk_output_shape, patch_output_shape + ).astype(np.int64) + chunk_input_shape = (chunk_output_shape + patch_diff_shape).astype(np.int64) + + patch_input_tl_list, _ = _get_patch_top_left_info( + img_shape, patch_input_shape, patch_output_shape + ) + patch_input_br_list = patch_input_tl_list + patch_input_shape + patch_output_tl_list = patch_input_tl_list + patch_diff_shape + patch_output_br_list = patch_output_tl_list + patch_output_shape + patch_info_list = np.stack( + [ + np.stack([patch_input_tl_list, patch_input_br_list], axis=1), + np.stack([patch_output_tl_list, patch_output_br_list], axis=1), + ], + axis=1, + ) + + chunk_input_tl_list, _ = _get_patch_top_left_info( + img_shape, chunk_input_shape, chunk_output_shape + ) + chunk_input_br_list = chunk_input_tl_list + chunk_input_shape + # * correct the coord so it stay within source image + y_sel = np.nonzero(chunk_input_br_list[:, 0] > img_shape[0])[0] + x_sel = np.nonzero(chunk_input_br_list[:, 1] > img_shape[1])[0] + chunk_input_br_list[y_sel, 0] = ( + img_shape[0] - patch_diff_shape[0] + ) - chunk_input_tl_list[y_sel, 0] + chunk_input_br_list[x_sel, 1] = ( + img_shape[1] - patch_diff_shape[1] + ) - chunk_input_tl_list[x_sel, 1] + chunk_input_br_list[y_sel, 0] = round_to_multiple( + chunk_input_br_list[y_sel, 0], patch_output_shape[0] + ) + chunk_input_br_list[x_sel, 1] = round_to_multiple( + chunk_input_br_list[x_sel, 1], patch_output_shape[1] + ) + chunk_input_br_list[y_sel, 0] += chunk_input_tl_list[y_sel, 0] + patch_diff_shape[0] + chunk_input_br_list[x_sel, 1] += chunk_input_tl_list[x_sel, 1] + patch_diff_shape[1] + chunk_output_tl_list = chunk_input_tl_list + patch_diff_shape // 2 + chunk_output_br_list = chunk_input_br_list - patch_diff_shape // 2 # may off pixels + chunk_info_list = np.stack( + [ + np.stack([chunk_input_tl_list, chunk_input_br_list], axis=1), + np.stack([chunk_output_tl_list, chunk_output_br_list], axis=1), + ], + axis=1, + ) + + return chunk_info_list, patch_info_list + + +#### +def _post_proc_para_wrapper(pred_map_mmap_path, tile_info, func, func_kwargs): + """Wrapper for parallel post processing.""" + idx, tile_tl, tile_br = tile_info + wsi_pred_map_ptr = np.load(pred_map_mmap_path, mmap_mode="r") + tile_pred_map = wsi_pred_map_ptr[tile_tl[0] : tile_br[0], tile_tl[1] : tile_br[1]] + tile_pred_map = np.array(tile_pred_map) # from mmap to ram + return func(tile_pred_map, **func_kwargs), tile_info + + +#### +def _assemble_and_flush(wsi_pred_map_mmap_path, chunk_info, patch_output_list): + """Assemble the results. Write to newly created holder for this wsi""" + wsi_pred_map_ptr = np.load(wsi_pred_map_mmap_path, mmap_mode="r+") + chunk_pred_map = wsi_pred_map_ptr[ + chunk_info[1][0][0] : chunk_info[1][1][0], + chunk_info[1][0][1] : chunk_info[1][1][1], + ] + if patch_output_list is None: + # chunk_pred_map[:] = 0 # zero flush when there is no-results + # print(chunk_info.flatten(), 'flush 0') + return + + for pinfo in patch_output_list: + pcoord, pdata = pinfo + pdata = np.squeeze(pdata) + pcoord = np.squeeze(pcoord)[:2] + chunk_pred_map[ + pcoord[0] : pcoord[0] + pdata.shape[0], + pcoord[1] : pcoord[1] + pdata.shape[1], + ] = pdata + # print(chunk_info.flatten(), 'pass') + return + + +#### +class InferManager(base.InferManager): + def __run_model(self, patch_top_left_list, pbar_desc): + # TODO: the cost of creating dataloader may not be cheap ? + dataset = SerializeArray( + "%s/cache_chunk.npy" % self.cache_path, + patch_top_left_list, + self.patch_input_shape, + ) + + dataloader = data.DataLoader( + dataset, + num_workers=self.nr_inference_workers, + batch_size=self.batch_size, + drop_last=False, + ) + + pbar = tqdm.tqdm( + desc=pbar_desc, + leave=True, + total=int(len(dataloader)), + ncols=80, + ascii=True, + position=0, + ) + + # run inference on input patches + accumulated_patch_output = [] + for batch_idx, batch_data in enumerate(dataloader): + sample_data_list, sample_info_list = batch_data + sample_output_list = self.run_step(sample_data_list) + sample_info_list = sample_info_list.numpy() + curr_batch_size = sample_output_list.shape[0] + sample_output_list = np.split(sample_output_list, curr_batch_size, axis=0) + sample_info_list = np.split(sample_info_list, curr_batch_size, axis=0) + sample_output_list = list(zip(sample_info_list, sample_output_list)) + accumulated_patch_output.extend(sample_output_list) + pbar.update() + pbar.close() + return accumulated_patch_output + + def __select_valid_patches(self, patch_info_list, has_output_info=True): + """Select valid patches from the list of input patch information. + + Args: + patch_info_list: patch input coordinate information + has_output_info: whether output information is given + + """ + down_sample_ratio = self.wsi_mask.shape[0] / self.wsi_proc_shape[0] + selected_indices = [] + for idx in range(patch_info_list.shape[0]): + patch_info = patch_info_list[idx] + patch_info = np.squeeze(patch_info) + # get the box at corresponding mag of the mask + if has_output_info: + output_bbox = patch_info[1] * down_sample_ratio + else: + output_bbox = patch_info * down_sample_ratio + output_bbox = np.rint(output_bbox).astype(np.int64) + # coord of the output of the patch (i.e center regions) + output_roi = self.wsi_mask[ + output_bbox[0][0] : output_bbox[1][0], + output_bbox[0][1] : output_bbox[1][1], + ] + if np.sum(output_roi) > 0: + selected_indices.append(idx) + sub_patch_info_list = patch_info_list[selected_indices] + return sub_patch_info_list + + def __get_raw_prediction(self, chunk_info_list, patch_info_list): + """Process input tiles (called chunks for inference) with HoVer-Net. + + Args: + chunk_info_list: list of inference tile coordinate information + patch_info_list: list of patch coordinate information + + """ + # 1 dedicated thread just to write results back to disk + proc_pool = Pool(processes=1) + wsi_pred_map_mmap_path = "%s/pred_map.npy" % self.cache_path + + masking = lambda x, a, b: (a <= x) & (x <= b) + for idx in range(0, chunk_info_list.shape[0]): + chunk_info = chunk_info_list[idx] + # select patch basing on top left coordinate of input + start_coord = chunk_info[0, 0] + end_coord = chunk_info[0, 1] - self.patch_input_shape + selection = masking( + patch_info_list[:, 0, 0, 0], start_coord[0], end_coord[0] + ) & masking(patch_info_list[:, 0, 0, 1], start_coord[1], end_coord[1]) + chunk_patch_info_list = np.array( + patch_info_list[selection] + ) # * do we need copy ? + + # further select only the patches within the provided mask + chunk_patch_info_list = self.__select_valid_patches(chunk_patch_info_list) + + # there no valid patches, so flush 0 and skip + if chunk_patch_info_list.shape[0] == 0: + proc_pool.apply_async( + _assemble_and_flush, args=(wsi_pred_map_mmap_path, chunk_info, None) + ) + continue + + # shift the coordinare from wrt slide to wrt chunk + chunk_patch_info_list -= chunk_info[:, 0] + chunk_data = self.wsi_handler.read_region( + chunk_info[0][0][::-1], (chunk_info[0][1] - chunk_info[0][0])[::-1] + ) + chunk_data = np.array(chunk_data)[..., :3] + np.save("%s/cache_chunk.npy" % self.cache_path, chunk_data) + + pbar_desc = "Process Chunk %d/%d" % (idx, chunk_info_list.shape[0]) + patch_output_list = self.__run_model( + chunk_patch_info_list[:, 0, 0], pbar_desc + ) + + proc_pool.apply_async( + _assemble_and_flush, + args=(wsi_pred_map_mmap_path, chunk_info, patch_output_list), + ) + proc_pool.close() + proc_pool.join() + return + + def __dispatch_post_processing(self, tile_info_list, callback): + """Post processing initialisation.""" + proc_pool = None + if self.nr_post_proc_workers > 0: + proc_pool = ProcessPoolExecutor(self.nr_post_proc_workers) + + future_list = [] + wsi_pred_map_mmap_path = "%s/pred_map.npy" % self.cache_path + for idx in list(range(tile_info_list.shape[0])): + tile_tl = tile_info_list[idx][0] + tile_br = tile_info_list[idx][1] + + tile_info = (idx, tile_tl, tile_br) + func_kwargs = { + "nr_types": self.method["model_args"]["nr_types"], + "return_centroids": True, + } + + # TODO: standarize protocol + if proc_pool is not None: + proc_future = proc_pool.submit( + _post_proc_para_wrapper, + wsi_pred_map_mmap_path, + tile_info, + self.post_proc_func, + func_kwargs, + ) + # ! manually poll future and call callback later as there is no guarantee + # ! that the callback is called from main thread + future_list.append(proc_future) + else: + results = _post_proc_para_wrapper( + wsi_pred_map_mmap_path, tile_info, self.post_proc_func, func_kwargs + ) + callback(results) + if proc_pool is not None: + silent_crash = False + # loop over all to check state a.k.a polling + for future in as_completed(future_list): + # ! silent crash, cancel all and raise error + if future.exception() is not None: + silent_crash = True + # ! cancel somehow leads to cascade error later + # ! so just poll it then crash once all future + # ! acquired for now + # for future in future_list: + # future.cancel() + # break + else: + callback(future.result()) + assert not silent_crash + return + + def _parse_args(self, run_args): + """Parse command line arguments and set as instance variables.""" + for variable, value in run_args.items(): + self.__setattr__(variable, value) + # to tuple + self.chunk_shape = [self.chunk_shape, self.chunk_shape] + self.tile_shape = [self.tile_shape, self.tile_shape] + self.patch_input_shape = [self.patch_input_shape, self.patch_input_shape] + self.patch_output_shape = [self.patch_output_shape, self.patch_output_shape] + return + + def process_single_file(self, wsi_path, msk_path, output_dir): + """Process a single whole-slide image and save the results. + + Args: + wsi_path: path to input whole-slide image + msk_path: path to input mask. If not supplied, mask will be automatically generated. + output_dir: path where output will be saved + + """ + # TODO: customize universal file handler to sync the protocol + ambiguous_size = self.ambiguous_size + tile_shape = (np.array(self.tile_shape)).astype(np.int64) + chunk_input_shape = np.array(self.chunk_shape) + patch_input_shape = np.array(self.patch_input_shape) + patch_output_shape = np.array(self.patch_output_shape) + + path_obj = pathlib.Path(wsi_path) + wsi_ext = path_obj.suffix + wsi_name = path_obj.stem + + start = time.perf_counter() + self.wsi_handler = get_file_handler(wsi_path, backend=wsi_ext) + self.wsi_proc_shape = self.wsi_handler.get_dimensions(self.proc_mag) + self.wsi_handler.prepare_reading( + read_mag=self.proc_mag, cache_path="%s/src_wsi.npy" % self.cache_path + ) + self.wsi_proc_shape = np.array(self.wsi_proc_shape[::-1]) # to Y, X + + if msk_path is not None and os.path.isfile(msk_path): + self.wsi_mask = cv2.imread(msk_path) + self.wsi_mask = cv2.cvtColor(self.wsi_mask, cv2.COLOR_BGR2GRAY) + self.wsi_mask[self.wsi_mask > 0] = 1 + else: + log_info( + "WARNING: No mask found, generating mask via thresholding at 1.25x!" + ) + + from skimage import morphology + + # simple method to extract tissue regions using intensity thresholding and morphological operations + def simple_get_mask(): + scaled_wsi_mag = 1.25 # ! hard coded + wsi_thumb_rgb = self.wsi_handler.get_full_img(read_mag=scaled_wsi_mag) + gray = cv2.cvtColor(wsi_thumb_rgb, cv2.COLOR_RGB2GRAY) + _, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_OTSU) + mask = morphology.remove_small_objects( + mask == 0, min_size=16 * 16, connectivity=2 + ) + mask = morphology.remove_small_holes(mask, area_threshold=128 * 128) + mask = morphology.binary_dilation(mask, morphology.disk(16)) + return mask + + self.wsi_mask = np.array(simple_get_mask() > 0, dtype=np.uint8) + if np.sum(self.wsi_mask) == 0: + log_info("Skip due to empty mask!") + return + if self.save_mask: + cv2.imwrite("%s/mask/%s.png" % (output_dir, wsi_name), self.wsi_mask * 255) + if self.save_thumb: + wsi_thumb_rgb = self.wsi_handler.get_full_img(read_mag=1.25) + cv2.imwrite( + "%s/thumb/%s.png" % (output_dir, wsi_name), + cv2.cvtColor(wsi_thumb_rgb, cv2.COLOR_RGB2BGR), + ) + + # * declare holder for output + # create a memory-mapped .npy file with the predefined dimensions and dtype + # TODO: dynamicalize this, retrieve from model? + out_ch = 3 if self.method["model_args"]["nr_types"] is None else 4 + self.wsi_inst_info = {} + # TODO: option to use entire RAM if users have too much available, would be faster than mmap + self.wsi_inst_map = np.lib.format.open_memmap( + "%s/pred_inst.npy" % self.cache_path, + mode="w+", + shape=tuple(self.wsi_proc_shape), + dtype=np.int32, + ) + # self.wsi_inst_map[:] = 0 # flush fill + + # warning, the value within this is uninitialized + self.wsi_pred_map = np.lib.format.open_memmap( + "%s/pred_map.npy" % self.cache_path, + mode="w+", + shape=tuple(self.wsi_proc_shape) + (out_ch,), + dtype=np.float32, + ) + # ! for debug + # self.wsi_pred_map = np.load('%s/pred_map.npy' % self.cache_path, mmap_mode='r') + end = time.perf_counter() + log_info("Preparing Input Output Placement: {0}".format(end - start)) + + # * raw prediction + start = time.perf_counter() + chunk_info_list, patch_info_list = _get_chunk_patch_info( + self.wsi_proc_shape, + chunk_input_shape, + patch_input_shape, + patch_output_shape, + ) + + # get the raw prediction of HoVer-Net, given info of inference tiles and patches + self.__get_raw_prediction(chunk_info_list, patch_info_list) + end = time.perf_counter() + log_info("Inference Time: {0}".format(end - start)) + + # TODO: deal with error banding + ##### * post processing + ##### * done in 3 stages to ensure that nuclei at the boundaries are dealt with accordingly + start = time.perf_counter() + tile_coord_set = _get_tile_info(self.wsi_proc_shape, tile_shape, ambiguous_size) + # 3 sets of patches are extracted and are dealt with differently + # tile_grid_info: central region of post processing tiles + # tile_boundary_info: boundary region of post processing tiles + # tile_cross_info: region at corners of post processing tiles + tile_grid_info, tile_boundary_info, tile_cross_info = tile_coord_set + tile_grid_info = self.__select_valid_patches(tile_grid_info, False) + tile_boundary_info = self.__select_valid_patches(tile_boundary_info, False) + tile_cross_info = self.__select_valid_patches(tile_cross_info, False) + + ####################### * Callback can only receive 1 arg + def post_proc_normal_tile_callback(args): + results, pos_args = args + run_idx, tile_tl, tile_br = pos_args + pred_inst, inst_info_dict = results + + if len(inst_info_dict) == 0: + pbar.update() # external + return # when there is nothing to do + + top_left = pos_args[1][::-1] + + # ! WARNING: + # ! inst ID may not be contiguous, + # ! hence must use max as safeguard + + wsi_max_id = 0 + if len(self.wsi_inst_info) > 0: + wsi_max_id = max(self.wsi_inst_info.keys()) + for inst_id, inst_info in inst_info_dict.items(): + # now correct the coordinate wrt to wsi + inst_info["bbox"] += top_left + inst_info["contour"] += top_left + inst_info["centroid"] += top_left + self.wsi_inst_info[inst_id + wsi_max_id] = inst_info + pred_inst[pred_inst > 0] += wsi_max_id + self.wsi_inst_map[ + tile_tl[0] : tile_br[0], tile_tl[1] : tile_br[1] + ] = pred_inst + + pbar.update() # external + return + + ####################### * Callback can only receive 1 arg + def post_proc_fixing_tile_callback(args): + results, pos_args = args + run_idx, tile_tl, tile_br = pos_args + pred_inst, inst_info_dict = results + + if len(inst_info_dict) == 0: + pbar.update() # external + return # when there is nothing to do + + top_left = pos_args[1][::-1] + + # for fixing the boundary, keep all nuclei split at boundary (i.e within unambigous region) + # of the existing prediction map, and replace all nuclei within the region with newly predicted + + # ! WARNING: + # ! inst ID may not be contiguous, + # ! hence must use max as safeguard + + # ! must get before the removal happened + wsi_max_id = 0 + if len(self.wsi_inst_info) > 0: + wsi_max_id = max(self.wsi_inst_info.keys()) + + # * exclude ambiguous out from old prediction map + # check 1 pix of 4 edges to find nuclei split at boundary + roi_inst = self.wsi_inst_map[ + tile_tl[0] : tile_br[0], tile_tl[1] : tile_br[1] + ] + roi_inst = np.copy(roi_inst) + roi_edge = np.concatenate( + [roi_inst[[0, -1], :].flatten(), roi_inst[:, [0, -1]].flatten()] + ) + roi_boundary_inst_list = np.unique(roi_edge)[1:] # exclude background + roi_inner_inst_list = np.unique(roi_inst)[1:] + roi_inner_inst_list = np.setdiff1d( + roi_inner_inst_list, roi_boundary_inst_list, assume_unique=True + ) + roi_inst = _remove_inst(roi_inst, roi_inner_inst_list) + self.wsi_inst_map[ + tile_tl[0] : tile_br[0], tile_tl[1] : tile_br[1] + ] = roi_inst + for inst_id in roi_inner_inst_list: + self.wsi_inst_info.pop(inst_id, None) + + # * exclude unambiguous out from new prediction map + # check 1 pix of 4 edges to find nuclei split at boundary + roi_edge = pred_inst[roi_inst > 0] # remove all overlap + boundary_inst_list = np.unique(roi_edge) # no background to exclude + inner_inst_list = np.unique(pred_inst)[1:] + inner_inst_list = np.setdiff1d( + inner_inst_list, boundary_inst_list, assume_unique=True + ) + pred_inst = _remove_inst(pred_inst, boundary_inst_list) + + # * proceed to overwrite + for inst_id in inner_inst_list: + # ! happen because we alrd skip thoses with wrong + # ! contour (<3 points) within the postproc, so + # ! sanity gate here + if inst_id not in inst_info_dict: + log_info("Nuclei id=%d not in saved dict WRN1." % inst_id) + continue + inst_info = inst_info_dict[inst_id] + # now correct the coordinate wrt to wsi + inst_info["bbox"] += top_left + inst_info["contour"] += top_left + inst_info["centroid"] += top_left + self.wsi_inst_info[inst_id + wsi_max_id] = inst_info + pred_inst[pred_inst > 0] += wsi_max_id + pred_inst = roi_inst + pred_inst + self.wsi_inst_map[ + tile_tl[0] : tile_br[0], tile_tl[1] : tile_br[1] + ] = pred_inst + + pbar.update() # external + return + + ####################### + pbar_creator = lambda x, y: tqdm.tqdm( + desc=y, leave=True, total=int(len(x)), ncols=80, ascii=True, position=0 + ) + pbar = pbar_creator(tile_grid_info, "Post Proc Phase 1") + # * must be in sequential ordering + self.__dispatch_post_processing(tile_grid_info, post_proc_normal_tile_callback) + pbar.close() + + pbar = pbar_creator(tile_boundary_info, "Post Proc Phase 2") + self.__dispatch_post_processing( + tile_boundary_info, post_proc_fixing_tile_callback + ) + pbar.close() + + pbar = pbar_creator(tile_cross_info, "Post Proc Phase 3") + self.__dispatch_post_processing(tile_cross_info, post_proc_fixing_tile_callback) + pbar.close() + + end = time.perf_counter() + log_info("Total Post Proc Time: {0}".format(end - start)) + + # ! cant possibly save the inst map at high res, too large + start = time.perf_counter() + if self.save_mask or self.save_thumb: + json_path = "%s/json/%s.json" % (output_dir, wsi_name) + else: + json_path = "%s/%s.json" % (output_dir, wsi_name) + self.__save_json(json_path, self.wsi_inst_info, mag=self.proc_mag) + end = time.perf_counter() + log_info("Save Time: {0}".format(end - start)) + + def process_wsi_list(self, run_args): + """Process a list of whole-slide images. + + Args: + run_args: arguments as defined in run_infer.py + + """ + self._parse_args(run_args) + + if not os.path.exists(self.cache_path): + rm_n_mkdir(self.cache_path) + + if not os.path.exists(self.output_dir + "/json/"): + rm_n_mkdir(self.output_dir + "/json/") + if self.save_thumb: + if not os.path.exists(self.output_dir + "/thumb/"): + rm_n_mkdir(self.output_dir + "/thumb/") + if self.save_mask: + if not os.path.exists(self.output_dir + "/mask/"): + rm_n_mkdir(self.output_dir + "/mask/") + + wsi_path_list = glob.glob(self.input_dir + "/*") + wsi_path_list.sort() # ensure ordering + for wsi_path in wsi_path_list[:]: + wsi_base_name = pathlib.Path(wsi_path).stem + msk_path = "%s/%s.png" % (self.input_mask_dir, wsi_base_name) + if self.save_thumb or self.save_mask: + output_file = "%s/json/%s.json" % (self.output_dir, wsi_base_name) + else: + output_file = "%s/%s.json" % (self.output_dir, wsi_base_name) + if os.path.exists(output_file): + log_info("Skip: %s" % wsi_base_name) + continue + try: + log_info("Process: %s" % wsi_base_name) + self.process_single_file(wsi_path, msk_path, self.output_dir) + log_info("Finish") + except: + logging.exception("Crash") + rm_n_mkdir(self.cache_path) # clean up all cache + return diff --git a/metrics/README.md b/metrics/README.md new file mode 100644 index 0000000..ac4de15 --- /dev/null +++ b/metrics/README.md @@ -0,0 +1,55 @@ + +# Statistical Measurements for Instance Segmentation and Classification + +## Description + +In this directory, the script `stats_utils.py` contains the statistical measurements code for instance segmentation. In order of appearance, the available measurements are AJI+, AJI, DICE2, Panoptic Quality (PQ), DICE which can be access through following functions: + +`get_fast_aji()`: aji ported from the matlab code but is optimised for speed **[1]**.
+`get_fast_aji_plus()`: extension of aggregated jaccard index that doesn't suffer from over-penalisation.
+`get_dice_1()` and `get_dice_2()`: standard dice and ensemble dice (DICE2) **[2]** measures respectively.
+`get_fast_dice_2()`: ensemble dice optimised for speed.
+`get_fast_panoptic_quality()`: panoptic quality as used in **[3]**. + +## Sample + +

+ Metric +

+ +Given the predictions as above, basic difference between AJI, AJI+ and Panoptic Quality is summarized +in the following table. + +| | DICE2 | AJI | AJI+ | PQ | +| ------------- |:------:|:------:|:------:|:------:| +| Prediction A | 0.6477 | 0.4790 | 0.6375 | 0.6803 | +| Prediction B | 0.9007 | 0.6414 | 0.6414 | 0.6863 | + +## Processing + +### Instance Segmentation + +To get the instance segmentation measurements, run:
+`python compute_stats.py --mode=instance --pred_dir='pred_dir' --true_dir='true_dir'` + +Toggle `print_img_stats` to determine whether to show the stats for each image. + +### Classification + +To get the classification measurements, run:
+`python compute_stats.py --mode=type --pred_dir='pred_dir' --true_dir='true_dir'` + +The above calculates the classification metrics, as discussed in the evaluation metrics section of our paper. + + +## References +**[1]** Kumar, Neeraj, Ruchika Verma, Sanuj Sharma, Surabhi Bhargava, Abhishek Vahadane, and Amit Sethi. "A dataset and a technique for generalized nuclear segmentation for computational pathology." IEEE transactions on medical imaging 36, no. 7 (2017): 1550-1560.
+**[2]** Vu, Quoc Dang, Simon Graham, Minh Nguyen Nhat To, Muhammad Shaban, Talha Qaiser, Navid Alemi Koohbanani, Syed Ali Khurram et al. "Methods for Segmentation and Classification of Digital Microscopy Tissue Images." arXiv preprint arXiv:1810.13230 (2018).
+**[3]** Kirillov, Alexander, Kaiming He, Ross Girshick, Carsten Rother, and Piotr Dollár. "Panoptic Segmentation." arXiv preprint arXiv:1801.00868 (2018). + + + + + + + diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/metrics/stats_utils.py b/metrics/stats_utils.py new file mode 100644 index 0000000..d6fd0e0 --- /dev/null +++ b/metrics/stats_utils.py @@ -0,0 +1,429 @@ +import warnings + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import scipy +from scipy.optimize import linear_sum_assignment + + +# --------------------------Optimised for Speed +def get_fast_aji(true, pred): + """AJI version distributed by MoNuSeg, has no permutation problem but suffered from + over-penalisation similar to DICE2. + + Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] + not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no + effect on the result. + + """ + true = np.copy(true) # ? do we need this + pred = np.copy(pred) + true_id_list = list(np.unique(true)) + pred_id_list = list(np.unique(pred)) + + true_masks = [ + None, + ] + for t in true_id_list[1:]: + t_mask = np.array(true == t, np.uint8) + true_masks.append(t_mask) + + pred_masks = [ + None, + ] + for p in pred_id_list[1:]: + p_mask = np.array(pred == p, np.uint8) + pred_masks.append(p_mask) + + # prefill with value + pairwise_inter = np.zeros( + [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 + ) + pairwise_union = np.zeros( + [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 + ) + + # caching pairwise + for true_id in true_id_list[1:]: # 0-th is background + t_mask = true_masks[true_id] + pred_true_overlap = pred[t_mask > 0] + pred_true_overlap_id = np.unique(pred_true_overlap) + pred_true_overlap_id = list(pred_true_overlap_id) + for pred_id in pred_true_overlap_id: + if pred_id == 0: # ignore + continue # overlaping background + p_mask = pred_masks[pred_id] + total = (t_mask + p_mask).sum() + inter = (t_mask * p_mask).sum() + pairwise_inter[true_id - 1, pred_id - 1] = inter + pairwise_union[true_id - 1, pred_id - 1] = total - inter + + pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) + # pair of pred that give highest iou for each true, dont care + # about reusing pred instance multiple times + paired_pred = np.argmax(pairwise_iou, axis=1) + pairwise_iou = np.max(pairwise_iou, axis=1) + # exlude those dont have intersection + paired_true = np.nonzero(pairwise_iou > 0.0)[0] + paired_pred = paired_pred[paired_true] + # print(paired_true.shape, paired_pred.shape) + overall_inter = (pairwise_inter[paired_true, paired_pred]).sum() + overall_union = (pairwise_union[paired_true, paired_pred]).sum() + + paired_true = list(paired_true + 1) # index to instance ID + paired_pred = list(paired_pred + 1) + # add all unpaired GT and Prediction into the union + unpaired_true = np.array( + [idx for idx in true_id_list[1:] if idx not in paired_true] + ) + unpaired_pred = np.array( + [idx for idx in pred_id_list[1:] if idx not in paired_pred] + ) + for true_id in unpaired_true: + overall_union += true_masks[true_id].sum() + for pred_id in unpaired_pred: + overall_union += pred_masks[pred_id].sum() + + aji_score = overall_inter / overall_union + return aji_score + + +##### +def get_fast_aji_plus(true, pred): + """AJI+, an AJI version with maximal unique pairing to obtain overall intersecion. + Every prediction instance is paired with at most 1 GT instance (1 to 1) mapping, unlike AJI + where a prediction instance can be paired against many GT instances (1 to many). + Remaining unpaired GT and Prediction instances will be added to the overall union. + The 1 to 1 mapping prevents AJI's over-penalisation from happening. + + Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] + not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no + effect on the result. + + """ + true = np.copy(true) # ? do we need this + pred = np.copy(pred) + true_id_list = list(np.unique(true)) + pred_id_list = list(np.unique(pred)) + + true_masks = [ + None, + ] + for t in true_id_list[1:]: + t_mask = np.array(true == t, np.uint8) + true_masks.append(t_mask) + + pred_masks = [ + None, + ] + for p in pred_id_list[1:]: + p_mask = np.array(pred == p, np.uint8) + pred_masks.append(p_mask) + + # prefill with value + pairwise_inter = np.zeros( + [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 + ) + pairwise_union = np.zeros( + [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 + ) + + # caching pairwise + for true_id in true_id_list[1:]: # 0-th is background + t_mask = true_masks[true_id] + pred_true_overlap = pred[t_mask > 0] + pred_true_overlap_id = np.unique(pred_true_overlap) + pred_true_overlap_id = list(pred_true_overlap_id) + for pred_id in pred_true_overlap_id: + if pred_id == 0: # ignore + continue # overlaping background + p_mask = pred_masks[pred_id] + total = (t_mask + p_mask).sum() + inter = (t_mask * p_mask).sum() + pairwise_inter[true_id - 1, pred_id - 1] = inter + pairwise_union[true_id - 1, pred_id - 1] = total - inter + # + pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) + #### Munkres pairing to find maximal unique pairing + paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) + ### extract the paired cost and remove invalid pair + paired_iou = pairwise_iou[paired_true, paired_pred] + # now select all those paired with iou != 0.0 i.e have intersection + paired_true = paired_true[paired_iou > 0.0] + paired_pred = paired_pred[paired_iou > 0.0] + paired_inter = pairwise_inter[paired_true, paired_pred] + paired_union = pairwise_union[paired_true, paired_pred] + paired_true = list(paired_true + 1) # index to instance ID + paired_pred = list(paired_pred + 1) + overall_inter = paired_inter.sum() + overall_union = paired_union.sum() + # add all unpaired GT and Prediction into the union + unpaired_true = np.array( + [idx for idx in true_id_list[1:] if idx not in paired_true] + ) + unpaired_pred = np.array( + [idx for idx in pred_id_list[1:] if idx not in paired_pred] + ) + for true_id in unpaired_true: + overall_union += true_masks[true_id].sum() + for pred_id in unpaired_pred: + overall_union += pred_masks[pred_id].sum() + # + aji_score = overall_inter / overall_union + return aji_score + + +##### +def get_fast_pq(true, pred, match_iou=0.5): + """`match_iou` is the IoU threshold level to determine the pairing between + GT instances `p` and prediction instances `g`. `p` and `g` is a pair + if IoU > `match_iou`. However, pair of `p` and `g` must be unique + (1 prediction instance to 1 GT instance mapping). + + If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching + in bipartite graphs) is caculated to find the maximal amount of unique pairing. + + If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and + the number of pairs is also maximal. + + Fast computation requires instance IDs are in contiguous orderding + i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand + and `by_size` flag has no effect on the result. + + Returns: + [dq, sq, pq]: measurement statistic + + [paired_true, paired_pred, unpaired_true, unpaired_pred]: + pairing information to perform measurement + + """ + assert match_iou >= 0.0, "Cant' be negative" + + true = np.copy(true) + pred = np.copy(pred) + true_id_list = list(np.unique(true)) + pred_id_list = list(np.unique(pred)) + + true_masks = [ + None, + ] + for t in true_id_list[1:]: + t_mask = np.array(true == t, np.uint8) + true_masks.append(t_mask) + + pred_masks = [ + None, + ] + for p in pred_id_list[1:]: + p_mask = np.array(pred == p, np.uint8) + pred_masks.append(p_mask) + + # prefill with value + pairwise_iou = np.zeros( + [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 + ) + + # caching pairwise iou + for true_id in true_id_list[1:]: # 0-th is background + t_mask = true_masks[true_id] + pred_true_overlap = pred[t_mask > 0] + pred_true_overlap_id = np.unique(pred_true_overlap) + pred_true_overlap_id = list(pred_true_overlap_id) + for pred_id in pred_true_overlap_id: + if pred_id == 0: # ignore + continue # overlaping background + p_mask = pred_masks[pred_id] + total = (t_mask + p_mask).sum() + inter = (t_mask * p_mask).sum() + iou = inter / (total - inter) + pairwise_iou[true_id - 1, pred_id - 1] = iou + # + if match_iou >= 0.5: + paired_iou = pairwise_iou[pairwise_iou > match_iou] + pairwise_iou[pairwise_iou <= match_iou] = 0.0 + paired_true, paired_pred = np.nonzero(pairwise_iou) + paired_iou = pairwise_iou[paired_true, paired_pred] + paired_true += 1 # index is instance id - 1 + paired_pred += 1 # hence return back to original + else: # * Exhaustive maximal unique pairing + #### Munkres pairing with scipy library + # the algorithm return (row indices, matched column indices) + # if there is multiple same cost in a row, index of first occurence + # is return, thus the unique pairing is ensure + # inverse pair to get high IoU as minimum + paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) + ### extract the paired cost and remove invalid pair + paired_iou = pairwise_iou[paired_true, paired_pred] + + # now select those above threshold level + # paired with iou = 0.0 i.e no intersection => FP or FN + paired_true = list(paired_true[paired_iou > match_iou] + 1) + paired_pred = list(paired_pred[paired_iou > match_iou] + 1) + paired_iou = paired_iou[paired_iou > match_iou] + + # get the actual FP and FN + unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] + unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] + # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred)) + + # + tp = len(paired_true) + fp = len(unpaired_pred) + fn = len(unpaired_true) + # get the F1-score i.e DQ + dq = tp / (tp + 0.5 * fp + 0.5 * fn) + # get the SQ, no paired has 0 iou so not impact + sq = paired_iou.sum() / (tp + 1.0e-6) + + return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred] + + +##### +def get_fast_dice_2(true, pred): + """Ensemble dice.""" + true = np.copy(true) + pred = np.copy(pred) + true_id = list(np.unique(true)) + pred_id = list(np.unique(pred)) + + overall_total = 0 + overall_inter = 0 + + true_masks = [np.zeros(true.shape)] + for t in true_id[1:]: + t_mask = np.array(true == t, np.uint8) + true_masks.append(t_mask) + + pred_masks = [np.zeros(true.shape)] + for p in pred_id[1:]: + p_mask = np.array(pred == p, np.uint8) + pred_masks.append(p_mask) + + for true_idx in range(1, len(true_id)): + t_mask = true_masks[true_idx] + pred_true_overlap = pred[t_mask > 0] + pred_true_overlap_id = np.unique(pred_true_overlap) + pred_true_overlap_id = list(pred_true_overlap_id) + try: # blinly remove background + pred_true_overlap_id.remove(0) + except ValueError: + pass # just mean no background + for pred_idx in pred_true_overlap_id: + p_mask = pred_masks[pred_idx] + total = (t_mask + p_mask).sum() + inter = (t_mask * p_mask).sum() + overall_total += total + overall_inter += inter + + return 2 * overall_inter / overall_total + + +#####--------------------------As pseudocode +def get_dice_1(true, pred): + """Traditional dice.""" + # cast to binary 1st + true = np.copy(true) + pred = np.copy(pred) + true[true > 0] = 1 + pred[pred > 0] = 1 + inter = true * pred + denom = true + pred + return 2.0 * np.sum(inter) / np.sum(denom) + + +#### +def get_dice_2(true, pred): + """Ensemble Dice as used in Computational Precision Medicine Challenge.""" + true = np.copy(true) + pred = np.copy(pred) + true_id = list(np.unique(true)) + pred_id = list(np.unique(pred)) + # remove background aka id 0 + true_id.remove(0) + pred_id.remove(0) + + total_markup = 0 + total_intersect = 0 + for t in true_id: + t_mask = np.array(true == t, np.uint8) + for p in pred_id: + p_mask = np.array(pred == p, np.uint8) + intersect = p_mask * t_mask + if intersect.sum() > 0: + total_intersect += intersect.sum() + total_markup += t_mask.sum() + p_mask.sum() + return 2 * total_intersect / total_markup + + +##### +def remap_label(pred, by_size=False): + """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] + not [0, 2, 4, 6]. The ordering of instances (which one comes first) + is preserved unless by_size=True, then the instances will be reordered + so that bigger nucler has smaller ID. + + Args: + pred : the 2d array contain instances where each instances is marked + by non-zero integer + by_size : renaming with larger nuclei has smaller id (on-top) + + """ + pred_id = list(np.unique(pred)) + pred_id.remove(0) + if len(pred_id) == 0: + return pred # no label + if by_size: + pred_size = [] + for inst_id in pred_id: + size = (pred == inst_id).sum() + pred_size.append(size) + # sort the id by size in descending order + pair_list = zip(pred_id, pred_size) + pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) + pred_id, pred_size = zip(*pair_list) + + new_pred = np.zeros(pred.shape, np.int32) + for idx, inst_id in enumerate(pred_id): + new_pred[pred == inst_id] = idx + 1 + return new_pred + + +##### +def pair_coordinates(setA, setB, radius): + """Use the Munkres or Kuhn-Munkres algorithm to find the most optimal + unique pairing (largest possible match) when pairing points in set B + against points in set A, using distance as cost function. + + Args: + setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate + of N different points + radius: valid area around a point in setA to consider + a given coordinate in setB a candidate for match + Return: + pairing: pairing is an array of indices + where point at index pairing[0] in set A paired with point + in set B at index pairing[1] + unparedA, unpairedB: remaining poitn in set A and set B unpaired + + """ + # * Euclidean distance as the cost matrix + pair_distance = scipy.spatial.distance.cdist(setA, setB, metric='euclidean') + + # * Munkres pairing with scipy library + # the algorithm return (row indices, matched column indices) + # if there is multiple same cost in a row, index of first occurence + # is return, thus the unique pairing is ensured + indicesA, paired_indicesB = linear_sum_assignment(pair_distance) + + # extract the paired cost and remove instances + # outside of designated radius + pair_cost = pair_distance[indicesA, paired_indicesB] + + pairedA = indicesA[pair_cost <= radius] + pairedB = paired_indicesB[pair_cost <= radius] + + pairing = np.concatenate([pairedA[:,None], pairedB[:,None]], axis=-1) + unpairedA = np.delete(np.arange(setA.shape[0]), pairedA) + unpairedB = np.delete(np.arange(setB.shape[0]), pairedB) + return pairing, unpairedA, unpairedB diff --git a/misc/__init__.py b/misc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/misc/patch_extractor.py b/misc/patch_extractor.py new file mode 100644 index 0000000..f6aa015 --- /dev/null +++ b/misc/patch_extractor.py @@ -0,0 +1,155 @@ +import math +import time + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + +from .utils import cropping_center + + +##### +class PatchExtractor(object): + """Extractor to generate patches with or without padding. + Turn on debug mode to see how it is done. + + Args: + x : input image, should be of shape HWC + win_size : a tuple of (h, w) + step_size : a tuple of (h, w) + debug : flag to see how it is done + Return: + a list of sub patches, each patch has dtype same as x + + Examples: + >>> xtractor = PatchExtractor((450, 450), (120, 120)) + >>> img = np.full([1200, 1200, 3], 255, np.uint8) + >>> patches = xtractor.extract(img, 'mirror') + + """ + + def __init__(self, win_size, step_size, debug=False): + + self.patch_type = "mirror" + self.win_size = win_size + self.step_size = step_size + self.debug = debug + self.counter = 0 + + def __get_patch(self, x, ptx): + pty = (ptx[0] + self.win_size[0], ptx[1] + self.win_size[1]) + win = x[ptx[0] : pty[0], ptx[1] : pty[1]] + assert ( + win.shape[0] == self.win_size[0] and win.shape[1] == self.win_size[1] + ), "[BUG] Incorrect Patch Size {0}".format(win.shape) + if self.debug: + if self.patch_type == "mirror": + cen = cropping_center(win, self.step_size) + cen = cen[..., self.counter % 3] + cen.fill(150) + cv2.rectangle(x, ptx, pty, (255, 0, 0), 2) + plt.imshow(x) + plt.show(block=False) + plt.pause(1) + plt.close() + self.counter += 1 + return win + + def __extract_valid(self, x): + """Extracted patches without padding, only work in case win_size > step_size. + + Note: to deal with the remaining portions which are at the boundary a.k.a + those which do not fit when slide left->right, top->bottom), we flip + the sliding direction then extract 1 patch starting from right / bottom edge. + There will be 1 additional patch extracted at the bottom-right corner. + + Args: + x : input image, should be of shape HWC + win_size : a tuple of (h, w) + step_size : a tuple of (h, w) + Return: + a list of sub patches, each patch is same dtype as x + + """ + im_h = x.shape[0] + im_w = x.shape[1] + + def extract_infos(length, win_size, step_size): + flag = (length - win_size) % step_size != 0 + last_step = math.floor((length - win_size) / step_size) + last_step = (last_step + 1) * step_size + return flag, last_step + + h_flag, h_last = extract_infos(im_h, self.win_size[0], self.step_size[0]) + w_flag, w_last = extract_infos(im_w, self.win_size[1], self.step_size[1]) + + sub_patches = [] + #### Deal with valid block + for row in range(0, h_last, self.step_size[0]): + for col in range(0, w_last, self.step_size[1]): + win = self.__get_patch(x, (row, col)) + sub_patches.append(win) + #### Deal with edge case + if h_flag: + row = im_h - self.win_size[0] + for col in range(0, w_last, self.step_size[1]): + win = self.__get_patch(x, (row, col)) + sub_patches.append(win) + if w_flag: + col = im_w - self.win_size[1] + for row in range(0, h_last, self.step_size[0]): + win = self.__get_patch(x, (row, col)) + sub_patches.append(win) + if h_flag and w_flag: + ptx = (im_h - self.win_size[0], im_w - self.win_size[1]) + win = self.__get_patch(x, ptx) + sub_patches.append(win) + return sub_patches + + def __extract_mirror(self, x): + """Extracted patches with mirror padding the boundary such that the + central region of each patch is always within the orginal (non-padded) + image while all patches' central region cover the whole orginal image. + + Args: + x : input image, should be of shape HWC + win_size : a tuple of (h, w) + step_size : a tuple of (h, w) + Return: + a list of sub patches, each patch is same dtype as x + + """ + diff_h = self.win_size[0] - self.step_size[0] + padt = diff_h // 2 + padb = diff_h - padt + + diff_w = self.win_size[1] - self.step_size[1] + padl = diff_w // 2 + padr = diff_w - padl + + pad_type = "constant" if self.debug else "reflect" + x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type) + sub_patches = self.__extract_valid(x) + return sub_patches + + def extract(self, x, patch_type): + patch_type = patch_type.lower() + self.patch_type = patch_type + if patch_type == "valid": + return self.__extract_valid(x) + elif patch_type == "mirror": + return self.__extract_mirror(x) + else: + assert False, "Unknown Patch Type [%s]" % patch_type + return + + +# ---------------------------------------------------------------------------- + +if __name__ == "__main__": + # toy example for debug + # 355x355, 480x480 + xtractor = PatchExtractor((450, 450), (120, 120), debug=True) + a = np.full([1200, 1200, 3], 255, np.uint8) + xtractor.extract(a, "mirror") + xtractor.extract(a, "valid") diff --git a/misc/utils.py b/misc/utils.py new file mode 100644 index 0000000..2b59dc1 --- /dev/null +++ b/misc/utils.py @@ -0,0 +1,182 @@ +import glob +import inspect +import logging +import os +import shutil + +import cv2 +import numpy as np +from scipy import ndimage + + +#### +def normalize(mask, dtype=np.uint8): + return (255 * mask / np.amax(mask)).astype(dtype) + + +#### +def get_bounding_box(img): + """Get bounding box coordinate information.""" + rows = np.any(img, axis=1) + cols = np.any(img, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + # due to python indexing, need to add 1 to max + # else accessing will be 1px in the box, not out + rmax += 1 + cmax += 1 + return [rmin, rmax, cmin, cmax] + + +#### +def cropping_center(x, crop_shape, batch=False): + """Crop an input image at the centre. + + Args: + x: input array + crop_shape: dimensions of cropped array + + Returns: + x: cropped array + + """ + orig_shape = x.shape + if not batch: + h0 = int((orig_shape[0] - crop_shape[0]) * 0.5) + w0 = int((orig_shape[1] - crop_shape[1]) * 0.5) + x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] + else: + h0 = int((orig_shape[1] - crop_shape[0]) * 0.5) + w0 = int((orig_shape[2] - crop_shape[1]) * 0.5) + x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] + return x + + +#### +def rm_n_mkdir(dir_path): + """Remove and make directory.""" + if os.path.isdir(dir_path): + shutil.rmtree(dir_path) + os.makedirs(dir_path) + + +#### +def mkdir(dir_path): + """Make directory.""" + if not os.path.isdir(dir_path): + os.makedirs(dir_path) + + +#### +def get_inst_centroid(inst_map): + """Get instance centroids given an input instance map. + + Args: + inst_map: input instance map + + Returns: + array of centroids + + """ + inst_centroid_list = [] + inst_id_list = list(np.unique(inst_map)) + for inst_id in inst_id_list[1:]: # avoid 0 i.e background + mask = np.array(inst_map == inst_id, np.uint8) + inst_moment = cv2.moments(mask) + inst_centroid = [ + (inst_moment["m10"] / inst_moment["m00"]), + (inst_moment["m01"] / inst_moment["m00"]), + ] + inst_centroid_list.append(inst_centroid) + return np.array(inst_centroid_list) + + +#### +def center_pad_to_shape(img, size, cval=255): + """Pad input image.""" + # rounding down, add 1 + pad_h = size[0] - img.shape[0] + pad_w = size[1] - img.shape[1] + pad_h = (pad_h // 2, pad_h - pad_h // 2) + pad_w = (pad_w // 2, pad_w - pad_w // 2) + if len(img.shape) == 2: + pad_shape = (pad_h, pad_w) + else: + pad_shape = (pad_h, pad_w, (0, 0)) + img = np.pad(img, pad_shape, "constant", constant_values=cval) + return img + + +#### +def color_deconvolution(rgb, stain_mat): + """Apply colour deconvolution.""" + log255 = np.log(255) # to base 10, not base e + rgb_float = rgb.astype(np.float64) + log_rgb = -((255.0 * np.log((rgb_float + 1) / 255.0)) / log255) + output = np.exp(-(log_rgb @ stain_mat - 255.0) * log255 / 255.0) + output[output > 255] = 255 + output = np.floor(output + 0.5).astype("uint8") + return output + + +#### +def log_debug(msg): + frame, filename, line_number, function_name, lines, index = inspect.getouterframes( + inspect.currentframe() + )[1] + line = lines[0] + indentation_level = line.find(line.lstrip()) + logging.debug("{i} {m}".format(i="." * indentation_level, m=msg)) + + +#### +def log_info(msg): + frame, filename, line_number, function_name, lines, index = inspect.getouterframes( + inspect.currentframe() + )[1] + line = lines[0] + indentation_level = line.find(line.lstrip()) + logging.info("{i} {m}".format(i="." * indentation_level, m=msg)) + + +def remove_small_objects(pred, min_size=64, connectivity=1): + """Remove connected components smaller than the specified size. + + This function is taken from skimage.morphology.remove_small_objects, but the warning + is removed when a single label is provided. + + Args: + pred: input labelled array + min_size: minimum size of instance in output array + connectivity: The connectivity defining the neighborhood of a pixel. + + Returns: + out: output array with instances removed under min_size + + """ + out = pred + + if min_size == 0: # shortcut for efficiency + return out + + if out.dtype == bool: + selem = ndimage.generate_binary_structure(pred.ndim, connectivity) + ccs = np.zeros_like(pred, dtype=np.int32) + ndimage.label(pred, selem, output=ccs) + else: + ccs = out + + try: + component_sizes = np.bincount(ccs.ravel()) + except ValueError: + raise ValueError( + "Negative value labels are not supported. Try " + "relabeling the input with `scipy.ndimage.label` or " + "`skimage.morphology.label`." + ) + + too_small = component_sizes < min_size + too_small_mask = too_small[ccs] + out[too_small_mask] = 0 + + return out diff --git a/misc/viz_utils.py b/misc/viz_utils.py new file mode 100644 index 0000000..798f798 --- /dev/null +++ b/misc/viz_utils.py @@ -0,0 +1,173 @@ +import cv2 +import math +import random +import colorsys +import numpy as np +import itertools +import matplotlib.pyplot as plt +from matplotlib import cm + +from .utils import get_bounding_box + +#### +def colorize(ch, vmin, vmax): + """Will clamp value value outside the provided range to vmax and vmin.""" + cmap = plt.get_cmap("jet") + ch = np.squeeze(ch.astype("float32")) + vmin = vmin if vmin is not None else ch.min() + vmax = vmax if vmax is not None else ch.max() + ch[ch > vmax] = vmax # clamp value + ch[ch < vmin] = vmin + ch = (ch - vmin) / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + return ch_cmap + + +#### +def random_colors(N, bright=True): + """Generate random colors. + + To get visually distinct colors, generate them in HSV space then + convert to RGB. + """ + brightness = 1.0 if bright else 0.7 + hsv = [(i / N, 1, brightness) for i in range(N)] + colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) + random.shuffle(colors) + return colors + + +#### +def visualize_instances_map( + input_image, inst_map, type_map=None, type_colour=None, line_thickness=2 +): + """Overlays segmentation results on image as contours. + + Args: + input_image: input image + inst_map: instance mask with unique value for every object + type_map: type mask with unique value for every class + type_colour: a dict of {type : colour} , `type` is from 0-N + and `colour` is a tuple of (R, G, B) + line_thickness: line thickness of contours + + Returns: + overlay: output image with segmentation overlay as contours + """ + overlay = np.copy((input_image).astype(np.uint8)) + + inst_list = list(np.unique(inst_map)) # get list of instances + inst_list.remove(0) # remove background + + inst_rng_colors = random_colors(len(inst_list)) + inst_rng_colors = np.array(inst_rng_colors) * 255 + inst_rng_colors = inst_rng_colors.astype(np.uint8) + + for inst_idx, inst_id in enumerate(inst_list): + inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object + y1, y2, x1, x2 = get_bounding_box(inst_map_mask) + y1 = y1 - 2 if y1 - 2 >= 0 else y1 + x1 = x1 - 2 if x1 - 2 >= 0 else x1 + x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 + y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 + inst_map_crop = inst_map_mask[y1:y2, x1:x2] + contours_crop = cv2.findContours( + inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + # only has 1 instance per map, no need to check #contour detected by opencv + contours_crop = np.squeeze( + contours_crop[0][0].astype("int32") + ) # * opencv protocol format may break + contours_crop += np.asarray([[x1, y1]]) # index correction + if type_map is not None: + type_map_crop = type_map[y1:y2, x1:x2] + type_id = np.unique(type_map_crop).max() # non-zero + inst_colour = type_colour[type_id] + else: + inst_colour = (inst_rng_colors[inst_idx]).tolist() + cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness) + return overlay + + +#### +def visualize_instances_dict( + input_image, inst_dict, draw_dot=False, type_colour=None, line_thickness=2 +): + """Overlays segmentation results (dictionary) on image as contours. + + Args: + input_image: input image + inst_dict: dict of output prediction, defined as in this library + draw_dot: to draw a dot for each centroid + type_colour: a dict of {type_id : (type_name, colour)} , + `type_id` is from 0-N and `colour` is a tuple of (R, G, B) + line_thickness: line thickness of contours + """ + overlay = np.copy((input_image)) + + inst_rng_colors = random_colors(len(inst_dict)) + inst_rng_colors = np.array(inst_rng_colors) * 255 + inst_rng_colors = inst_rng_colors.astype(np.uint8) + + for idx, [inst_id, inst_info] in enumerate(inst_dict.items()): + inst_contour = inst_info["contour"] + if "type" in inst_info and type_colour is not None: + inst_colour = type_colour[inst_info["type"]][1] + else: + inst_colour = (inst_rng_colors[idx]).tolist() + cv2.drawContours(overlay, [inst_contour], -1, inst_colour, line_thickness) + + if draw_dot: + inst_centroid = inst_info["centroid"] + inst_centroid = tuple([int(v) for v in inst_centroid]) + overlay = cv2.circle(overlay, inst_centroid, 3, (255, 0, 0), -1) + return overlay + + +#### +def gen_figure( + imgs_list, + titles, + fig_inch, + shape=None, + share_ax="all", + show=False, + colormap=plt.get_cmap("jet"), +): + """Generate figure.""" + num_img = len(imgs_list) + if shape is None: + ncols = math.ceil(math.sqrt(num_img)) + nrows = math.ceil(num_img / ncols) + else: + nrows, ncols = shape + + # generate figure + fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=share_ax, sharey=share_ax) + axes = [axes] if nrows == 1 else axes + + # not very elegant + idx = 0 + for ax in axes: + for cell in ax: + cell.set_title(titles[idx]) + cell.imshow(imgs_list[idx], cmap=colormap) + cell.tick_params( + axis="both", + which="both", + bottom="off", + top="off", + labelbottom="off", + right="off", + left="off", + labelleft="off", + ) + idx += 1 + if idx == len(titles): + break + if idx == len(titles): + break + + fig.tight_layout() + return fig diff --git a/misc/wsi_handler.py b/misc/wsi_handler.py new file mode 100644 index 0000000..49aa4ea --- /dev/null +++ b/misc/wsi_handler.py @@ -0,0 +1,204 @@ +from collections import OrderedDict +import cv2 +import numpy as np +from skimage import img_as_ubyte +from skimage import color +import re +import subprocess + +import openslide + + +class FileHandler(object): + def __init__(self): + """The handler is responsible for storing the processed data, parsing + the metadata from original file, and reading it from storage. + """ + self.metadata = { + ("available_mag", None), + ("base_mag", None), + ("vendor", None), + ("mpp ", None), + ("base_shape", None), + } + pass + + def __load_metadata(self): + raise NotImplementedError + + def get_full_img(self, read_mag=None, read_mpp=None): + """Only use `read_mag` or `read_mpp`, not both, prioritize `read_mpp`. + + `read_mpp` is in X, Y format + """ + raise NotImplementedError + + def read_region(self, coords, size): + """Must call `prepare_reading` before hand. + + Args: + coords (tuple): (dims_x, dims_y), + top left coordinates of image region at selected + `read_mag` or `read_mpp` from `prepare_reading` + size (tuple): (dims_x, dims_y) + width and height of image region at selected + `read_mag` or `read_mpp` from `prepare_reading` + + """ + raise NotImplementedError + + def get_dimensions(self, read_mag=None, read_mpp=None): + """Will be in X, Y.""" + if read_mpp is not None: + read_scale = (self.metadata["base_mpp"] / read_mpp)[0] + read_mag = read_scale * self.metadata["base_mag"] + scale = read_mag / self.metadata["base_mag"] + # may off some pixels wrt existing mag + return (self.metadata["base_shape"] * scale).astype(np.int32) + + def prepare_reading(self, read_mag=None, read_mpp=None, cache_path=None): + """Only use `read_mag` or `read_mpp`, not both, prioritize `read_mpp`. + + `read_mpp` is in X, Y format. + """ + read_lv, scale_factor = self._get_read_info( + read_mag=read_mag, read_mpp=read_mpp + ) + + if scale_factor is None: + self.image_ptr = None + self.read_lv = read_lv + else: + np.save(cache_path, self.get_full_img(read_mag=read_mag)) + self.image_ptr = np.load(cache_path, mmap_mode="r") + return + + def _get_read_info(self, read_mag=None, read_mpp=None): + if read_mpp is not None: + assert read_mpp[0] == read_mpp[1], "Not supported uneven `read_mpp`" + read_scale = (self.metadata["base_mpp"] / read_mpp)[0] + read_mag = read_scale * self.metadata["base_mag"] + + hires_mag = read_mag + scale_factor = None + if read_mag not in self.metadata["available_mag"]: + if read_mag > self.metadata["base_mag"]: + scale_factor = read_mag / self.metadata["base_mag"] + hires_mag = self.metadata["base_mag"] + else: + mag_list = np.array(self.metadata["available_mag"]) + mag_list = np.sort(mag_list)[::-1] + hires_mag = mag_list - read_mag + # only use higher mag as base for loading + hires_mag = hires_mag[hires_mag > 0] + # use the immediate higher to save compuration + hires_mag = mag_list[np.argmin(hires_mag)] + scale_factor = read_mag / hires_mag + + hires_lv = self.metadata["available_mag"].index(hires_mag) + return hires_lv, scale_factor + + +class OpenSlideHandler(FileHandler): + """Class for handling OpenSlide supported whole-slide images.""" + + def __init__(self, file_path): + """file_path (string): path to single whole-slide image.""" + super().__init__() + self.file_ptr = openslide.OpenSlide(file_path) # load OpenSlide object + self.metadata = self.__load_metadata() + + # only used for cases where the read magnification is different from + self.image_ptr = None # the existing modes of the read file + self.read_level = None + + def __load_metadata(self): + metadata = {} + + wsi_properties = self.file_ptr.properties + level_0_magnification = wsi_properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER] + level_0_magnification = float(level_0_magnification) + + downsample_level = self.file_ptr.level_downsamples + magnification_level = [level_0_magnification / lv for lv in downsample_level] + + mpp = [ + wsi_properties[openslide.PROPERTY_NAME_MPP_X], + wsi_properties[openslide.PROPERTY_NAME_MPP_Y], + ] + mpp = np.array(mpp) + + metadata = [ + ("available_mag", magnification_level), # highest to lowest mag + ("base_mag", magnification_level[0]), + ("vendor", wsi_properties[openslide.PROPERTY_NAME_VENDOR]), + ("mpp ", mpp), + ("base_shape", np.array(self.file_ptr.dimensions)), + ] + return OrderedDict(metadata) + + def read_region(self, coords, size): + """Must call `prepare_reading` before hand. + + Args: + coords (tuple): (dims_x, dims_y), + top left coordinates of image region at selected + `read_mag` or `read_mpp` from `prepare_reading` + size (tuple): (dims_x, dims_y) + width and height of image region at selected + `read_mag` or `read_mpp` from `prepare_reading` + + """ + if self.image_ptr is None: + # convert coord from read lv to lv zero + lv_0_shape = np.array(self.file_ptr.level_dimensions[0]) + lv_r_shape = np.array(self.file_ptr.level_dimensions[self.read_lv]) + up_sample = (lv_0_shape / lv_r_shape)[0] + new_coord = [0, 0] + new_coord[0] = int(coords[0] * up_sample) + new_coord[1] = int(coords[1] * up_sample) + region = self.file_ptr.read_region(new_coord, self.read_lv, size) + else: + region = self.image_ptr[ + coords[1] : coords[1] + size[1], coords[0] : coords[0] + size[0] + ] + return np.array(region)[..., :3] + + def get_full_img(self, read_mag=None, read_mpp=None): + """Only use `read_mag` or `read_mpp`, not both, prioritize `read_mpp`. + + `read_mpp` is in X, Y format. + """ + + read_lv, scale_factor = self._get_read_info( + read_mag=read_mag, read_mpp=read_mpp + ) + + read_size = self.file_ptr.level_dimensions[read_lv] + + wsi_img = self.file_ptr.read_region((0, 0), read_lv, read_size) + wsi_img = np.array(wsi_img)[..., :3] # remove alpha channel + if scale_factor is not None: + # now rescale then return + if scale_factor > 1.0: + interp = cv2.INTER_CUBIC + else: + interp = cv2.INTER_LINEAR + wsi_img = cv2.resize( + wsi_img, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=interp + ) + return wsi_img + + +def get_file_handler(path, backend): + if backend in [ + '.svs', '.tif', + '.vms', '.vmu', '.ndpi', + '.scn', '.mrxs', '.tiff', + '.svslide', + '.bif', + ]: + return OpenSlideHandler(path) + else: + assert False, "Unknown WSI format `%s`" % backend + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/hovernet/__init__.py b/models/hovernet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/hovernet/net_desc.py b/models/hovernet/net_desc.py new file mode 100644 index 0000000..745f201 --- /dev/null +++ b/models/hovernet/net_desc.py @@ -0,0 +1,153 @@ +import math +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .net_utils import (DenseBlock, Net, ResidualBlock, TFSamepaddingLayer, + UpSample2x) +from .utils import crop_op, crop_to_shape + +#### +class HoVerNet(Net): + """Initialise HoVer-Net.""" + + def __init__(self, input_ch=3, nr_types=None, freeze=False, mode='original'): + super().__init__() + self.mode = mode + self.freeze = freeze + self.nr_types = nr_types + self.output_ch = 3 if nr_types is None else 4 + + assert mode == 'original' or mode == 'fast', \ + 'Unknown mode `%s` for HoVerNet %s. Only support `original` or `fast`.' % mode + + module_list = [ + ("/", nn.Conv2d(input_ch, 64, 7, stride=1, padding=0, bias=False)), + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + if mode == 'fast': # prepend the padding for `fast` mode + module_list = [("pad", TFSamepaddingLayer(ksize=7, stride=1))] + module_list + + self.conv0 = nn.Sequential(OrderedDict(module_list)) + self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) + self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) + self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) + self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) + + self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) + + def create_decoder_branch(out_ch=2, ksize=5): + module_list = [ + ("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)), + ("convf", nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False),), + ] + u3 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)), + ("convf", nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False),), + ] + u2 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)), + ("conva", nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False),), + ] + u1 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ("conv", nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True),), + ] + u0 = nn.Sequential(OrderedDict(module_list)) + + decoder = nn.Sequential( + OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0),]) + ) + return decoder + + ksize = 5 if mode == 'original' else 3 + if nr_types is None: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("np", create_decoder_branch(ksize=ksize,out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize,out_ch=2)), + ] + ) + ) + else: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("tp", create_decoder_branch(ksize=ksize, out_ch=nr_types)), + ("np", create_decoder_branch(ksize=ksize, out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize, out_ch=2)), + ] + ) + ) + + self.upsample2x = UpSample2x() + # TODO: pytorch still require the channel eventhough its ignored + self.weights_init() + + def forward(self, imgs): + + imgs = imgs / 255.0 # to 0-1 range to match XY + + if self.training: + d0 = self.conv0(imgs) + d0 = self.d0(d0, self.freeze) + with torch.set_grad_enabled(not self.freeze): + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + else: + d0 = self.conv0(imgs) + d0 = self.d0(d0) + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + + # TODO: switch to `crop_to_shape` ? + if self.mode == 'original': + d[0] = crop_op(d[0], [184, 184]) + d[1] = crop_op(d[1], [72, 72]) + else: + d[0] = crop_op(d[0], [92, 92]) + d[1] = crop_op(d[1], [36, 36]) + + out_dict = OrderedDict() + for branch_name, branch_desc in self.decoder.items(): + u3 = self.upsample2x(d[-1]) + d[-2] + u3 = branch_desc[0](u3) + + u2 = self.upsample2x(u3) + d[-3] + u2 = branch_desc[1](u2) + + u1 = self.upsample2x(u2) + d[-4] + u1 = branch_desc[2](u1) + + u0 = branch_desc[3](u1) + out_dict[branch_name] = u0 + + return out_dict + + +#### +def create_model(mode=None, **kwargs): + if mode not in ['original', 'fast']: + assert "Unknown Model Mode %s" % mode + return HoVerNet(mode=mode, **kwargs) + diff --git a/models/hovernet/net_utils.py b/models/hovernet/net_utils.py new file mode 100644 index 0000000..7f13624 --- /dev/null +++ b/models/hovernet/net_utils.py @@ -0,0 +1,295 @@ +import numpy as np +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict + +from .utils import crop_op, crop_to_shape +from config import Config + + +#### +class Net(nn.Module): + """ A base class provides a common weight initialisation scheme.""" + + def weights_init(self): + for m in self.modules(): + classname = m.__class__.__name__ + + # ! Fixed the type checking + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + + if "norm" in classname.lower(): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if "linear" in classname.lower(): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + return x + + +#### +class TFSamepaddingLayer(nn.Module): + """To align with tf `same` padding. + + Putting this before any conv layer that need padding + Assuming kernel has Height == Width for simplicity + """ + + def __init__(self, ksize, stride): + super(TFSamepaddingLayer, self).__init__() + self.ksize = ksize + self.stride = stride + + def forward(self, x): + if x.shape[2] % self.stride == 0: + pad = max(self.ksize - self.stride, 0) + else: + pad = max(self.ksize - (x.shape[2] % self.stride), 0) + + if pad % 2 == 0: + pad_val = pad // 2 + padding = (pad_val, pad_val, pad_val, pad_val) + else: + pad_val_start = pad // 2 + pad_val_end = pad - pad_val_start + padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end) + # print(x.shape, padding) + x = F.pad(x, padding, "constant", 0) + # print(x.shape) + return x + + +#### +class DenseBlock(Net): + """Dense Block as defined in: + + Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. + "Densely connected convolutional networks." In Proceedings of the IEEE conference + on computer vision and pattern recognition, pp. 4700-4708. 2017. + + Only performs `valid` convolution. + + """ + + def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1): + super(DenseBlock, self).__init__() + assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" + + self.nr_unit = unit_count + self.in_ch = in_ch + self.unit_ch = unit_ch + + # ! For inference only so init values for batchnorm may not match tensorflow + unit_in_ch = in_ch + self.units = nn.ModuleList() + for idx in range(unit_count): + self.units.append( + nn.Sequential( + OrderedDict( + [ + ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("preact_bna/relu", nn.ReLU(inplace=True)), + ( + "conv1", + nn.Conv2d( + unit_in_ch, + unit_ch[0], + unit_ksize[0], + stride=1, + padding=0, + bias=False, + ), + ), + ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), + ("conv1/relu", nn.ReLU(inplace=True)), + # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)), + ( + "conv2", + nn.Conv2d( + unit_ch[0], + unit_ch[1], + unit_ksize[1], + groups=split, + stride=1, + padding=0, + bias=False, + ), + ), + ] + ) + ) + ) + unit_in_ch += unit_ch[1] + + self.blk_bna = nn.Sequential( + OrderedDict( + [ + ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + + def out_ch(self): + return self.in_ch + self.nr_unit * self.unit_ch[-1] + + def forward(self, prev_feat): + for idx in range(self.nr_unit): + new_feat = self.units[idx](prev_feat) + prev_feat = crop_to_shape(prev_feat, new_feat) + prev_feat = torch.cat([prev_feat, new_feat], dim=1) + prev_feat = self.blk_bna(prev_feat) + + return prev_feat + + +#### +class ResidualBlock(Net): + """Residual block as defined in: + + He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning + for image recognition." In Proceedings of the IEEE conference on computer vision + and pattern recognition, pp. 770-778. 2016. + + """ + + def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1): + super(ResidualBlock, self).__init__() + assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" + + self.nr_unit = unit_count + self.in_ch = in_ch + self.unit_ch = unit_ch + + # ! For inference only so init values for batchnorm may not match tensorflow + unit_in_ch = in_ch + self.units = nn.ModuleList() + for idx in range(unit_count): + unit_layer = [ + ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("preact/relu", nn.ReLU(inplace=True)), + ( + "conv1", + nn.Conv2d( + unit_in_ch, + unit_ch[0], + unit_ksize[0], + stride=1, + padding=0, + bias=False, + ), + ), + ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), + ("conv1/relu", nn.ReLU(inplace=True)), + ( + "conv2/pad", + TFSamepaddingLayer( + ksize=unit_ksize[1], stride=stride if idx == 0 else 1 + ), + ), + ( + "conv2", + nn.Conv2d( + unit_ch[0], + unit_ch[1], + unit_ksize[1], + stride=stride if idx == 0 else 1, + padding=0, + bias=False, + ), + ), + ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)), + ("conv2/relu", nn.ReLU(inplace=True)), + ( + "conv3", + nn.Conv2d( + unit_ch[1], + unit_ch[2], + unit_ksize[2], + stride=1, + padding=0, + bias=False, + ), + ), + ] + # * has bna to conclude each previous block so + # * must not put preact for the first unit of this block + unit_layer = unit_layer if idx != 0 else unit_layer[2:] + self.units.append(nn.Sequential(OrderedDict(unit_layer))) + unit_in_ch = unit_ch[-1] + + if in_ch != unit_ch[-1] or stride != 1: + self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False) + else: + self.shortcut = None + + self.blk_bna = nn.Sequential( + OrderedDict( + [ + ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + + # print(self.units[0]) + # print(self.units[1]) + # exit() + + def out_ch(self): + return self.unit_ch[-1] + + def forward(self, prev_feat, freeze=False): + if self.shortcut is None: + shortcut = prev_feat + else: + shortcut = self.shortcut(prev_feat) + + for idx in range(0, len(self.units)): + new_feat = prev_feat + if self.training: + with torch.set_grad_enabled(not freeze): + new_feat = self.units[idx](new_feat) + else: + new_feat = self.units[idx](new_feat) + prev_feat = new_feat + shortcut + shortcut = prev_feat + feat = self.blk_bna(prev_feat) + return feat + + +#### +class UpSample2x(nn.Module): + """Upsample input by a factor of 2. + + Assume input is of NCHW, port FixedUnpooling from TensorPack. + """ + + def __init__(self): + super(UpSample2x, self).__init__() + # correct way to create constant within module + self.register_buffer( + "unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32")) + ) + self.unpool_mat.unsqueeze(0) + + def forward(self, x): + input_shape = list(x.shape) + # unsqueeze is expand_dims equivalent + # permute is transpose equivalent + # view is reshape equivalent + x = x.unsqueeze(-1) # bchwx1 + mat = self.unpool_mat.unsqueeze(0) # 1xshxsw + ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw + ret = ret.permute(0, 1, 2, 4, 3, 5) + ret = ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2)) + return ret + diff --git a/models/hovernet/opt.py b/models/hovernet/opt.py new file mode 100644 index 0000000..049a146 --- /dev/null +++ b/models/hovernet/opt.py @@ -0,0 +1,142 @@ +import torch.optim as optim + +from run_utils.callbacks.base import ( + AccumulateRawOutput, + PeriodicSaver, + ProcessAccumulatedRawOutput, + ScalarMovingAverage, + ScheduleLr, + TrackLr, + VisualizeOutput, + TriggerEngine, +) +from run_utils.callbacks.logging import LoggingEpochOutput, LoggingGradient +from run_utils.engine import Events + +from .targets import gen_targets, prep_sample +from .net_desc import create_model +from .run_desc import proc_valid_step_output, train_step, valid_step, viz_step_output + + +# TODO: training config only ? +# TODO: switch all to function name String for all option +def get_config(nr_type, mode): + return { + # ------------------------------------------------------------------ + # ! All phases have the same number of run engine + # phases are run sequentially from index 0 to N + "phase_list": [ + { + "run_info": { + # may need more dynamic for each network + "net": { + "desc": lambda: create_model( + input_ch=3, nr_types=nr_type, + freeze=True, mode=mode + ), + "optimizer": [ + optim.Adam, + { # should match keyword for parameters within the optimizer + "lr": 1.0e-4, # initial learning rate, + "betas": (0.9, 0.999), + }, + ], + # learning rate scheduler + "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), + "extra_info": { + "loss": { + "np": {"bce": 1, "dice": 1}, + "hv": {"mse": 1, "msge": 1}, + "tp": {"bce": 1, "dice": 1}, + }, + }, + # path to load, -1 to auto load checkpoint from previous phase, + # None to start from scratch + "pretrained": "ImageNet-ResNet50-Preact_pytorch.tar", + # 'pretrained': None, + }, + }, + "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, + "batch_size": {"train": 2, "valid": 2,}, # engine name : value + "nr_epochs": 50, + }, + { + "run_info": { + # may need more dynamic for each network + "net": { + "desc": lambda: create_model( + input_ch=3, nr_types=nr_type, + freeze=False, mode=mode + ), + "optimizer": [ + optim.Adam, + { # should match keyword for parameters within the optimizer + "lr": 1.0e-4, # initial learning rate, + "betas": (0.9, 0.999), + }, + ], + # learning rate scheduler + "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), + "extra_info": { + "loss": { + "np": {"bce": 1, "dice": 1}, + "hv": {"mse": 1, "msge": 1}, + "tp": {"bce": 1, "dice": 1}, + }, + }, + # path to load, -1 to auto load checkpoint from previous phase, + # None to start from scratch + "pretrained": -1, + }, + }, + "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, + "batch_size": {"train": 2, "valid": 2,}, # batch size per gpu + "nr_epochs": 0, + }, + ], + # ------------------------------------------------------------------ + # TODO: dynamically for dataset plugin selection and processing also? + # all enclosed engine shares the same neural networks + # as the on at the outer calling it + "run_engine": { + "train": { + # TODO: align here, file path or what? what about CV? + "dataset": "", # whats about compound dataset ? + "nr_procs": 16, # number of threads for dataloader + "run_step": train_step, # TODO: function name or function variable ? + "reset_per_run": False, + # callbacks are run according to the list order of the event + "callbacks": { + Events.STEP_COMPLETED: [ + # LoggingGradient(), # TODO: very slow, may be due to back forth of tensor/numpy ? + ScalarMovingAverage(), + ], + Events.EPOCH_COMPLETED: [ + TrackLr(), + PeriodicSaver(), + VisualizeOutput(viz_step_output), + LoggingEpochOutput(), + TriggerEngine("valid"), + ScheduleLr(), + ], + }, + }, + "valid": { + "dataset": "", # whats about compound dataset ? + "nr_procs": 8, # number of threads for dataloader + "run_step": valid_step, + "reset_per_run": True, # * to stop aggregating output etc. from last run + # callbacks are run according to the list order of the event + "callbacks": { + Events.STEP_COMPLETED: [AccumulateRawOutput(),], + Events.EPOCH_COMPLETED: [ + # TODO: is there way to preload these ? + ProcessAccumulatedRawOutput( + lambda a: proc_valid_step_output(a, nr_types=nr_type) + ), + LoggingEpochOutput(), + ], + }, + }, + }, + } diff --git a/models/hovernet/post_proc.py b/models/hovernet/post_proc.py new file mode 100644 index 0000000..4fe1ffb --- /dev/null +++ b/models/hovernet/post_proc.py @@ -0,0 +1,186 @@ +import cv2 +import numpy as np + +from scipy.ndimage import filters, measurements +from scipy.ndimage.morphology import ( + binary_dilation, + binary_fill_holes, + distance_transform_cdt, + distance_transform_edt, +) + +from skimage.segmentation import watershed +from misc.utils import get_bounding_box, remove_small_objects + +import warnings + + +def noop(*args, **kargs): + pass + + +warnings.warn = noop + + +#### +def __proc_np_hv(pred): + """Process Nuclei Prediction with XY Coordinate Map. + + Args: + pred: prediction output, assuming + channel 0 contain probability map of nuclei + channel 1 containing the regressed X-map + channel 2 containing the regressed Y-map + + """ + pred = np.array(pred, dtype=np.float32) + + blb_raw = pred[..., 0] + h_dir_raw = pred[..., 1] + v_dir_raw = pred[..., 2] + + # processing + blb = np.array(blb_raw >= 0.5, dtype=np.int32) + + blb = measurements.label(blb)[0] + blb = remove_small_objects(blb, min_size=10) + blb[blb > 0] = 1 # background is 0 already + + h_dir = cv2.normalize( + h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + v_dir = cv2.normalize( + v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + + sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=21) + sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=21) + + sobelh = 1 - ( + cv2.normalize( + sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + ) + sobelv = 1 - ( + cv2.normalize( + sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + ) + + overall = np.maximum(sobelh, sobelv) + overall = overall - (1 - blb) + overall[overall < 0] = 0 + + dist = (1.0 - overall) * blb + ## nuclei values form mountains so inverse to get basins + dist = -cv2.GaussianBlur(dist, (3, 3), 0) + + overall = np.array(overall >= 0.4, dtype=np.int32) + + marker = blb - overall + marker[marker < 0] = 0 + marker = binary_fill_holes(marker).astype("uint8") + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) + marker = measurements.label(marker)[0] + marker = remove_small_objects(marker, min_size=10) + + proced_pred = watershed(dist, markers=marker, mask=blb) + + return proced_pred + + +#### +def process(pred_map, nr_types=None, return_centroids=False): + """Post processing script for image tiles. + + Args: + pred_map: commbined output of tp, np and hv branches, in the same order + nr_types: number of types considered at output of nc branch + overlaid_img: img to overlay the predicted instances upon, `None` means no + type_colour (dict) : `None` to use random, else overlay instances of a type to colour in the dict + output_dtype: data type of output + + Returns: + pred_inst: pixel-wise nuclear instance segmentation prediction + pred_type_out: pixel-wise nuclear type prediction + + """ + if nr_types is not None: + pred_type = pred_map[..., :1] + pred_inst = pred_map[..., 1:] + pred_type = pred_type.astype(np.int32) + else: + pred_inst = pred_map + + pred_inst = np.squeeze(pred_inst) + pred_inst = __proc_np_hv(pred_inst) + + inst_info_dict = None + if return_centroids or nr_types is not None: + inst_id_list = np.unique(pred_inst)[1:] # exlcude background + inst_info_dict = {} + for inst_id in inst_id_list: + inst_map = pred_inst == inst_id + # TODO: chane format of bbox output + rmin, rmax, cmin, cmax = get_bounding_box(inst_map) + inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) + inst_map = inst_map[ + inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] + ] + inst_map = inst_map.astype(np.uint8) + inst_moment = cv2.moments(inst_map) + inst_contour = cv2.findContours( + inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + # * opencv protocol format may break + inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) + # < 3 points dont make a contour, so skip, likely artifact too + # as the contours obtained via approximation => too small or sthg + if inst_contour.shape[0] < 3: + continue + if len(inst_contour.shape) != 2: + continue # ! check for trickery shape + inst_centroid = [ + (inst_moment["m10"] / inst_moment["m00"]), + (inst_moment["m01"] / inst_moment["m00"]), + ] + inst_centroid = np.array(inst_centroid) + inst_contour[:, 0] += inst_bbox[0][1] # X + inst_contour[:, 1] += inst_bbox[0][0] # Y + inst_centroid[0] += inst_bbox[0][1] # X + inst_centroid[1] += inst_bbox[0][0] # Y + inst_info_dict[inst_id] = { # inst_id should start at 1 + "bbox": inst_bbox, + "centroid": inst_centroid, + "contour": inst_contour, + "type_prob": None, + "type": None, + } + + if nr_types is not None: + #### * Get class of each instance id, stored at index id-1 + for inst_id in list(inst_info_dict.keys()): + rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() + inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] + inst_type_crop = pred_type[rmin:rmax, cmin:cmax] + inst_map_crop = ( + inst_map_crop == inst_id + ) # TODO: duplicated operation, may be expensive + inst_type = inst_type_crop[inst_map_crop] + type_list, type_pixels = np.unique(inst_type, return_counts=True) + type_list = list(zip(type_list, type_pixels)) + type_list = sorted(type_list, key=lambda x: x[1], reverse=True) + inst_type = type_list[0][0] + if inst_type == 0: # ! pick the 2nd most dominant if exist + if len(type_list) > 1: + inst_type = type_list[1][0] + type_dict = {v[0]: v[1] for v in type_list} + type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) + inst_info_dict[inst_id]["type"] = int(inst_type) + inst_info_dict[inst_id]["type_prob"] = float(type_prob) + + # print('here') + # ! WARNING: ID MAY NOT BE CONTIGUOUS + # inst_id in the dict maps to the same value in the `pred_inst` + return pred_inst, inst_info_dict diff --git a/models/hovernet/run_desc.py b/models/hovernet/run_desc.py new file mode 100644 index 0000000..026873c --- /dev/null +++ b/models/hovernet/run_desc.py @@ -0,0 +1,344 @@ +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F + +from misc.utils import center_pad_to_shape, cropping_center +from .utils import crop_to_shape, dice_loss, mse_loss, msge_loss, xentropy_loss + +from collections import OrderedDict + +#### +def train_step(batch_data, run_info): + # TODO: synchronize the attach protocol + run_info, state_info = run_info + loss_func_dict = { + "bce": xentropy_loss, + "dice": dice_loss, + "mse": mse_loss, + "msge": msge_loss, + } + # use 'ema' to add for EMA calculation, must be scalar! + result_dict = {"EMA": {}} + track_value = lambda name, value: result_dict["EMA"].update({name: value}) + + #### + model = run_info["net"]["desc"] + optimizer = run_info["net"]["optimizer"] + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs = imgs.to("cuda").type(torch.float32) # to NCHW + imgs = imgs.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = true_np.to("cuda").type(torch.int64) + true_hv = true_hv.to("cuda").type(torch.float32) + + true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) + true_dict = { + "np": true_np_onehot, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64) + true_tp_onehot = F.one_hot(true_tp, num_classes=model.module.nr_types) + true_tp_onehot = true_tp_onehot.type(torch.float32) + true_dict["tp"] = true_tp_onehot + + #### + model.train() + model.zero_grad() # not rnn so not accumulate + + pred_dict = model(imgs) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1) + if model.module.nr_types is not None: + pred_dict["tp"] = F.softmax(pred_dict["tp"], dim=-1) + + #### + loss = 0 + loss_opts = run_info["net"]["extra_info"]["loss"] + for branch_name in pred_dict.keys(): + for loss_name, loss_weight in loss_opts[branch_name].items(): + loss_func = loss_func_dict[loss_name] + loss_args = [true_dict[branch_name], pred_dict[branch_name]] + if loss_name == "msge": + loss_args.append(true_np_onehot[..., 1]) + term_loss = loss_func(*loss_args) + track_value("loss_%s_%s" % (branch_name, loss_name), term_loss.cpu().item()) + loss += loss_weight * term_loss + + track_value("overall_loss", loss.cpu().item()) + # * gradient update + + # torch.set_printoptions(precision=10) + loss.backward() + optimizer.step() + #### + + # pick 2 random sample from the batch for visualization + sample_indices = torch.randint(0, true_np.shape[0], (2,)) + + imgs = (imgs[sample_indices]).byte() # to uint8 + imgs = imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy() + + pred_dict["np"] = pred_dict["np"][..., 1] # return pos only + pred_dict = { + k: v[sample_indices].detach().cpu().numpy() for k, v in pred_dict.items() + } + + true_dict["np"] = true_np + true_dict = { + k: v[sample_indices].detach().cpu().numpy() for k, v in true_dict.items() + } + + # * Its up to user to define the protocol to process the raw output per step! + result_dict["raw"] = { # protocol for contents exchange within `raw` + "img": imgs, + "np": (true_dict["np"], pred_dict["np"]), + "hv": (true_dict["hv"], pred_dict["hv"]), + } + return result_dict + + +#### +def valid_step(batch_data, run_info): + run_info, state_info = run_info + #### + model = run_info["net"]["desc"] + model.eval() # infer mode + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs_gpu = imgs.to("cuda").type(torch.float32) # to NCHW + imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = torch.squeeze(true_np).type(torch.int64) + true_hv = torch.squeeze(true_hv).type(torch.float32) + + true_dict = { + "np": true_np, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).type(torch.int64) + true_dict["tp"] = true_tp + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1] + if model.module.nr_types is not None: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=False) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + + # * Its up to user to define the protocol to process the raw output per step! + result_dict = { # protocol for contents exchange within `raw` + "raw": { + "imgs": imgs.numpy(), + "true_np": true_dict["np"].numpy(), + "true_hv": true_dict["hv"].numpy(), + "prob_np": pred_dict["np"].cpu().numpy(), + "pred_hv": pred_dict["hv"].cpu().numpy(), + } + } + if model.module.nr_types is not None: + result_dict["raw"]["true_tp"] = true_dict["tp"].numpy() + result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy() + return result_dict + + +#### +def infer_step(batch_data, model): + + #### + patch_imgs = batch_data + + patch_imgs_gpu = patch_imgs.to("cuda").type(torch.float32) # to NCHW + patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() + + #### + model.eval() # infer mode + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(patch_imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] + if "tp" in pred_dict: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=True) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + pred_output = torch.cat(list(pred_dict.values()), -1) + + # * Its up to user to define the protocol to process the raw output per step! + return pred_output.cpu().numpy() + + +#### +def viz_step_output(raw_data, nr_types=None): + """ + `raw_data` will be implicitly provided in the similar format as the + return dict from train/valid step, but may have been accumulated across N running step + """ + + imgs = raw_data["img"] + true_np, pred_np = raw_data["np"] + true_hv, pred_hv = raw_data["hv"] + if nr_types is not None: + true_tp, pred_tp = raw_data["tp"] + + aligned_shape = [list(imgs.shape), list(true_np.shape), list(pred_np.shape)] + aligned_shape = np.min(np.array(aligned_shape), axis=0)[1:3] + + cmap = plt.get_cmap("jet") + + def colorize(ch, vmin, vmax): + """ + Will clamp value value outside the provided range to vmax and vmin + """ + ch = np.squeeze(ch.astype("float32")) + ch[ch > vmax] = vmax # clamp value + ch[ch < vmin] = vmin + ch = (ch - vmin) / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + # ch_cmap = center_pad_to_shape(ch_cmap, aligned_shape) + return ch_cmap + + viz_list = [] + for idx in range(imgs.shape[0]): + # img = center_pad_to_shape(imgs[idx], aligned_shape) + img = cropping_center(imgs[idx], aligned_shape) + + true_viz_list = [img] + # cmap may randomly fails if of other types + true_viz_list.append(colorize(true_np[idx], 0, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 0], -1, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 1], -1, 1)) + if nr_types is not None: # TODO: a way to pass through external info + true_viz_list.append(colorize(true_tp[idx], 0, nr_types)) + true_viz_list = np.concatenate(true_viz_list, axis=1) + + pred_viz_list = [img] + # cmap may randomly fails if of other types + pred_viz_list.append(colorize(pred_np[idx], 0, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 0], -1, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 1], -1, 1)) + if nr_types is not None: + pred_viz_list.append(colorize(pred_tp[idx], 0, nr_types)) + pred_viz_list = np.concatenate(pred_viz_list, axis=1) + + viz_list.append(np.concatenate([true_viz_list, pred_viz_list], axis=0)) + viz_list = np.concatenate(viz_list, axis=0) + return viz_list + + +#### +from itertools import chain + + +def proc_valid_step_output(raw_data, nr_types=None): + # TODO: add auto populate from main state track list + track_dict = {"scalar": {}, "image": {}} + + def track_value(name, value, vtype): + return track_dict[vtype].update({name: value}) + + def _dice_info(true, pred, label): + true = np.array(true == label, np.int32) + pred = np.array(pred == label, np.int32) + inter = (pred * true).sum() + total = (pred + true).sum() + return inter, total + + over_inter = 0 + over_total = 0 + over_correct = 0 + prob_np = raw_data["prob_np"] + true_np = raw_data["true_np"] + for idx in range(len(raw_data["true_np"])): + patch_prob_np = prob_np[idx] + patch_true_np = true_np[idx] + patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32) + inter, total = _dice_info(patch_true_np, patch_pred_np, 1) + correct = (patch_pred_np == patch_true_np).sum() + over_inter += inter + over_total += total + over_correct += correct + nr_pixels = len(true_np) * np.size(true_np[0]) + acc_np = over_correct / nr_pixels + dice_np = 2 * over_inter / (over_total + 1.0e-8) + track_value("np_acc", acc_np, "scalar") + track_value("np_dice", dice_np, "scalar") + + # * TP statistic + if nr_types is not None: + pred_tp = raw_data["pred_tp"] + true_tp = raw_data["true_tp"] + for type_id in range(0, nr_types): + over_inter = 0 + over_total = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_tp = pred_tp[idx] + patch_true_tp = true_tp[idx] + inter, total = _dice_info(patch_true_tp, patch_pred_tp, type_id) + over_inter += inter + over_total += total + dice_tp = 2 * over_inter / (over_total + 1.0e-8) + track_value("tp_dice_%d" % type_id, dice_tp, "scalar") + + # * HV regression statistic + pred_hv = raw_data["pred_hv"] + true_hv = raw_data["true_hv"] + + over_squared_error = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_hv = pred_hv[idx] + patch_true_hv = true_hv[idx] + squared_error = patch_pred_hv - patch_true_hv + squared_error = squared_error * squared_error + over_squared_error += squared_error.sum() + mse = over_squared_error / nr_pixels + track_value("hv_mse", mse, "scalar") + + # * + imgs = raw_data["imgs"] + selected_idx = np.random.randint(0, len(imgs), size=(8,)).tolist() + imgs = np.array([imgs[idx] for idx in selected_idx]) + true_np = np.array([true_np[idx] for idx in selected_idx]) + true_hv = np.array([true_hv[idx] for idx in selected_idx]) + prob_np = np.array([prob_np[idx] for idx in selected_idx]) + pred_hv = np.array([pred_hv[idx] for idx in selected_idx]) + viz_raw_data = {"img": imgs, "np": (true_np, prob_np), "hv": (true_hv, pred_hv)} + + if nr_types is not None: + true_tp = np.array([true_tp[idx] for idx in selected_idx]) + pred_tp = np.array([pred_tp[idx] for idx in selected_idx]) + viz_raw_data["tp"] = (true_tp, pred_tp) + viz_fig = viz_step_output(viz_raw_data, nr_types) + track_dict["image"]["output"] = viz_fig + + return track_dict diff --git a/models/hovernet/targets.py b/models/hovernet/targets.py new file mode 100644 index 0000000..d9466d9 --- /dev/null +++ b/models/hovernet/targets.py @@ -0,0 +1,153 @@ +import math +import numpy as np + +import torch +import torch.nn.functional as F + +from scipy import ndimage +from scipy.ndimage import measurements +from skimage import morphology as morph +import matplotlib.pyplot as plt + +from misc.utils import center_pad_to_shape, cropping_center, get_bounding_box +from dataloader.augs import fix_mirror_padding + + +#### +def gen_instance_hv_map(ann, crop_shape): + """Input annotation must be of original shape. + + The map is calculated only for instances within the crop portion + but based on the original shape in original image. + + Perform following operation: + Obtain the horizontal and vertical distance maps for each + nuclear instance. + + """ + orig_ann = ann.copy() # instance ID map + fixed_ann = fix_mirror_padding(orig_ann) + # re-cropping with fixed instance id map + crop_ann = cropping_center(fixed_ann, crop_shape) + # TODO: deal with 1 label warning + crop_ann = morph.remove_small_objects(crop_ann, min_size=30) + + x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) + y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) + + inst_list = list(np.unique(crop_ann)) + inst_list.remove(0) # 0 is background + for inst_id in inst_list: + inst_map = np.array(fixed_ann == inst_id, np.uint8) + inst_box = get_bounding_box(inst_map) + + # expand the box by 2px + # Because we first pad the ann at line 207, the bboxes + # will remain valid after expansion + inst_box[0] -= 2 + inst_box[2] -= 2 + inst_box[1] += 2 + inst_box[3] += 2 + + inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + + if inst_map.shape[0] < 2 or inst_map.shape[1] < 2: + continue + + # instance center of mass, rounded to nearest pixel + inst_com = list(measurements.center_of_mass(inst_map)) + + inst_com[0] = int(inst_com[0] + 0.5) + inst_com[1] = int(inst_com[1] + 0.5) + + inst_x_range = np.arange(1, inst_map.shape[1] + 1) + inst_y_range = np.arange(1, inst_map.shape[0] + 1) + # shifting center of pixels grid to instance center of mass + inst_x_range -= inst_com[1] + inst_y_range -= inst_com[0] + + inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) + + # remove coord outside of instance + inst_x[inst_map == 0] = 0 + inst_y[inst_map == 0] = 0 + inst_x = inst_x.astype("float32") + inst_y = inst_y.astype("float32") + + # normalize min into -1 scale + if np.min(inst_x) < 0: + inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) + if np.min(inst_y) < 0: + inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) + # normalize max into +1 scale + if np.max(inst_x) > 0: + inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) + if np.max(inst_y) > 0: + inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) + + #### + x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + x_map_box[inst_map > 0] = inst_x[inst_map > 0] + + y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + y_map_box[inst_map > 0] = inst_y[inst_map > 0] + + hv_map = np.dstack([x_map, y_map]) + return hv_map + + +#### +def gen_targets(ann, crop_shape, **kwargs): + """Generate the targets for the network.""" + hv_map = gen_instance_hv_map(ann, crop_shape) + np_map = ann.copy() + np_map[np_map > 0] = 1 + + hv_map = cropping_center(hv_map, crop_shape) + np_map = cropping_center(np_map, crop_shape) + + target_dict = { + "hv_map": hv_map, + "np_map": np_map, + } + + return target_dict + + +#### +def prep_sample(data, is_batch=False, **kwargs): + """ + Designed to process direct output from loader + """ + cmap = plt.get_cmap("jet") + + def colorize(ch, vmin, vmax, shape): + ch = np.squeeze(ch.astype("float32")) + ch = ch / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + ch_cmap = center_pad_to_shape(ch_cmap, shape) + return ch_cmap + + def prep_one_sample(data): + shape_array = [np.array(v.shape[:2]) for v in data.values()] + shape = np.maximum(*shape_array) + viz_list = [] + viz_list.append(colorize(data["np_map"], 0, 1, shape)) + # map to [0,2] for better visualisation. + # Note, [-1,1] is used for training. + viz_list.append(colorize(data["hv_map"][..., 0] + 1, 0, 2, shape)) + viz_list.append(colorize(data["hv_map"][..., 1] + 1, 0, 2, shape)) + img = center_pad_to_shape(data["img"], shape) + return np.concatenate([img] + viz_list, axis=1) + + # cmap may randomly fails if of other types + if is_batch: + viz_list = [] + data_shape = list(data.values())[0].shape + for batch_idx in range(data_shape[0]): + sub_data = {k : v[batch_idx] for k, v in data.items()} + viz_list.append(prep_one_sample(sub_data)) + return np.concatenate(viz_list, axis=0) + else: + return prep_one_sample(data) diff --git a/models/hovernet/utils.py b/models/hovernet/utils.py new file mode 100644 index 0000000..3842287 --- /dev/null +++ b/models/hovernet/utils.py @@ -0,0 +1,172 @@ +import math +import numpy as np + +import torch +import torch.nn.functional as F + +from matplotlib import cm + + +#### +def crop_op(x, cropping, data_format="NCHW"): + """Center crop image. + + Args: + x: input image + cropping: the substracted amount + data_format: choose either `NCHW` or `NHWC` + + """ + crop_t = cropping[0] // 2 + crop_b = cropping[0] - crop_t + crop_l = cropping[1] // 2 + crop_r = cropping[1] - crop_l + if data_format == "NCHW": + x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] + else: + x = x[:, crop_t:-crop_b, crop_l:-crop_r, :] + return x + + +#### +def crop_to_shape(x, y, data_format="NCHW"): + """Centre crop x so that x has shape of y. y dims must be smaller than x dims. + + Args: + x: input array + y: array with desired shape. + + """ + assert ( + y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1] + ), "Ensure that y dimensions are smaller than x dimensions!" + + x_shape = x.size() + y_shape = y.size() + if data_format == "NCHW": + crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3]) + else: + crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2]) + return crop_op(x, crop_shape, data_format) + + +#### +def xentropy_loss(true, pred, reduction="mean"): + """Cross entropy loss. Assumes NHWC! + + Args: + pred: prediction array + true: ground truth array + + Returns: + cross entropy loss + + """ + epsilon = 10e-8 + # scale preds so that the class probs of each sample sum to 1 + pred = pred / torch.sum(pred, -1, keepdim=True) + # manual computation of crossentropy + pred = torch.clamp(pred, epsilon, 1.0 - epsilon) + loss = -torch.sum((true * torch.log(pred)), -1, keepdim=True) + loss = loss.mean() if reduction == "mean" else loss.sum() + return loss + + +#### +def dice_loss(true, pred, smooth=1e-3): + """`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC.""" + inse = torch.sum(pred * true, (0, 1, 2)) + l = torch.sum(pred, (0, 1, 2)) + r = torch.sum(true, (0, 1, 2)) + loss = 1.0 - (2.0 * inse + smooth) / (l + r + smooth) + loss = torch.sum(loss) + return loss + + +#### +def mse_loss(true, pred): + """Calculate mean squared error loss. + + Args: + true: ground truth of combined horizontal + and vertical maps + pred: prediction of combined horizontal + and vertical maps + + Returns: + loss: mean squared error + + """ + loss = pred - true + loss = (loss * loss).mean() + return loss + + +#### +def msge_loss(true, pred, focus): + """Calculate the mean squared error of the gradients of + horizontal and vertical map predictions. Assumes + channel 0 is Vertical and channel 1 is Horizontal. + + Args: + true: ground truth of combined horizontal + and vertical maps + pred: prediction of combined horizontal + and vertical maps + focus: area where to apply loss (we only calculate + the loss within the nuclei) + + Returns: + loss: mean squared error of gradients + + """ + + def get_sobel_kernel(size): + """Get sobel kernel with a given size.""" + assert size % 2 == 1, "Must be odd, get size=%d" % size + + h_range = torch.arange( + -size // 2 + 1, + size // 2 + 1, + dtype=torch.float32, + device="cuda", + requires_grad=False, + ) + v_range = torch.arange( + -size // 2 + 1, + size // 2 + 1, + dtype=torch.float32, + device="cuda", + requires_grad=False, + ) + h, v = torch.meshgrid(h_range, v_range) + kernel_h = h / (h * h + v * v + 1.0e-15) + kernel_v = v / (h * h + v * v + 1.0e-15) + return kernel_h, kernel_v + + #### + def get_gradient_hv(hv): + """For calculating gradient.""" + kernel_h, kernel_v = get_sobel_kernel(5) + kernel_h = kernel_h.view(1, 1, 5, 5) # constant + kernel_v = kernel_v.view(1, 1, 5, 5) # constant + + h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW + v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW + + # can only apply in NCHW mode + h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2) + v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2) + dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1) + dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC + return dhv + + focus = (focus[..., None]).float() # assume input NHW + focus = torch.cat([focus, focus], axis=-1) + true_grad = get_gradient_hv(true) + pred_grad = get_gradient_hv(pred) + loss = pred_grad - true_grad + loss = focus * (loss * loss) + # artificial reduce_mean with focused region + loss = loss.sum() / (focus.sum() + 1.0e-8) + return loss diff --git a/models/hovernet0/__init__.py b/models/hovernet0/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/hovernet0/net_desc.py b/models/hovernet0/net_desc.py new file mode 100644 index 0000000..745f201 --- /dev/null +++ b/models/hovernet0/net_desc.py @@ -0,0 +1,153 @@ +import math +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .net_utils import (DenseBlock, Net, ResidualBlock, TFSamepaddingLayer, + UpSample2x) +from .utils import crop_op, crop_to_shape + +#### +class HoVerNet(Net): + """Initialise HoVer-Net.""" + + def __init__(self, input_ch=3, nr_types=None, freeze=False, mode='original'): + super().__init__() + self.mode = mode + self.freeze = freeze + self.nr_types = nr_types + self.output_ch = 3 if nr_types is None else 4 + + assert mode == 'original' or mode == 'fast', \ + 'Unknown mode `%s` for HoVerNet %s. Only support `original` or `fast`.' % mode + + module_list = [ + ("/", nn.Conv2d(input_ch, 64, 7, stride=1, padding=0, bias=False)), + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + if mode == 'fast': # prepend the padding for `fast` mode + module_list = [("pad", TFSamepaddingLayer(ksize=7, stride=1))] + module_list + + self.conv0 = nn.Sequential(OrderedDict(module_list)) + self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) + self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) + self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) + self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) + + self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) + + def create_decoder_branch(out_ch=2, ksize=5): + module_list = [ + ("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)), + ("convf", nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False),), + ] + u3 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)), + ("convf", nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False),), + ] + u2 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)), + ("conva", nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False),), + ] + u1 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ("conv", nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True),), + ] + u0 = nn.Sequential(OrderedDict(module_list)) + + decoder = nn.Sequential( + OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0),]) + ) + return decoder + + ksize = 5 if mode == 'original' else 3 + if nr_types is None: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("np", create_decoder_branch(ksize=ksize,out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize,out_ch=2)), + ] + ) + ) + else: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("tp", create_decoder_branch(ksize=ksize, out_ch=nr_types)), + ("np", create_decoder_branch(ksize=ksize, out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize, out_ch=2)), + ] + ) + ) + + self.upsample2x = UpSample2x() + # TODO: pytorch still require the channel eventhough its ignored + self.weights_init() + + def forward(self, imgs): + + imgs = imgs / 255.0 # to 0-1 range to match XY + + if self.training: + d0 = self.conv0(imgs) + d0 = self.d0(d0, self.freeze) + with torch.set_grad_enabled(not self.freeze): + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + else: + d0 = self.conv0(imgs) + d0 = self.d0(d0) + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + + # TODO: switch to `crop_to_shape` ? + if self.mode == 'original': + d[0] = crop_op(d[0], [184, 184]) + d[1] = crop_op(d[1], [72, 72]) + else: + d[0] = crop_op(d[0], [92, 92]) + d[1] = crop_op(d[1], [36, 36]) + + out_dict = OrderedDict() + for branch_name, branch_desc in self.decoder.items(): + u3 = self.upsample2x(d[-1]) + d[-2] + u3 = branch_desc[0](u3) + + u2 = self.upsample2x(u3) + d[-3] + u2 = branch_desc[1](u2) + + u1 = self.upsample2x(u2) + d[-4] + u1 = branch_desc[2](u1) + + u0 = branch_desc[3](u1) + out_dict[branch_name] = u0 + + return out_dict + + +#### +def create_model(mode=None, **kwargs): + if mode not in ['original', 'fast']: + assert "Unknown Model Mode %s" % mode + return HoVerNet(mode=mode, **kwargs) + diff --git a/models/hovernet0/net_utils.py b/models/hovernet0/net_utils.py new file mode 100644 index 0000000..7f13624 --- /dev/null +++ b/models/hovernet0/net_utils.py @@ -0,0 +1,295 @@ +import numpy as np +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict + +from .utils import crop_op, crop_to_shape +from config import Config + + +#### +class Net(nn.Module): + """ A base class provides a common weight initialisation scheme.""" + + def weights_init(self): + for m in self.modules(): + classname = m.__class__.__name__ + + # ! Fixed the type checking + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + + if "norm" in classname.lower(): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if "linear" in classname.lower(): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + return x + + +#### +class TFSamepaddingLayer(nn.Module): + """To align with tf `same` padding. + + Putting this before any conv layer that need padding + Assuming kernel has Height == Width for simplicity + """ + + def __init__(self, ksize, stride): + super(TFSamepaddingLayer, self).__init__() + self.ksize = ksize + self.stride = stride + + def forward(self, x): + if x.shape[2] % self.stride == 0: + pad = max(self.ksize - self.stride, 0) + else: + pad = max(self.ksize - (x.shape[2] % self.stride), 0) + + if pad % 2 == 0: + pad_val = pad // 2 + padding = (pad_val, pad_val, pad_val, pad_val) + else: + pad_val_start = pad // 2 + pad_val_end = pad - pad_val_start + padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end) + # print(x.shape, padding) + x = F.pad(x, padding, "constant", 0) + # print(x.shape) + return x + + +#### +class DenseBlock(Net): + """Dense Block as defined in: + + Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. + "Densely connected convolutional networks." In Proceedings of the IEEE conference + on computer vision and pattern recognition, pp. 4700-4708. 2017. + + Only performs `valid` convolution. + + """ + + def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1): + super(DenseBlock, self).__init__() + assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" + + self.nr_unit = unit_count + self.in_ch = in_ch + self.unit_ch = unit_ch + + # ! For inference only so init values for batchnorm may not match tensorflow + unit_in_ch = in_ch + self.units = nn.ModuleList() + for idx in range(unit_count): + self.units.append( + nn.Sequential( + OrderedDict( + [ + ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("preact_bna/relu", nn.ReLU(inplace=True)), + ( + "conv1", + nn.Conv2d( + unit_in_ch, + unit_ch[0], + unit_ksize[0], + stride=1, + padding=0, + bias=False, + ), + ), + ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), + ("conv1/relu", nn.ReLU(inplace=True)), + # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)), + ( + "conv2", + nn.Conv2d( + unit_ch[0], + unit_ch[1], + unit_ksize[1], + groups=split, + stride=1, + padding=0, + bias=False, + ), + ), + ] + ) + ) + ) + unit_in_ch += unit_ch[1] + + self.blk_bna = nn.Sequential( + OrderedDict( + [ + ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + + def out_ch(self): + return self.in_ch + self.nr_unit * self.unit_ch[-1] + + def forward(self, prev_feat): + for idx in range(self.nr_unit): + new_feat = self.units[idx](prev_feat) + prev_feat = crop_to_shape(prev_feat, new_feat) + prev_feat = torch.cat([prev_feat, new_feat], dim=1) + prev_feat = self.blk_bna(prev_feat) + + return prev_feat + + +#### +class ResidualBlock(Net): + """Residual block as defined in: + + He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning + for image recognition." In Proceedings of the IEEE conference on computer vision + and pattern recognition, pp. 770-778. 2016. + + """ + + def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1): + super(ResidualBlock, self).__init__() + assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" + + self.nr_unit = unit_count + self.in_ch = in_ch + self.unit_ch = unit_ch + + # ! For inference only so init values for batchnorm may not match tensorflow + unit_in_ch = in_ch + self.units = nn.ModuleList() + for idx in range(unit_count): + unit_layer = [ + ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("preact/relu", nn.ReLU(inplace=True)), + ( + "conv1", + nn.Conv2d( + unit_in_ch, + unit_ch[0], + unit_ksize[0], + stride=1, + padding=0, + bias=False, + ), + ), + ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), + ("conv1/relu", nn.ReLU(inplace=True)), + ( + "conv2/pad", + TFSamepaddingLayer( + ksize=unit_ksize[1], stride=stride if idx == 0 else 1 + ), + ), + ( + "conv2", + nn.Conv2d( + unit_ch[0], + unit_ch[1], + unit_ksize[1], + stride=stride if idx == 0 else 1, + padding=0, + bias=False, + ), + ), + ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)), + ("conv2/relu", nn.ReLU(inplace=True)), + ( + "conv3", + nn.Conv2d( + unit_ch[1], + unit_ch[2], + unit_ksize[2], + stride=1, + padding=0, + bias=False, + ), + ), + ] + # * has bna to conclude each previous block so + # * must not put preact for the first unit of this block + unit_layer = unit_layer if idx != 0 else unit_layer[2:] + self.units.append(nn.Sequential(OrderedDict(unit_layer))) + unit_in_ch = unit_ch[-1] + + if in_ch != unit_ch[-1] or stride != 1: + self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False) + else: + self.shortcut = None + + self.blk_bna = nn.Sequential( + OrderedDict( + [ + ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + + # print(self.units[0]) + # print(self.units[1]) + # exit() + + def out_ch(self): + return self.unit_ch[-1] + + def forward(self, prev_feat, freeze=False): + if self.shortcut is None: + shortcut = prev_feat + else: + shortcut = self.shortcut(prev_feat) + + for idx in range(0, len(self.units)): + new_feat = prev_feat + if self.training: + with torch.set_grad_enabled(not freeze): + new_feat = self.units[idx](new_feat) + else: + new_feat = self.units[idx](new_feat) + prev_feat = new_feat + shortcut + shortcut = prev_feat + feat = self.blk_bna(prev_feat) + return feat + + +#### +class UpSample2x(nn.Module): + """Upsample input by a factor of 2. + + Assume input is of NCHW, port FixedUnpooling from TensorPack. + """ + + def __init__(self): + super(UpSample2x, self).__init__() + # correct way to create constant within module + self.register_buffer( + "unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32")) + ) + self.unpool_mat.unsqueeze(0) + + def forward(self, x): + input_shape = list(x.shape) + # unsqueeze is expand_dims equivalent + # permute is transpose equivalent + # view is reshape equivalent + x = x.unsqueeze(-1) # bchwx1 + mat = self.unpool_mat.unsqueeze(0) # 1xshxsw + ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw + ret = ret.permute(0, 1, 2, 4, 3, 5) + ret = ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2)) + return ret + diff --git a/models/hovernet0/opt.py b/models/hovernet0/opt.py new file mode 100644 index 0000000..9ed56d7 --- /dev/null +++ b/models/hovernet0/opt.py @@ -0,0 +1,142 @@ +import torch.optim as optim + +from run_utils.callbacks.base import ( + AccumulateRawOutput, + PeriodicSaver, + ProcessAccumulatedRawOutput, + ScalarMovingAverage, + ScheduleLr, + TrackLr, + VisualizeOutput, + TriggerEngine, +) +from run_utils.callbacks.logging import LoggingEpochOutput, LoggingGradient +from run_utils.engine import Events + +from .targets import gen_targets, prep_sample +from .net_desc import create_model +from .run_desc import proc_valid_step_output, train_step, valid_step, viz_step_output + + +# TODO: training config only ? +# TODO: switch all to function name String for all option +def get_config(nr_type, mode): + return { + # ------------------------------------------------------------------ + # ! All phases have the same number of run engine + # phases are run sequentially from index 0 to N + "phase_list": [ + { + "run_info": { + # may need more dynamic for each network + "net": { + "desc": lambda: create_model( + input_ch=3, nr_types=nr_type, + freeze=True, mode=mode + ), + "optimizer": [ + optim.Adam, + { # should match keyword for parameters within the optimizer + "lr": 1.0e-4, # initial learning rate, + "betas": (0.9, 0.999), + }, + ], + # learning rate scheduler + "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), + "extra_info": { + "loss": { + "np": {"bce": 1, "dice": 1}, + "hv": {"mse": 1, "msge": 1}, + "tp": {"bce": 1, "dice": 1}, + }, + }, + # path to load, -1 to auto load checkpoint from previous phase, + # None to start from scratch + "pretrained": "../pretrained/ImageNet-ResNet50-Preact_pytorch.tar", + # 'pretrained': None, + }, + }, + "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, + "batch_size": {"train": 16, "valid": 16,}, # engine name : value + "nr_epochs": 50, + }, + { + "run_info": { + # may need more dynamic for each network + "net": { + "desc": lambda: create_model( + input_ch=3, nr_types=nr_type, + freeze=False, mode=mode + ), + "optimizer": [ + optim.Adam, + { # should match keyword for parameters within the optimizer + "lr": 1.0e-4, # initial learning rate, + "betas": (0.9, 0.999), + }, + ], + # learning rate scheduler + "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), + "extra_info": { + "loss": { + "np": {"bce": 1, "dice": 1}, + "hv": {"mse": 1, "msge": 1}, + "tp": {"bce": 1, "dice": 1}, + }, + }, + # path to load, -1 to auto load checkpoint from previous phase, + # None to start from scratch + "pretrained": -1, + }, + }, + "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, + "batch_size": {"train": 4, "valid": 8,}, # batch size per gpu + "nr_epochs": 50, + }, + ], + # ------------------------------------------------------------------ + # TODO: dynamically for dataset plugin selection and processing also? + # all enclosed engine shares the same neural networks + # as the on at the outer calling it + "run_engine": { + "train": { + # TODO: align here, file path or what? what about CV? + "dataset": "", # whats about compound dataset ? + "nr_procs": 16, # number of threads for dataloader + "run_step": train_step, # TODO: function name or function variable ? + "reset_per_run": False, + # callbacks are run according to the list order of the event + "callbacks": { + Events.STEP_COMPLETED: [ + # LoggingGradient(), # TODO: very slow, may be due to back forth of tensor/numpy ? + ScalarMovingAverage(), + ], + Events.EPOCH_COMPLETED: [ + TrackLr(), + PeriodicSaver(), + VisualizeOutput(viz_step_output), + LoggingEpochOutput(), + TriggerEngine("valid"), + ScheduleLr(), + ], + }, + }, + "valid": { + "dataset": "", # whats about compound dataset ? + "nr_procs": 8, # number of threads for dataloader + "run_step": valid_step, + "reset_per_run": True, # * to stop aggregating output etc. from last run + # callbacks are run according to the list order of the event + "callbacks": { + Events.STEP_COMPLETED: [AccumulateRawOutput(),], + Events.EPOCH_COMPLETED: [ + # TODO: is there way to preload these ? + ProcessAccumulatedRawOutput( + lambda a: proc_valid_step_output(a, nr_types=nr_type) + ), + LoggingEpochOutput(), + ], + }, + }, + }, + } diff --git a/models/hovernet0/post_proc.py b/models/hovernet0/post_proc.py new file mode 100644 index 0000000..4fe1ffb --- /dev/null +++ b/models/hovernet0/post_proc.py @@ -0,0 +1,186 @@ +import cv2 +import numpy as np + +from scipy.ndimage import filters, measurements +from scipy.ndimage.morphology import ( + binary_dilation, + binary_fill_holes, + distance_transform_cdt, + distance_transform_edt, +) + +from skimage.segmentation import watershed +from misc.utils import get_bounding_box, remove_small_objects + +import warnings + + +def noop(*args, **kargs): + pass + + +warnings.warn = noop + + +#### +def __proc_np_hv(pred): + """Process Nuclei Prediction with XY Coordinate Map. + + Args: + pred: prediction output, assuming + channel 0 contain probability map of nuclei + channel 1 containing the regressed X-map + channel 2 containing the regressed Y-map + + """ + pred = np.array(pred, dtype=np.float32) + + blb_raw = pred[..., 0] + h_dir_raw = pred[..., 1] + v_dir_raw = pred[..., 2] + + # processing + blb = np.array(blb_raw >= 0.5, dtype=np.int32) + + blb = measurements.label(blb)[0] + blb = remove_small_objects(blb, min_size=10) + blb[blb > 0] = 1 # background is 0 already + + h_dir = cv2.normalize( + h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + v_dir = cv2.normalize( + v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + + sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=21) + sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=21) + + sobelh = 1 - ( + cv2.normalize( + sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + ) + sobelv = 1 - ( + cv2.normalize( + sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + ) + + overall = np.maximum(sobelh, sobelv) + overall = overall - (1 - blb) + overall[overall < 0] = 0 + + dist = (1.0 - overall) * blb + ## nuclei values form mountains so inverse to get basins + dist = -cv2.GaussianBlur(dist, (3, 3), 0) + + overall = np.array(overall >= 0.4, dtype=np.int32) + + marker = blb - overall + marker[marker < 0] = 0 + marker = binary_fill_holes(marker).astype("uint8") + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) + marker = measurements.label(marker)[0] + marker = remove_small_objects(marker, min_size=10) + + proced_pred = watershed(dist, markers=marker, mask=blb) + + return proced_pred + + +#### +def process(pred_map, nr_types=None, return_centroids=False): + """Post processing script for image tiles. + + Args: + pred_map: commbined output of tp, np and hv branches, in the same order + nr_types: number of types considered at output of nc branch + overlaid_img: img to overlay the predicted instances upon, `None` means no + type_colour (dict) : `None` to use random, else overlay instances of a type to colour in the dict + output_dtype: data type of output + + Returns: + pred_inst: pixel-wise nuclear instance segmentation prediction + pred_type_out: pixel-wise nuclear type prediction + + """ + if nr_types is not None: + pred_type = pred_map[..., :1] + pred_inst = pred_map[..., 1:] + pred_type = pred_type.astype(np.int32) + else: + pred_inst = pred_map + + pred_inst = np.squeeze(pred_inst) + pred_inst = __proc_np_hv(pred_inst) + + inst_info_dict = None + if return_centroids or nr_types is not None: + inst_id_list = np.unique(pred_inst)[1:] # exlcude background + inst_info_dict = {} + for inst_id in inst_id_list: + inst_map = pred_inst == inst_id + # TODO: chane format of bbox output + rmin, rmax, cmin, cmax = get_bounding_box(inst_map) + inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) + inst_map = inst_map[ + inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] + ] + inst_map = inst_map.astype(np.uint8) + inst_moment = cv2.moments(inst_map) + inst_contour = cv2.findContours( + inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + # * opencv protocol format may break + inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) + # < 3 points dont make a contour, so skip, likely artifact too + # as the contours obtained via approximation => too small or sthg + if inst_contour.shape[0] < 3: + continue + if len(inst_contour.shape) != 2: + continue # ! check for trickery shape + inst_centroid = [ + (inst_moment["m10"] / inst_moment["m00"]), + (inst_moment["m01"] / inst_moment["m00"]), + ] + inst_centroid = np.array(inst_centroid) + inst_contour[:, 0] += inst_bbox[0][1] # X + inst_contour[:, 1] += inst_bbox[0][0] # Y + inst_centroid[0] += inst_bbox[0][1] # X + inst_centroid[1] += inst_bbox[0][0] # Y + inst_info_dict[inst_id] = { # inst_id should start at 1 + "bbox": inst_bbox, + "centroid": inst_centroid, + "contour": inst_contour, + "type_prob": None, + "type": None, + } + + if nr_types is not None: + #### * Get class of each instance id, stored at index id-1 + for inst_id in list(inst_info_dict.keys()): + rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() + inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] + inst_type_crop = pred_type[rmin:rmax, cmin:cmax] + inst_map_crop = ( + inst_map_crop == inst_id + ) # TODO: duplicated operation, may be expensive + inst_type = inst_type_crop[inst_map_crop] + type_list, type_pixels = np.unique(inst_type, return_counts=True) + type_list = list(zip(type_list, type_pixels)) + type_list = sorted(type_list, key=lambda x: x[1], reverse=True) + inst_type = type_list[0][0] + if inst_type == 0: # ! pick the 2nd most dominant if exist + if len(type_list) > 1: + inst_type = type_list[1][0] + type_dict = {v[0]: v[1] for v in type_list} + type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) + inst_info_dict[inst_id]["type"] = int(inst_type) + inst_info_dict[inst_id]["type_prob"] = float(type_prob) + + # print('here') + # ! WARNING: ID MAY NOT BE CONTIGUOUS + # inst_id in the dict maps to the same value in the `pred_inst` + return pred_inst, inst_info_dict diff --git a/models/hovernet0/run_desc.py b/models/hovernet0/run_desc.py new file mode 100644 index 0000000..026873c --- /dev/null +++ b/models/hovernet0/run_desc.py @@ -0,0 +1,344 @@ +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F + +from misc.utils import center_pad_to_shape, cropping_center +from .utils import crop_to_shape, dice_loss, mse_loss, msge_loss, xentropy_loss + +from collections import OrderedDict + +#### +def train_step(batch_data, run_info): + # TODO: synchronize the attach protocol + run_info, state_info = run_info + loss_func_dict = { + "bce": xentropy_loss, + "dice": dice_loss, + "mse": mse_loss, + "msge": msge_loss, + } + # use 'ema' to add for EMA calculation, must be scalar! + result_dict = {"EMA": {}} + track_value = lambda name, value: result_dict["EMA"].update({name: value}) + + #### + model = run_info["net"]["desc"] + optimizer = run_info["net"]["optimizer"] + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs = imgs.to("cuda").type(torch.float32) # to NCHW + imgs = imgs.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = true_np.to("cuda").type(torch.int64) + true_hv = true_hv.to("cuda").type(torch.float32) + + true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) + true_dict = { + "np": true_np_onehot, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64) + true_tp_onehot = F.one_hot(true_tp, num_classes=model.module.nr_types) + true_tp_onehot = true_tp_onehot.type(torch.float32) + true_dict["tp"] = true_tp_onehot + + #### + model.train() + model.zero_grad() # not rnn so not accumulate + + pred_dict = model(imgs) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1) + if model.module.nr_types is not None: + pred_dict["tp"] = F.softmax(pred_dict["tp"], dim=-1) + + #### + loss = 0 + loss_opts = run_info["net"]["extra_info"]["loss"] + for branch_name in pred_dict.keys(): + for loss_name, loss_weight in loss_opts[branch_name].items(): + loss_func = loss_func_dict[loss_name] + loss_args = [true_dict[branch_name], pred_dict[branch_name]] + if loss_name == "msge": + loss_args.append(true_np_onehot[..., 1]) + term_loss = loss_func(*loss_args) + track_value("loss_%s_%s" % (branch_name, loss_name), term_loss.cpu().item()) + loss += loss_weight * term_loss + + track_value("overall_loss", loss.cpu().item()) + # * gradient update + + # torch.set_printoptions(precision=10) + loss.backward() + optimizer.step() + #### + + # pick 2 random sample from the batch for visualization + sample_indices = torch.randint(0, true_np.shape[0], (2,)) + + imgs = (imgs[sample_indices]).byte() # to uint8 + imgs = imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy() + + pred_dict["np"] = pred_dict["np"][..., 1] # return pos only + pred_dict = { + k: v[sample_indices].detach().cpu().numpy() for k, v in pred_dict.items() + } + + true_dict["np"] = true_np + true_dict = { + k: v[sample_indices].detach().cpu().numpy() for k, v in true_dict.items() + } + + # * Its up to user to define the protocol to process the raw output per step! + result_dict["raw"] = { # protocol for contents exchange within `raw` + "img": imgs, + "np": (true_dict["np"], pred_dict["np"]), + "hv": (true_dict["hv"], pred_dict["hv"]), + } + return result_dict + + +#### +def valid_step(batch_data, run_info): + run_info, state_info = run_info + #### + model = run_info["net"]["desc"] + model.eval() # infer mode + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs_gpu = imgs.to("cuda").type(torch.float32) # to NCHW + imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = torch.squeeze(true_np).type(torch.int64) + true_hv = torch.squeeze(true_hv).type(torch.float32) + + true_dict = { + "np": true_np, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).type(torch.int64) + true_dict["tp"] = true_tp + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1] + if model.module.nr_types is not None: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=False) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + + # * Its up to user to define the protocol to process the raw output per step! + result_dict = { # protocol for contents exchange within `raw` + "raw": { + "imgs": imgs.numpy(), + "true_np": true_dict["np"].numpy(), + "true_hv": true_dict["hv"].numpy(), + "prob_np": pred_dict["np"].cpu().numpy(), + "pred_hv": pred_dict["hv"].cpu().numpy(), + } + } + if model.module.nr_types is not None: + result_dict["raw"]["true_tp"] = true_dict["tp"].numpy() + result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy() + return result_dict + + +#### +def infer_step(batch_data, model): + + #### + patch_imgs = batch_data + + patch_imgs_gpu = patch_imgs.to("cuda").type(torch.float32) # to NCHW + patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() + + #### + model.eval() # infer mode + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(patch_imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] + if "tp" in pred_dict: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=True) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + pred_output = torch.cat(list(pred_dict.values()), -1) + + # * Its up to user to define the protocol to process the raw output per step! + return pred_output.cpu().numpy() + + +#### +def viz_step_output(raw_data, nr_types=None): + """ + `raw_data` will be implicitly provided in the similar format as the + return dict from train/valid step, but may have been accumulated across N running step + """ + + imgs = raw_data["img"] + true_np, pred_np = raw_data["np"] + true_hv, pred_hv = raw_data["hv"] + if nr_types is not None: + true_tp, pred_tp = raw_data["tp"] + + aligned_shape = [list(imgs.shape), list(true_np.shape), list(pred_np.shape)] + aligned_shape = np.min(np.array(aligned_shape), axis=0)[1:3] + + cmap = plt.get_cmap("jet") + + def colorize(ch, vmin, vmax): + """ + Will clamp value value outside the provided range to vmax and vmin + """ + ch = np.squeeze(ch.astype("float32")) + ch[ch > vmax] = vmax # clamp value + ch[ch < vmin] = vmin + ch = (ch - vmin) / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + # ch_cmap = center_pad_to_shape(ch_cmap, aligned_shape) + return ch_cmap + + viz_list = [] + for idx in range(imgs.shape[0]): + # img = center_pad_to_shape(imgs[idx], aligned_shape) + img = cropping_center(imgs[idx], aligned_shape) + + true_viz_list = [img] + # cmap may randomly fails if of other types + true_viz_list.append(colorize(true_np[idx], 0, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 0], -1, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 1], -1, 1)) + if nr_types is not None: # TODO: a way to pass through external info + true_viz_list.append(colorize(true_tp[idx], 0, nr_types)) + true_viz_list = np.concatenate(true_viz_list, axis=1) + + pred_viz_list = [img] + # cmap may randomly fails if of other types + pred_viz_list.append(colorize(pred_np[idx], 0, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 0], -1, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 1], -1, 1)) + if nr_types is not None: + pred_viz_list.append(colorize(pred_tp[idx], 0, nr_types)) + pred_viz_list = np.concatenate(pred_viz_list, axis=1) + + viz_list.append(np.concatenate([true_viz_list, pred_viz_list], axis=0)) + viz_list = np.concatenate(viz_list, axis=0) + return viz_list + + +#### +from itertools import chain + + +def proc_valid_step_output(raw_data, nr_types=None): + # TODO: add auto populate from main state track list + track_dict = {"scalar": {}, "image": {}} + + def track_value(name, value, vtype): + return track_dict[vtype].update({name: value}) + + def _dice_info(true, pred, label): + true = np.array(true == label, np.int32) + pred = np.array(pred == label, np.int32) + inter = (pred * true).sum() + total = (pred + true).sum() + return inter, total + + over_inter = 0 + over_total = 0 + over_correct = 0 + prob_np = raw_data["prob_np"] + true_np = raw_data["true_np"] + for idx in range(len(raw_data["true_np"])): + patch_prob_np = prob_np[idx] + patch_true_np = true_np[idx] + patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32) + inter, total = _dice_info(patch_true_np, patch_pred_np, 1) + correct = (patch_pred_np == patch_true_np).sum() + over_inter += inter + over_total += total + over_correct += correct + nr_pixels = len(true_np) * np.size(true_np[0]) + acc_np = over_correct / nr_pixels + dice_np = 2 * over_inter / (over_total + 1.0e-8) + track_value("np_acc", acc_np, "scalar") + track_value("np_dice", dice_np, "scalar") + + # * TP statistic + if nr_types is not None: + pred_tp = raw_data["pred_tp"] + true_tp = raw_data["true_tp"] + for type_id in range(0, nr_types): + over_inter = 0 + over_total = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_tp = pred_tp[idx] + patch_true_tp = true_tp[idx] + inter, total = _dice_info(patch_true_tp, patch_pred_tp, type_id) + over_inter += inter + over_total += total + dice_tp = 2 * over_inter / (over_total + 1.0e-8) + track_value("tp_dice_%d" % type_id, dice_tp, "scalar") + + # * HV regression statistic + pred_hv = raw_data["pred_hv"] + true_hv = raw_data["true_hv"] + + over_squared_error = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_hv = pred_hv[idx] + patch_true_hv = true_hv[idx] + squared_error = patch_pred_hv - patch_true_hv + squared_error = squared_error * squared_error + over_squared_error += squared_error.sum() + mse = over_squared_error / nr_pixels + track_value("hv_mse", mse, "scalar") + + # * + imgs = raw_data["imgs"] + selected_idx = np.random.randint(0, len(imgs), size=(8,)).tolist() + imgs = np.array([imgs[idx] for idx in selected_idx]) + true_np = np.array([true_np[idx] for idx in selected_idx]) + true_hv = np.array([true_hv[idx] for idx in selected_idx]) + prob_np = np.array([prob_np[idx] for idx in selected_idx]) + pred_hv = np.array([pred_hv[idx] for idx in selected_idx]) + viz_raw_data = {"img": imgs, "np": (true_np, prob_np), "hv": (true_hv, pred_hv)} + + if nr_types is not None: + true_tp = np.array([true_tp[idx] for idx in selected_idx]) + pred_tp = np.array([pred_tp[idx] for idx in selected_idx]) + viz_raw_data["tp"] = (true_tp, pred_tp) + viz_fig = viz_step_output(viz_raw_data, nr_types) + track_dict["image"]["output"] = viz_fig + + return track_dict diff --git a/models/hovernet0/targets.py b/models/hovernet0/targets.py new file mode 100644 index 0000000..d9466d9 --- /dev/null +++ b/models/hovernet0/targets.py @@ -0,0 +1,153 @@ +import math +import numpy as np + +import torch +import torch.nn.functional as F + +from scipy import ndimage +from scipy.ndimage import measurements +from skimage import morphology as morph +import matplotlib.pyplot as plt + +from misc.utils import center_pad_to_shape, cropping_center, get_bounding_box +from dataloader.augs import fix_mirror_padding + + +#### +def gen_instance_hv_map(ann, crop_shape): + """Input annotation must be of original shape. + + The map is calculated only for instances within the crop portion + but based on the original shape in original image. + + Perform following operation: + Obtain the horizontal and vertical distance maps for each + nuclear instance. + + """ + orig_ann = ann.copy() # instance ID map + fixed_ann = fix_mirror_padding(orig_ann) + # re-cropping with fixed instance id map + crop_ann = cropping_center(fixed_ann, crop_shape) + # TODO: deal with 1 label warning + crop_ann = morph.remove_small_objects(crop_ann, min_size=30) + + x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) + y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) + + inst_list = list(np.unique(crop_ann)) + inst_list.remove(0) # 0 is background + for inst_id in inst_list: + inst_map = np.array(fixed_ann == inst_id, np.uint8) + inst_box = get_bounding_box(inst_map) + + # expand the box by 2px + # Because we first pad the ann at line 207, the bboxes + # will remain valid after expansion + inst_box[0] -= 2 + inst_box[2] -= 2 + inst_box[1] += 2 + inst_box[3] += 2 + + inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + + if inst_map.shape[0] < 2 or inst_map.shape[1] < 2: + continue + + # instance center of mass, rounded to nearest pixel + inst_com = list(measurements.center_of_mass(inst_map)) + + inst_com[0] = int(inst_com[0] + 0.5) + inst_com[1] = int(inst_com[1] + 0.5) + + inst_x_range = np.arange(1, inst_map.shape[1] + 1) + inst_y_range = np.arange(1, inst_map.shape[0] + 1) + # shifting center of pixels grid to instance center of mass + inst_x_range -= inst_com[1] + inst_y_range -= inst_com[0] + + inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) + + # remove coord outside of instance + inst_x[inst_map == 0] = 0 + inst_y[inst_map == 0] = 0 + inst_x = inst_x.astype("float32") + inst_y = inst_y.astype("float32") + + # normalize min into -1 scale + if np.min(inst_x) < 0: + inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) + if np.min(inst_y) < 0: + inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) + # normalize max into +1 scale + if np.max(inst_x) > 0: + inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) + if np.max(inst_y) > 0: + inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) + + #### + x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + x_map_box[inst_map > 0] = inst_x[inst_map > 0] + + y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + y_map_box[inst_map > 0] = inst_y[inst_map > 0] + + hv_map = np.dstack([x_map, y_map]) + return hv_map + + +#### +def gen_targets(ann, crop_shape, **kwargs): + """Generate the targets for the network.""" + hv_map = gen_instance_hv_map(ann, crop_shape) + np_map = ann.copy() + np_map[np_map > 0] = 1 + + hv_map = cropping_center(hv_map, crop_shape) + np_map = cropping_center(np_map, crop_shape) + + target_dict = { + "hv_map": hv_map, + "np_map": np_map, + } + + return target_dict + + +#### +def prep_sample(data, is_batch=False, **kwargs): + """ + Designed to process direct output from loader + """ + cmap = plt.get_cmap("jet") + + def colorize(ch, vmin, vmax, shape): + ch = np.squeeze(ch.astype("float32")) + ch = ch / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + ch_cmap = center_pad_to_shape(ch_cmap, shape) + return ch_cmap + + def prep_one_sample(data): + shape_array = [np.array(v.shape[:2]) for v in data.values()] + shape = np.maximum(*shape_array) + viz_list = [] + viz_list.append(colorize(data["np_map"], 0, 1, shape)) + # map to [0,2] for better visualisation. + # Note, [-1,1] is used for training. + viz_list.append(colorize(data["hv_map"][..., 0] + 1, 0, 2, shape)) + viz_list.append(colorize(data["hv_map"][..., 1] + 1, 0, 2, shape)) + img = center_pad_to_shape(data["img"], shape) + return np.concatenate([img] + viz_list, axis=1) + + # cmap may randomly fails if of other types + if is_batch: + viz_list = [] + data_shape = list(data.values())[0].shape + for batch_idx in range(data_shape[0]): + sub_data = {k : v[batch_idx] for k, v in data.items()} + viz_list.append(prep_one_sample(sub_data)) + return np.concatenate(viz_list, axis=0) + else: + return prep_one_sample(data) diff --git a/models/hovernet0/utils.py b/models/hovernet0/utils.py new file mode 100644 index 0000000..3842287 --- /dev/null +++ b/models/hovernet0/utils.py @@ -0,0 +1,172 @@ +import math +import numpy as np + +import torch +import torch.nn.functional as F + +from matplotlib import cm + + +#### +def crop_op(x, cropping, data_format="NCHW"): + """Center crop image. + + Args: + x: input image + cropping: the substracted amount + data_format: choose either `NCHW` or `NHWC` + + """ + crop_t = cropping[0] // 2 + crop_b = cropping[0] - crop_t + crop_l = cropping[1] // 2 + crop_r = cropping[1] - crop_l + if data_format == "NCHW": + x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] + else: + x = x[:, crop_t:-crop_b, crop_l:-crop_r, :] + return x + + +#### +def crop_to_shape(x, y, data_format="NCHW"): + """Centre crop x so that x has shape of y. y dims must be smaller than x dims. + + Args: + x: input array + y: array with desired shape. + + """ + assert ( + y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1] + ), "Ensure that y dimensions are smaller than x dimensions!" + + x_shape = x.size() + y_shape = y.size() + if data_format == "NCHW": + crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3]) + else: + crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2]) + return crop_op(x, crop_shape, data_format) + + +#### +def xentropy_loss(true, pred, reduction="mean"): + """Cross entropy loss. Assumes NHWC! + + Args: + pred: prediction array + true: ground truth array + + Returns: + cross entropy loss + + """ + epsilon = 10e-8 + # scale preds so that the class probs of each sample sum to 1 + pred = pred / torch.sum(pred, -1, keepdim=True) + # manual computation of crossentropy + pred = torch.clamp(pred, epsilon, 1.0 - epsilon) + loss = -torch.sum((true * torch.log(pred)), -1, keepdim=True) + loss = loss.mean() if reduction == "mean" else loss.sum() + return loss + + +#### +def dice_loss(true, pred, smooth=1e-3): + """`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC.""" + inse = torch.sum(pred * true, (0, 1, 2)) + l = torch.sum(pred, (0, 1, 2)) + r = torch.sum(true, (0, 1, 2)) + loss = 1.0 - (2.0 * inse + smooth) / (l + r + smooth) + loss = torch.sum(loss) + return loss + + +#### +def mse_loss(true, pred): + """Calculate mean squared error loss. + + Args: + true: ground truth of combined horizontal + and vertical maps + pred: prediction of combined horizontal + and vertical maps + + Returns: + loss: mean squared error + + """ + loss = pred - true + loss = (loss * loss).mean() + return loss + + +#### +def msge_loss(true, pred, focus): + """Calculate the mean squared error of the gradients of + horizontal and vertical map predictions. Assumes + channel 0 is Vertical and channel 1 is Horizontal. + + Args: + true: ground truth of combined horizontal + and vertical maps + pred: prediction of combined horizontal + and vertical maps + focus: area where to apply loss (we only calculate + the loss within the nuclei) + + Returns: + loss: mean squared error of gradients + + """ + + def get_sobel_kernel(size): + """Get sobel kernel with a given size.""" + assert size % 2 == 1, "Must be odd, get size=%d" % size + + h_range = torch.arange( + -size // 2 + 1, + size // 2 + 1, + dtype=torch.float32, + device="cuda", + requires_grad=False, + ) + v_range = torch.arange( + -size // 2 + 1, + size // 2 + 1, + dtype=torch.float32, + device="cuda", + requires_grad=False, + ) + h, v = torch.meshgrid(h_range, v_range) + kernel_h = h / (h * h + v * v + 1.0e-15) + kernel_v = v / (h * h + v * v + 1.0e-15) + return kernel_h, kernel_v + + #### + def get_gradient_hv(hv): + """For calculating gradient.""" + kernel_h, kernel_v = get_sobel_kernel(5) + kernel_h = kernel_h.view(1, 1, 5, 5) # constant + kernel_v = kernel_v.view(1, 1, 5, 5) # constant + + h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW + v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW + + # can only apply in NCHW mode + h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2) + v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2) + dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1) + dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC + return dhv + + focus = (focus[..., None]).float() # assume input NHW + focus = torch.cat([focus, focus], axis=-1) + true_grad = get_gradient_hv(true) + pred_grad = get_gradient_hv(pred) + loss = pred_grad - true_grad + loss = focus * (loss * loss) + # artificial reduce_mean with focused region + loss = loss.sum() / (focus.sum() + 1.0e-8) + return loss diff --git a/models/hovernetC/__init__.py b/models/hovernetC/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/hovernetC/net_desc.py b/models/hovernetC/net_desc.py new file mode 100644 index 0000000..56b169a --- /dev/null +++ b/models/hovernetC/net_desc.py @@ -0,0 +1,221 @@ +import math +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .net_utils import (DenseBlock, Net, ResidualBlock, TFSamepaddingLayer, + UpSample2x) +from .utils import crop_op, crop_to_shape + +#### +class HoVerNet0(Net): + """Initialise HoVer-Net.""" + + def __init__(self, input_ch=3, nr_types=None, freeze=False, mode='original'): + super().__init__() + self.mode = mode + self.freeze = freeze + self.nr_types = nr_types + self.output_ch = 3 if nr_types is None else 4 + + assert mode == 'original' or mode == 'fast', \ + 'Unknown mode `%s` for HoVerNet %s. Only support `original` or `fast`.' % mode + + module_list = [ + ("/", nn.Conv2d(input_ch, 64, 7, stride=1, padding=0, bias=False)), + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + if mode == 'fast': # prepend the padding for `fast` mode + module_list = [("pad", TFSamepaddingLayer(ksize=7, stride=1))] + module_list + + self.conv0 = nn.Sequential(OrderedDict(module_list)) + self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) + self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) + self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) + self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) + + self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) + + def create_decoder_branch(out_ch=2, ksize=5): + module_list = [ + ("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)), + ("convf", nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False),), + ] + u3 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)), + ("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)), + ("convf", nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False),), + ] + u2 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)), + ("conva", nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False),), + ] + u1 = nn.Sequential(OrderedDict(module_list)) + + module_list = [ + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ("conv", nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True),), + ] + u0 = nn.Sequential(OrderedDict(module_list)) + + decoder = nn.Sequential( + OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0),]) + ) + return decoder + + ksize = 5 if mode == 'original' else 3 + if nr_types is None: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("np", create_decoder_branch(ksize=ksize,out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize,out_ch=2)), + ] + ) + ) + else: + self.decoder = nn.ModuleDict( + OrderedDict( + [ + ("tp", create_decoder_branch(ksize=ksize, out_ch=nr_types)), + ("np", create_decoder_branch(ksize=ksize, out_ch=2)), + ("hv", create_decoder_branch(ksize=ksize, out_ch=2)), + ] + ) + ) + + self.upsample2x = UpSample2x() + # TODO: pytorch still require the channel eventhough its ignored + self.weights_init() + + def forward(self, imgs): + + imgs = imgs / 255.0 # to 0-1 range to match XY + + if self.training: + d0 = self.conv0(imgs) + d0 = self.d0(d0, self.freeze) + with torch.set_grad_enabled(not self.freeze): + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + else: + d0 = self.conv0(imgs) + d0 = self.d0(d0) + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + + # TODO: switch to `crop_to_shape` ? + if self.mode == 'original': + d[0] = crop_op(d[0], [184, 184]) + d[1] = crop_op(d[1], [72, 72]) + else: + d[0] = crop_op(d[0], [92, 92]) + d[1] = crop_op(d[1], [36, 36]) + + out_dict = OrderedDict() + for branch_name, branch_desc in self.decoder.items(): + u3 = self.upsample2x(d[-1]) + d[-2] + u3 = branch_desc[0](u3) + + u2 = self.upsample2x(u3) + d[-3] + u2 = branch_desc[1](u2) + + u1 = self.upsample2x(u2) + d[-4] + u1 = branch_desc[2](u1) + + u0 = branch_desc[3](u1) + out_dict[branch_name] = u0 + + return out_dict +class HoVerNet(nn.Module): + """Initialise HoVer-Net.""" + + def __init__(self, input_ch=3, nr_types=None, freeze=False, mode='original'): + super().__init__() + self.finalwid = 28 + self.mode = mode + self.freeze = freeze + self.nr_types = nr_types + self.output_ch = 3 if nr_types is None else 4 + + assert mode == 'original' or mode == 'fast', \ + 'Unknown mode `%s` for HoVerNet %s. Only support `original` or `fast`.' % mode + + module_list = [ + ("/", nn.Conv2d(input_ch, 64, 7, stride=1, padding=0, bias=False)), + ("bn", nn.BatchNorm2d(64, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + if mode == 'fast': # prepend the padding for `fast` mode + module_list = [("pad", TFSamepaddingLayer(ksize=7, stride=1))] + module_list + + self.conv0 = nn.Sequential(OrderedDict(module_list)) + self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) + self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) + self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) + self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) + + self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) + # self.maxpool = nn.MaxPool2d(3, stride=2) + self.fc1=nn.Linear(1024,out_features=1024) + self.fc2 = nn.Linear(1024,out_features=2) + # TODO: pytorch still require the channel eventhough its ignored + + def forward(self, imgs): + + imgs = imgs / 255.0 # to 0-1 range to match XY + + if self.training: + d0 = self.conv0(imgs) + d0 = self.d0(d0, self.freeze) + with torch.set_grad_enabled(not self.freeze): + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + else: + d0 = self.conv0(imgs) + d0 = self.d0(d0) + d1 = self.d1(d0) + d2 = self.d2(d1) + d3 = self.d3(d2) + d3 = self.conv_bot(d3) + d = [d0, d1, d2, d3] + gap = F.adaptive_avg_pool2d(d3, (1, 1)) + gap=gap.view(-1,1024) + # d3=d3.view(-1,1024) + # d3=self.maxpool(d3) + # d3 = d3.view(-1, 1024 * self.finalwid * self.finalwid) + fc1=self.fc1(gap) + fc2 = self.fc2(fc1) + + # TODO: switch to `crop_to_shape` ? + + + return fc2 + +#### +def create_model(mode=None, **kwargs): + if mode not in ['original', 'fast']: + assert "Unknown Model Mode %s" % mode + return HoVerNet(mode=mode, **kwargs) + + + diff --git a/models/hovernetC/net_utils.py b/models/hovernetC/net_utils.py new file mode 100644 index 0000000..7f13624 --- /dev/null +++ b/models/hovernetC/net_utils.py @@ -0,0 +1,295 @@ +import numpy as np +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict + +from .utils import crop_op, crop_to_shape +from config import Config + + +#### +class Net(nn.Module): + """ A base class provides a common weight initialisation scheme.""" + + def weights_init(self): + for m in self.modules(): + classname = m.__class__.__name__ + + # ! Fixed the type checking + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + + if "norm" in classname.lower(): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if "linear" in classname.lower(): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + return x + + +#### +class TFSamepaddingLayer(nn.Module): + """To align with tf `same` padding. + + Putting this before any conv layer that need padding + Assuming kernel has Height == Width for simplicity + """ + + def __init__(self, ksize, stride): + super(TFSamepaddingLayer, self).__init__() + self.ksize = ksize + self.stride = stride + + def forward(self, x): + if x.shape[2] % self.stride == 0: + pad = max(self.ksize - self.stride, 0) + else: + pad = max(self.ksize - (x.shape[2] % self.stride), 0) + + if pad % 2 == 0: + pad_val = pad // 2 + padding = (pad_val, pad_val, pad_val, pad_val) + else: + pad_val_start = pad // 2 + pad_val_end = pad - pad_val_start + padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end) + # print(x.shape, padding) + x = F.pad(x, padding, "constant", 0) + # print(x.shape) + return x + + +#### +class DenseBlock(Net): + """Dense Block as defined in: + + Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. + "Densely connected convolutional networks." In Proceedings of the IEEE conference + on computer vision and pattern recognition, pp. 4700-4708. 2017. + + Only performs `valid` convolution. + + """ + + def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1): + super(DenseBlock, self).__init__() + assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" + + self.nr_unit = unit_count + self.in_ch = in_ch + self.unit_ch = unit_ch + + # ! For inference only so init values for batchnorm may not match tensorflow + unit_in_ch = in_ch + self.units = nn.ModuleList() + for idx in range(unit_count): + self.units.append( + nn.Sequential( + OrderedDict( + [ + ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("preact_bna/relu", nn.ReLU(inplace=True)), + ( + "conv1", + nn.Conv2d( + unit_in_ch, + unit_ch[0], + unit_ksize[0], + stride=1, + padding=0, + bias=False, + ), + ), + ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), + ("conv1/relu", nn.ReLU(inplace=True)), + # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)), + ( + "conv2", + nn.Conv2d( + unit_ch[0], + unit_ch[1], + unit_ksize[1], + groups=split, + stride=1, + padding=0, + bias=False, + ), + ), + ] + ) + ) + ) + unit_in_ch += unit_ch[1] + + self.blk_bna = nn.Sequential( + OrderedDict( + [ + ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + + def out_ch(self): + return self.in_ch + self.nr_unit * self.unit_ch[-1] + + def forward(self, prev_feat): + for idx in range(self.nr_unit): + new_feat = self.units[idx](prev_feat) + prev_feat = crop_to_shape(prev_feat, new_feat) + prev_feat = torch.cat([prev_feat, new_feat], dim=1) + prev_feat = self.blk_bna(prev_feat) + + return prev_feat + + +#### +class ResidualBlock(Net): + """Residual block as defined in: + + He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning + for image recognition." In Proceedings of the IEEE conference on computer vision + and pattern recognition, pp. 770-778. 2016. + + """ + + def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1): + super(ResidualBlock, self).__init__() + assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" + + self.nr_unit = unit_count + self.in_ch = in_ch + self.unit_ch = unit_ch + + # ! For inference only so init values for batchnorm may not match tensorflow + unit_in_ch = in_ch + self.units = nn.ModuleList() + for idx in range(unit_count): + unit_layer = [ + ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("preact/relu", nn.ReLU(inplace=True)), + ( + "conv1", + nn.Conv2d( + unit_in_ch, + unit_ch[0], + unit_ksize[0], + stride=1, + padding=0, + bias=False, + ), + ), + ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), + ("conv1/relu", nn.ReLU(inplace=True)), + ( + "conv2/pad", + TFSamepaddingLayer( + ksize=unit_ksize[1], stride=stride if idx == 0 else 1 + ), + ), + ( + "conv2", + nn.Conv2d( + unit_ch[0], + unit_ch[1], + unit_ksize[1], + stride=stride if idx == 0 else 1, + padding=0, + bias=False, + ), + ), + ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)), + ("conv2/relu", nn.ReLU(inplace=True)), + ( + "conv3", + nn.Conv2d( + unit_ch[1], + unit_ch[2], + unit_ksize[2], + stride=1, + padding=0, + bias=False, + ), + ), + ] + # * has bna to conclude each previous block so + # * must not put preact for the first unit of this block + unit_layer = unit_layer if idx != 0 else unit_layer[2:] + self.units.append(nn.Sequential(OrderedDict(unit_layer))) + unit_in_ch = unit_ch[-1] + + if in_ch != unit_ch[-1] or stride != 1: + self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False) + else: + self.shortcut = None + + self.blk_bna = nn.Sequential( + OrderedDict( + [ + ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + + # print(self.units[0]) + # print(self.units[1]) + # exit() + + def out_ch(self): + return self.unit_ch[-1] + + def forward(self, prev_feat, freeze=False): + if self.shortcut is None: + shortcut = prev_feat + else: + shortcut = self.shortcut(prev_feat) + + for idx in range(0, len(self.units)): + new_feat = prev_feat + if self.training: + with torch.set_grad_enabled(not freeze): + new_feat = self.units[idx](new_feat) + else: + new_feat = self.units[idx](new_feat) + prev_feat = new_feat + shortcut + shortcut = prev_feat + feat = self.blk_bna(prev_feat) + return feat + + +#### +class UpSample2x(nn.Module): + """Upsample input by a factor of 2. + + Assume input is of NCHW, port FixedUnpooling from TensorPack. + """ + + def __init__(self): + super(UpSample2x, self).__init__() + # correct way to create constant within module + self.register_buffer( + "unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32")) + ) + self.unpool_mat.unsqueeze(0) + + def forward(self, x): + input_shape = list(x.shape) + # unsqueeze is expand_dims equivalent + # permute is transpose equivalent + # view is reshape equivalent + x = x.unsqueeze(-1) # bchwx1 + mat = self.unpool_mat.unsqueeze(0) # 1xshxsw + ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw + ret = ret.permute(0, 1, 2, 4, 3, 5) + ret = ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2)) + return ret + diff --git a/models/hovernetC/opt.py b/models/hovernetC/opt.py new file mode 100644 index 0000000..2fe6ef2 --- /dev/null +++ b/models/hovernetC/opt.py @@ -0,0 +1,142 @@ +import torch.optim as optim + +from run_utils.callbacks.base import ( + AccumulateRawOutput, + PeriodicSaver, + ProcessAccumulatedRawOutput, + ScalarMovingAverage, + ScheduleLr, + TrackLr, + VisualizeOutput, + TriggerEngine, +) +from run_utils.callbacks.logging import LoggingEpochOutput, LoggingGradient +from run_utils.engine import Events + +from .targets import gen_targets, prep_sample +from .net_desc import create_model +from .run_desc import proc_valid_step_output, train_step, valid_step, viz_step_output + + +# TODO: training config only ? +# TODO: switch all to function name String for all option +def get_config(nr_type, mode): + return { + # ------------------------------------------------------------------ + # ! All phases have the same number of run engine + # phases are run sequentially from index 0 to N + "phase_list": [ + { + "run_info": { + # may need more dynamic for each network + "net": { + "desc": lambda: create_model( + input_ch=3, nr_types=nr_type, + freeze=True, mode=mode + ), + "optimizer": [ + optim.Adam, + { # should match keyword for parameters within the optimizer + "lr": 1.0e-4, # initial learning rate, + "betas": (0.9, 0.999), + }, + ], + # learning rate scheduler + "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), + "extra_info": { + "loss": { + "np": {"bce": 1, "dice": 1}, + "hv": {"mse": 1, "msge": 1}, + "tp": {"bce": 1, "dice": 1}, + }, + }, + # path to load, -1 to auto load checkpoint from previous phase, + # None to start from scratch + "pretrained": "ImageNet-ResNet50-Preact_pytorch.tar", + # 'pretrained': None, + }, + }, + "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, + "batch_size": {"train": 2, "valid": 2,}, # engine name : value + "nr_epochs": 50, + }, + { + "run_info": { + # may need more dynamic for each network + "net": { + "desc": lambda: create_model( + input_ch=3, nr_types=nr_type, + freeze=False, mode=mode + ), + "optimizer": [ + optim.Adam, + { # should match keyword for parameters within the optimizer + "lr": 1.0e-4, # initial learning rate, + "betas": (0.9, 0.999), + }, + ], + # learning rate scheduler + "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), + "extra_info": { + "loss": { + "np": {"bce": 1, "dice": 1}, + "hv": {"mse": 1, "msge": 1}, + "tp": {"bce": 1, "dice": 1}, + }, + }, + # path to load, -1 to auto load checkpoint from previous phase, + # None to start from scratch + "pretrained": -1, + }, + }, + "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, + "batch_size": {"train": 2, "valid": 2,}, # batch size per gpu + "nr_epochs": 50, + }, + ], + # ------------------------------------------------------------------ + # TODO: dynamically for dataset plugin selection and processing also? + # all enclosed engine shares the same neural networks + # as the on at the outer calling it + "run_engine": { + "train": { + # TODO: align here, file path or what? what about CV? + "dataset": "", # whats about compound dataset ? + "nr_procs": 16, # number of threads for dataloader + "run_step": train_step, # TODO: function name or function variable ? + "reset_per_run": False, + # callbacks are run according to the list order of the event + "callbacks": { + Events.STEP_COMPLETED: [ + # LoggingGradient(), # TODO: very slow, may be due to back forth of tensor/numpy ? + ScalarMovingAverage(), + ], + Events.EPOCH_COMPLETED: [ + TrackLr(), + PeriodicSaver(), + VisualizeOutput(viz_step_output), + LoggingEpochOutput(), + TriggerEngine("valid"), + ScheduleLr(), + ], + }, + }, + "valid": { + "dataset": "", # whats about compound dataset ? + "nr_procs": 8, # number of threads for dataloader + "run_step": valid_step, + "reset_per_run": True, # * to stop aggregating output etc. from last run + # callbacks are run according to the list order of the event + "callbacks": { + Events.STEP_COMPLETED: [AccumulateRawOutput(),], + Events.EPOCH_COMPLETED: [ + # TODO: is there way to preload these ? + ProcessAccumulatedRawOutput( + lambda a: proc_valid_step_output(a, nr_types=nr_type) + ), + LoggingEpochOutput(), + ], + }, + }, + }, + } diff --git a/models/hovernetC/post_proc.py b/models/hovernetC/post_proc.py new file mode 100644 index 0000000..4fe1ffb --- /dev/null +++ b/models/hovernetC/post_proc.py @@ -0,0 +1,186 @@ +import cv2 +import numpy as np + +from scipy.ndimage import filters, measurements +from scipy.ndimage.morphology import ( + binary_dilation, + binary_fill_holes, + distance_transform_cdt, + distance_transform_edt, +) + +from skimage.segmentation import watershed +from misc.utils import get_bounding_box, remove_small_objects + +import warnings + + +def noop(*args, **kargs): + pass + + +warnings.warn = noop + + +#### +def __proc_np_hv(pred): + """Process Nuclei Prediction with XY Coordinate Map. + + Args: + pred: prediction output, assuming + channel 0 contain probability map of nuclei + channel 1 containing the regressed X-map + channel 2 containing the regressed Y-map + + """ + pred = np.array(pred, dtype=np.float32) + + blb_raw = pred[..., 0] + h_dir_raw = pred[..., 1] + v_dir_raw = pred[..., 2] + + # processing + blb = np.array(blb_raw >= 0.5, dtype=np.int32) + + blb = measurements.label(blb)[0] + blb = remove_small_objects(blb, min_size=10) + blb[blb > 0] = 1 # background is 0 already + + h_dir = cv2.normalize( + h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + v_dir = cv2.normalize( + v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + + sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=21) + sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=21) + + sobelh = 1 - ( + cv2.normalize( + sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + ) + sobelv = 1 - ( + cv2.normalize( + sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F + ) + ) + + overall = np.maximum(sobelh, sobelv) + overall = overall - (1 - blb) + overall[overall < 0] = 0 + + dist = (1.0 - overall) * blb + ## nuclei values form mountains so inverse to get basins + dist = -cv2.GaussianBlur(dist, (3, 3), 0) + + overall = np.array(overall >= 0.4, dtype=np.int32) + + marker = blb - overall + marker[marker < 0] = 0 + marker = binary_fill_holes(marker).astype("uint8") + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) + marker = measurements.label(marker)[0] + marker = remove_small_objects(marker, min_size=10) + + proced_pred = watershed(dist, markers=marker, mask=blb) + + return proced_pred + + +#### +def process(pred_map, nr_types=None, return_centroids=False): + """Post processing script for image tiles. + + Args: + pred_map: commbined output of tp, np and hv branches, in the same order + nr_types: number of types considered at output of nc branch + overlaid_img: img to overlay the predicted instances upon, `None` means no + type_colour (dict) : `None` to use random, else overlay instances of a type to colour in the dict + output_dtype: data type of output + + Returns: + pred_inst: pixel-wise nuclear instance segmentation prediction + pred_type_out: pixel-wise nuclear type prediction + + """ + if nr_types is not None: + pred_type = pred_map[..., :1] + pred_inst = pred_map[..., 1:] + pred_type = pred_type.astype(np.int32) + else: + pred_inst = pred_map + + pred_inst = np.squeeze(pred_inst) + pred_inst = __proc_np_hv(pred_inst) + + inst_info_dict = None + if return_centroids or nr_types is not None: + inst_id_list = np.unique(pred_inst)[1:] # exlcude background + inst_info_dict = {} + for inst_id in inst_id_list: + inst_map = pred_inst == inst_id + # TODO: chane format of bbox output + rmin, rmax, cmin, cmax = get_bounding_box(inst_map) + inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) + inst_map = inst_map[ + inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] + ] + inst_map = inst_map.astype(np.uint8) + inst_moment = cv2.moments(inst_map) + inst_contour = cv2.findContours( + inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + ) + # * opencv protocol format may break + inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) + # < 3 points dont make a contour, so skip, likely artifact too + # as the contours obtained via approximation => too small or sthg + if inst_contour.shape[0] < 3: + continue + if len(inst_contour.shape) != 2: + continue # ! check for trickery shape + inst_centroid = [ + (inst_moment["m10"] / inst_moment["m00"]), + (inst_moment["m01"] / inst_moment["m00"]), + ] + inst_centroid = np.array(inst_centroid) + inst_contour[:, 0] += inst_bbox[0][1] # X + inst_contour[:, 1] += inst_bbox[0][0] # Y + inst_centroid[0] += inst_bbox[0][1] # X + inst_centroid[1] += inst_bbox[0][0] # Y + inst_info_dict[inst_id] = { # inst_id should start at 1 + "bbox": inst_bbox, + "centroid": inst_centroid, + "contour": inst_contour, + "type_prob": None, + "type": None, + } + + if nr_types is not None: + #### * Get class of each instance id, stored at index id-1 + for inst_id in list(inst_info_dict.keys()): + rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() + inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] + inst_type_crop = pred_type[rmin:rmax, cmin:cmax] + inst_map_crop = ( + inst_map_crop == inst_id + ) # TODO: duplicated operation, may be expensive + inst_type = inst_type_crop[inst_map_crop] + type_list, type_pixels = np.unique(inst_type, return_counts=True) + type_list = list(zip(type_list, type_pixels)) + type_list = sorted(type_list, key=lambda x: x[1], reverse=True) + inst_type = type_list[0][0] + if inst_type == 0: # ! pick the 2nd most dominant if exist + if len(type_list) > 1: + inst_type = type_list[1][0] + type_dict = {v[0]: v[1] for v in type_list} + type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) + inst_info_dict[inst_id]["type"] = int(inst_type) + inst_info_dict[inst_id]["type_prob"] = float(type_prob) + + # print('here') + # ! WARNING: ID MAY NOT BE CONTIGUOUS + # inst_id in the dict maps to the same value in the `pred_inst` + return pred_inst, inst_info_dict diff --git a/models/hovernetC/run_desc.py b/models/hovernetC/run_desc.py new file mode 100644 index 0000000..98a2ba9 --- /dev/null +++ b/models/hovernetC/run_desc.py @@ -0,0 +1,331 @@ +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F + +from misc.utils import center_pad_to_shape, cropping_center +from .utils import crop_to_shape, dice_loss, mse_loss, msge_loss, xentropy_loss + +from collections import OrderedDict + +#### +def train_step(batch_data, run_info): + # TODO: synchronize the attach protocol + run_info, state_info = run_info + loss_func_dict = { + "bce": xentropy_loss, + "dice": dice_loss, + "mse": mse_loss, + "msge": msge_loss, + } + # use 'ema' to add for EMA calculation, must be scalar! + result_dict = {"EMA": {}} + track_value = lambda name, value: result_dict["EMA"].update({name: value}) + + #### + model = run_info["net"]["desc"] + optimizer = run_info["net"]["optimizer"] + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs = imgs.to("cuda").type(torch.float32) # to NCHW + imgs = imgs.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = true_np.to("cuda").type(torch.int64) + true_hv = true_hv.to("cuda").type(torch.float32) + + true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) + true_dict = { + "np": true_np_onehot, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64) + true_tp_onehot = F.one_hot(true_tp, num_classes=model.module.nr_types) + true_tp_onehot = true_tp_onehot.type(torch.float32) + true_dict["tp"] = true_tp_onehot + + #### + model.train() + model.zero_grad() # not rnn so not accumulate + + pred_dict = model(imgs) + + pred_dict = F.softmax(pred_dict, dim=-1) + + #### + loss = 0 + loss += xentropy_loss(pred_dict,pred_dict) + + track_value("overall_loss", loss.cpu().item()) + # * gradient update + + # torch.set_printoptions(precision=10) + loss.backward() + optimizer.step() + #### + + # # pick 2 random sample from the batch for visualization + # sample_indices = torch.randint(0, true_np.shape[0], (2,)) + # + # imgs = (imgs[sample_indices]).byte() # to uint8 + # imgs = imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy() + # + # pred_dict["np"] = pred_dict["np"][..., 1] # return pos only + # pred_dict = { + # k: v[sample_indices].detach().cpu().numpy() for k, v in pred_dict.items() + # } + # + # true_dict["np"] = true_np + # true_dict = { + # k: v[sample_indices].detach().cpu().numpy() for k, v in true_dict.items() + # } + # + # # * Its up to user to define the protocol to process the raw output per step! + # result_dict["raw"] = { # protocol for contents exchange within `raw` + # "img": imgs, + # "np": (true_dict["np"], pred_dict["np"]), + # "hv": (true_dict["hv"], pred_dict["hv"]), + # } + return pred_dict + + +#### +def valid_step(batch_data, run_info): + run_info, state_info = run_info + #### + model = run_info["net"]["desc"] + model.eval() # infer mode + + #### + imgs = batch_data["img"] + true_np = batch_data["np_map"] + true_hv = batch_data["hv_map"] + + imgs_gpu = imgs.to("cuda").type(torch.float32) # to NCHW + imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous() + + # HWC + true_np = torch.squeeze(true_np).type(torch.int64) + true_hv = torch.squeeze(true_hv).type(torch.float32) + + true_dict = { + "np": true_np, + "hv": true_hv, + } + + if model.module.nr_types is not None: + true_tp = batch_data["tp_map"] + true_tp = torch.squeeze(true_tp).type(torch.int64) + true_dict["tp"] = true_tp + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1] + if model.module.nr_types is not None: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=False) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + + # * Its up to user to define the protocol to process the raw output per step! + result_dict = { # protocol for contents exchange within `raw` + "raw": { + "imgs": imgs.numpy(), + "true_np": true_dict["np"].numpy(), + "true_hv": true_dict["hv"].numpy(), + "prob_np": pred_dict["np"].cpu().numpy(), + "pred_hv": pred_dict["hv"].cpu().numpy(), + } + } + if model.module.nr_types is not None: + result_dict["raw"]["true_tp"] = true_dict["tp"].numpy() + result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy() + return result_dict + + +#### +def infer_step(batch_data, model): + + #### + patch_imgs = batch_data + + patch_imgs_gpu = patch_imgs.to("cuda").type(torch.float32) # to NCHW + patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() + + #### + model.eval() # infer mode + + # -------------------------------------------------------------- + with torch.no_grad(): # dont compute gradient + pred_dict = model(patch_imgs_gpu) + pred_dict = OrderedDict( + [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] + ) + pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] + if "tp" in pred_dict: + type_map = F.softmax(pred_dict["tp"], dim=-1) + type_map = torch.argmax(type_map, dim=-1, keepdim=True) + type_map = type_map.type(torch.float32) + pred_dict["tp"] = type_map + pred_output = torch.cat(list(pred_dict.values()), -1) + + # * Its up to user to define the protocol to process the raw output per step! + return pred_output.cpu().numpy() + + +#### +def viz_step_output(raw_data, nr_types=None): + """ + `raw_data` will be implicitly provided in the similar format as the + return dict from train/valid step, but may have been accumulated across N running step + """ + + imgs = raw_data["img"] + true_np, pred_np = raw_data["np"] + true_hv, pred_hv = raw_data["hv"] + if nr_types is not None: + true_tp, pred_tp = raw_data["tp"] + + aligned_shape = [list(imgs.shape), list(true_np.shape), list(pred_np.shape)] + aligned_shape = np.min(np.array(aligned_shape), axis=0)[1:3] + + cmap = plt.get_cmap("jet") + + def colorize(ch, vmin, vmax): + """ + Will clamp value value outside the provided range to vmax and vmin + """ + ch = np.squeeze(ch.astype("float32")) + ch[ch > vmax] = vmax # clamp value + ch[ch < vmin] = vmin + ch = (ch - vmin) / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + # ch_cmap = center_pad_to_shape(ch_cmap, aligned_shape) + return ch_cmap + + viz_list = [] + for idx in range(imgs.shape[0]): + # img = center_pad_to_shape(imgs[idx], aligned_shape) + img = cropping_center(imgs[idx], aligned_shape) + + true_viz_list = [img] + # cmap may randomly fails if of other types + true_viz_list.append(colorize(true_np[idx], 0, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 0], -1, 1)) + true_viz_list.append(colorize(true_hv[idx][..., 1], -1, 1)) + if nr_types is not None: # TODO: a way to pass through external info + true_viz_list.append(colorize(true_tp[idx], 0, nr_types)) + true_viz_list = np.concatenate(true_viz_list, axis=1) + + pred_viz_list = [img] + # cmap may randomly fails if of other types + pred_viz_list.append(colorize(pred_np[idx], 0, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 0], -1, 1)) + pred_viz_list.append(colorize(pred_hv[idx][..., 1], -1, 1)) + if nr_types is not None: + pred_viz_list.append(colorize(pred_tp[idx], 0, nr_types)) + pred_viz_list = np.concatenate(pred_viz_list, axis=1) + + viz_list.append(np.concatenate([true_viz_list, pred_viz_list], axis=0)) + viz_list = np.concatenate(viz_list, axis=0) + return viz_list + + +#### +from itertools import chain + + +def proc_valid_step_output(raw_data, nr_types=None): + # TODO: add auto populate from main state track list + track_dict = {"scalar": {}, "image": {}} + + def track_value(name, value, vtype): + return track_dict[vtype].update({name: value}) + + def _dice_info(true, pred, label): + true = np.array(true == label, np.int32) + pred = np.array(pred == label, np.int32) + inter = (pred * true).sum() + total = (pred + true).sum() + return inter, total + + over_inter = 0 + over_total = 0 + over_correct = 0 + prob_np = raw_data["prob_np"] + true_np = raw_data["true_np"] + for idx in range(len(raw_data["true_np"])): + patch_prob_np = prob_np[idx] + patch_true_np = true_np[idx] + patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32) + inter, total = _dice_info(patch_true_np, patch_pred_np, 1) + correct = (patch_pred_np == patch_true_np).sum() + over_inter += inter + over_total += total + over_correct += correct + nr_pixels = len(true_np) * np.size(true_np[0]) + acc_np = over_correct / nr_pixels + dice_np = 2 * over_inter / (over_total + 1.0e-8) + track_value("np_acc", acc_np, "scalar") + track_value("np_dice", dice_np, "scalar") + + # * TP statistic + if nr_types is not None: + pred_tp = raw_data["pred_tp"] + true_tp = raw_data["true_tp"] + for type_id in range(0, nr_types): + over_inter = 0 + over_total = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_tp = pred_tp[idx] + patch_true_tp = true_tp[idx] + inter, total = _dice_info(patch_true_tp, patch_pred_tp, type_id) + over_inter += inter + over_total += total + dice_tp = 2 * over_inter / (over_total + 1.0e-8) + track_value("tp_dice_%d" % type_id, dice_tp, "scalar") + + # * HV regression statistic + pred_hv = raw_data["pred_hv"] + true_hv = raw_data["true_hv"] + + over_squared_error = 0 + for idx in range(len(raw_data["true_np"])): + patch_pred_hv = pred_hv[idx] + patch_true_hv = true_hv[idx] + squared_error = patch_pred_hv - patch_true_hv + squared_error = squared_error * squared_error + over_squared_error += squared_error.sum() + mse = over_squared_error / nr_pixels + track_value("hv_mse", mse, "scalar") + + # * + imgs = raw_data["imgs"] + selected_idx = np.random.randint(0, len(imgs), size=(8,)).tolist() + imgs = np.array([imgs[idx] for idx in selected_idx]) + true_np = np.array([true_np[idx] for idx in selected_idx]) + true_hv = np.array([true_hv[idx] for idx in selected_idx]) + prob_np = np.array([prob_np[idx] for idx in selected_idx]) + pred_hv = np.array([pred_hv[idx] for idx in selected_idx]) + viz_raw_data = {"img": imgs, "np": (true_np, prob_np), "hv": (true_hv, pred_hv)} + + if nr_types is not None: + true_tp = np.array([true_tp[idx] for idx in selected_idx]) + pred_tp = np.array([pred_tp[idx] for idx in selected_idx]) + viz_raw_data["tp"] = (true_tp, pred_tp) + viz_fig = viz_step_output(viz_raw_data, nr_types) + track_dict["image"]["output"] = viz_fig + + return track_dict diff --git a/models/hovernetC/targets.py b/models/hovernetC/targets.py new file mode 100644 index 0000000..d9466d9 --- /dev/null +++ b/models/hovernetC/targets.py @@ -0,0 +1,153 @@ +import math +import numpy as np + +import torch +import torch.nn.functional as F + +from scipy import ndimage +from scipy.ndimage import measurements +from skimage import morphology as morph +import matplotlib.pyplot as plt + +from misc.utils import center_pad_to_shape, cropping_center, get_bounding_box +from dataloader.augs import fix_mirror_padding + + +#### +def gen_instance_hv_map(ann, crop_shape): + """Input annotation must be of original shape. + + The map is calculated only for instances within the crop portion + but based on the original shape in original image. + + Perform following operation: + Obtain the horizontal and vertical distance maps for each + nuclear instance. + + """ + orig_ann = ann.copy() # instance ID map + fixed_ann = fix_mirror_padding(orig_ann) + # re-cropping with fixed instance id map + crop_ann = cropping_center(fixed_ann, crop_shape) + # TODO: deal with 1 label warning + crop_ann = morph.remove_small_objects(crop_ann, min_size=30) + + x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) + y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) + + inst_list = list(np.unique(crop_ann)) + inst_list.remove(0) # 0 is background + for inst_id in inst_list: + inst_map = np.array(fixed_ann == inst_id, np.uint8) + inst_box = get_bounding_box(inst_map) + + # expand the box by 2px + # Because we first pad the ann at line 207, the bboxes + # will remain valid after expansion + inst_box[0] -= 2 + inst_box[2] -= 2 + inst_box[1] += 2 + inst_box[3] += 2 + + inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + + if inst_map.shape[0] < 2 or inst_map.shape[1] < 2: + continue + + # instance center of mass, rounded to nearest pixel + inst_com = list(measurements.center_of_mass(inst_map)) + + inst_com[0] = int(inst_com[0] + 0.5) + inst_com[1] = int(inst_com[1] + 0.5) + + inst_x_range = np.arange(1, inst_map.shape[1] + 1) + inst_y_range = np.arange(1, inst_map.shape[0] + 1) + # shifting center of pixels grid to instance center of mass + inst_x_range -= inst_com[1] + inst_y_range -= inst_com[0] + + inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) + + # remove coord outside of instance + inst_x[inst_map == 0] = 0 + inst_y[inst_map == 0] = 0 + inst_x = inst_x.astype("float32") + inst_y = inst_y.astype("float32") + + # normalize min into -1 scale + if np.min(inst_x) < 0: + inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) + if np.min(inst_y) < 0: + inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) + # normalize max into +1 scale + if np.max(inst_x) > 0: + inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) + if np.max(inst_y) > 0: + inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) + + #### + x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + x_map_box[inst_map > 0] = inst_x[inst_map > 0] + + y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] + y_map_box[inst_map > 0] = inst_y[inst_map > 0] + + hv_map = np.dstack([x_map, y_map]) + return hv_map + + +#### +def gen_targets(ann, crop_shape, **kwargs): + """Generate the targets for the network.""" + hv_map = gen_instance_hv_map(ann, crop_shape) + np_map = ann.copy() + np_map[np_map > 0] = 1 + + hv_map = cropping_center(hv_map, crop_shape) + np_map = cropping_center(np_map, crop_shape) + + target_dict = { + "hv_map": hv_map, + "np_map": np_map, + } + + return target_dict + + +#### +def prep_sample(data, is_batch=False, **kwargs): + """ + Designed to process direct output from loader + """ + cmap = plt.get_cmap("jet") + + def colorize(ch, vmin, vmax, shape): + ch = np.squeeze(ch.astype("float32")) + ch = ch / (vmax - vmin + 1.0e-16) + # take RGB from RGBA heat map + ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") + ch_cmap = center_pad_to_shape(ch_cmap, shape) + return ch_cmap + + def prep_one_sample(data): + shape_array = [np.array(v.shape[:2]) for v in data.values()] + shape = np.maximum(*shape_array) + viz_list = [] + viz_list.append(colorize(data["np_map"], 0, 1, shape)) + # map to [0,2] for better visualisation. + # Note, [-1,1] is used for training. + viz_list.append(colorize(data["hv_map"][..., 0] + 1, 0, 2, shape)) + viz_list.append(colorize(data["hv_map"][..., 1] + 1, 0, 2, shape)) + img = center_pad_to_shape(data["img"], shape) + return np.concatenate([img] + viz_list, axis=1) + + # cmap may randomly fails if of other types + if is_batch: + viz_list = [] + data_shape = list(data.values())[0].shape + for batch_idx in range(data_shape[0]): + sub_data = {k : v[batch_idx] for k, v in data.items()} + viz_list.append(prep_one_sample(sub_data)) + return np.concatenate(viz_list, axis=0) + else: + return prep_one_sample(data) diff --git a/models/hovernetC/utils.py b/models/hovernetC/utils.py new file mode 100644 index 0000000..3842287 --- /dev/null +++ b/models/hovernetC/utils.py @@ -0,0 +1,172 @@ +import math +import numpy as np + +import torch +import torch.nn.functional as F + +from matplotlib import cm + + +#### +def crop_op(x, cropping, data_format="NCHW"): + """Center crop image. + + Args: + x: input image + cropping: the substracted amount + data_format: choose either `NCHW` or `NHWC` + + """ + crop_t = cropping[0] // 2 + crop_b = cropping[0] - crop_t + crop_l = cropping[1] // 2 + crop_r = cropping[1] - crop_l + if data_format == "NCHW": + x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] + else: + x = x[:, crop_t:-crop_b, crop_l:-crop_r, :] + return x + + +#### +def crop_to_shape(x, y, data_format="NCHW"): + """Centre crop x so that x has shape of y. y dims must be smaller than x dims. + + Args: + x: input array + y: array with desired shape. + + """ + assert ( + y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1] + ), "Ensure that y dimensions are smaller than x dimensions!" + + x_shape = x.size() + y_shape = y.size() + if data_format == "NCHW": + crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3]) + else: + crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2]) + return crop_op(x, crop_shape, data_format) + + +#### +def xentropy_loss(true, pred, reduction="mean"): + """Cross entropy loss. Assumes NHWC! + + Args: + pred: prediction array + true: ground truth array + + Returns: + cross entropy loss + + """ + epsilon = 10e-8 + # scale preds so that the class probs of each sample sum to 1 + pred = pred / torch.sum(pred, -1, keepdim=True) + # manual computation of crossentropy + pred = torch.clamp(pred, epsilon, 1.0 - epsilon) + loss = -torch.sum((true * torch.log(pred)), -1, keepdim=True) + loss = loss.mean() if reduction == "mean" else loss.sum() + return loss + + +#### +def dice_loss(true, pred, smooth=1e-3): + """`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC.""" + inse = torch.sum(pred * true, (0, 1, 2)) + l = torch.sum(pred, (0, 1, 2)) + r = torch.sum(true, (0, 1, 2)) + loss = 1.0 - (2.0 * inse + smooth) / (l + r + smooth) + loss = torch.sum(loss) + return loss + + +#### +def mse_loss(true, pred): + """Calculate mean squared error loss. + + Args: + true: ground truth of combined horizontal + and vertical maps + pred: prediction of combined horizontal + and vertical maps + + Returns: + loss: mean squared error + + """ + loss = pred - true + loss = (loss * loss).mean() + return loss + + +#### +def msge_loss(true, pred, focus): + """Calculate the mean squared error of the gradients of + horizontal and vertical map predictions. Assumes + channel 0 is Vertical and channel 1 is Horizontal. + + Args: + true: ground truth of combined horizontal + and vertical maps + pred: prediction of combined horizontal + and vertical maps + focus: area where to apply loss (we only calculate + the loss within the nuclei) + + Returns: + loss: mean squared error of gradients + + """ + + def get_sobel_kernel(size): + """Get sobel kernel with a given size.""" + assert size % 2 == 1, "Must be odd, get size=%d" % size + + h_range = torch.arange( + -size // 2 + 1, + size // 2 + 1, + dtype=torch.float32, + device="cuda", + requires_grad=False, + ) + v_range = torch.arange( + -size // 2 + 1, + size // 2 + 1, + dtype=torch.float32, + device="cuda", + requires_grad=False, + ) + h, v = torch.meshgrid(h_range, v_range) + kernel_h = h / (h * h + v * v + 1.0e-15) + kernel_v = v / (h * h + v * v + 1.0e-15) + return kernel_h, kernel_v + + #### + def get_gradient_hv(hv): + """For calculating gradient.""" + kernel_h, kernel_v = get_sobel_kernel(5) + kernel_h = kernel_h.view(1, 1, 5, 5) # constant + kernel_v = kernel_v.view(1, 1, 5, 5) # constant + + h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW + v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW + + # can only apply in NCHW mode + h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2) + v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2) + dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1) + dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC + return dhv + + focus = (focus[..., None]).float() # assume input NHW + focus = torch.cat([focus, focus], axis=-1) + true_grad = get_gradient_hv(true) + pred_grad = get_gradient_hv(pred) + loss = pred_grad - true_grad + loss = focus * (loss * loss) + # artificial reduce_mean with focused region + loss = loss.sum() / (focus.sum() + 1.0e-8) + return loss