diff --git a/README.md b/README.md index 1faae7b..197b902 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,9 @@ docker pull projectmonai/monai:1.3.2 ### Inference -We provide two ways to use the model for inference. -1. We recommend users to use the optimized and standardized [MONAI bundle]() model. The bundle provides a unified API for inference. -The [VISTA3D NVIDIA Inference Microservices (NIM)]() deploys the bundle with an interactive front-end. +We provide two ways to use the model for inference. +1. We recommend users to use the optimized and standardized [MONAI bundle]() model. The bundle provides a unified API for inference. +The [VISTA3D NVIDIA Inference Microservices (NIM)]() deploys the bundle with an interactive front-end. 2. For quick debugging and model development purposes, we also provide the `infer.py` script and its light-weight front-end `debugger.py`. `python -m scripts.debugger run`. Note we will prioritize [NIM]() and [monai bundle]() developments and those functions will be deprecated in the future. ``` export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true diff --git a/scripts/gui.py b/scripts/debugger.py similarity index 91% rename from scripts/gui.py rename to scripts/debugger.py index 9c2e97a..c297e7e 100644 --- a/scripts/gui.py +++ b/scripts/debugger.py @@ -1,12 +1,11 @@ import copy -from tkinter import Tk, filedialog, messagebox, simpledialog +from tkinter import Tk, filedialog, messagebox import fire -import pdb import matplotlib.pyplot as plt import nibabel as nib import numpy as np -from matplotlib.widgets import Button, CheckButtons, TextBox +from matplotlib.widgets import Button, TextBox from .infer import InferClass from .utils.workflow_utils import get_point_label @@ -62,6 +61,12 @@ def generate_mask(self): prompt_class = None neg_id, pos_id = get_point_label(1) else: + if self.class_label in [2, 20, 21]: + messagebox.showwarning( + "Warning", + "Current debugger skip kidney (2), lung (20), and bone (21). Use their subclasses.", + ) + return label_prompt = int(self.class_label) neg_id, pos_id = get_point_label(label_prompt) label_prompt = np.array([label_prompt])[np.newaxis, ...] @@ -85,7 +90,7 @@ def generate_mask(self): label_prompt, prompt_class, save_mask=True, - point_start=self.point_start + point_start=self.point_start, )[0] nan_mask = np.isnan(mask) mask = mask.data.cpu().numpy() > 0.5 @@ -118,8 +123,12 @@ def on_button_click(event, ax=ax): print("-- segmenting ---") self.generate_mask() print("-- done ---") - print("-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---") - print("-- Note: Click points not matching class prompts will also cause confusion. ---") + print( + "-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---" + ) + print( + "-- Note: Click points not matching class prompts will also cause confusion. ---" + ) print("-- Note: CTRL + Right Click will be adding negative points. ---") self.update_slice(ax) # self.point_start = len(self.clicked_points) diff --git a/scripts/infer.py b/scripts/infer.py index a421a8a..18752c1 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -20,50 +20,20 @@ import torch.distributed as dist from monai import transforms from monai.apps.auto3dseg.auto_runner import logger -from monai.apps.utils import DEFAULT_FMT from monai.auto3dseg.utils import datafold_read from monai.bundle import ConfigParser from monai.bundle.scripts import _pop_args, _update_args from monai.data import decollate_batch, list_data_collate, partition_dataset -from monai.utils import optional_import, RankFilter +from monai.utils import optional_import from vista3d import vista_model_registry from .sliding_window import point_based_window_inferer, sliding_window_inference -from .utils.trans_utils import get_largest_connected_component_point, VistaPostTransform -# from .train import CONFIG +from .train import CONFIG +from .utils.trans_utils import VistaPostTransform + rearrange, _ = optional_import("einops", name="rearrange") -RankFilter, _ = optional_import("monai.utils", name="RankFilter") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) -CONFIG = { - "version": 1, - "disable_existing_loggers": False, - "formatters": {"monai_default": {"format": DEFAULT_FMT}}, - "loggers": { - "monai.apps.auto3dseg.auto_runner": { - "handlers": ["file", "console"], - "level": "DEBUG", - "propagate": False, - } - }, - "filters": {"rank_filter": {"{}": RankFilter}}, - "handlers": { - "file": { - "class": "logging.FileHandler", - "filename": "runner.log", - "mode": "a", # append or overwrite - "level": "DEBUG", - "formatter": "monai_default", - "filters": ["rank_filter"], - }, - "console": { - "class": "logging.StreamHandler", - "level": "INFO", - "formatter": "monai_default", - "filters": ["rank_filter"], - }, - }, -} IGNORE_PROMPT = set( [ 2, # kidney @@ -92,7 +62,7 @@ def infer_wrapper(inputs, model, **kwargs): class InferClass: - def __init__(self, config_file='./configs/infer.yaml', **override): + def __init__(self, config_file="./configs/infer.yaml", **override): logging.basicConfig(stream=sys.stdout, level=logging.INFO) _args = _update_args(config_file=config_file, **override) @@ -139,7 +109,7 @@ def __init__(self, config_file='./configs/infer.yaml', **override): meta_key_postfix="meta_dict", nearest_interp=True, to_tensor=True, - ) + ), ] # For Vista3d, sigmoid is always used, but for visualization, argmax is needed @@ -184,16 +154,16 @@ def infer( label_prompt=None, prompt_class=None, save_mask=False, - point_start=0 + point_start=0, ): """Infer a single image_file. If save_mask is true, save the argmax prediction to disk. If false, - do not save and return the probability maps (usually used by autorunner emsembler). point_start is + do not save and return the probability maps (usually used by autorunner emsembler). point_start is used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save time and avoid repeated inference. This is by default disabled. """ self.model.eval() - if type(image_file) is not dict: - image_file = {'image': image_file} + if not isinstance(image_file, dict): + image_file = {"image": image_file} if self.batch_data is not None: batch_data = self.batch_data else: @@ -279,15 +249,15 @@ def infer( if not finished: raise RuntimeError("Infer not finished due to OOM.") return batch_data[0]["pred"] - + @torch.no_grad() def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): self.model.eval() device = f"cuda:{rank}" - if type(image_file) is not dict: - image_file = {'image': image_file} + if not isinstance(image_file, dict): + image_file = {"image": image_file} batch_data = self.infer_transforms(image_file) - batch_data['label_prompt'] = label_prompt + batch_data["label_prompt"] = label_prompt batch_data = list_data_collate([batch_data]) device_list_input = [device, device, "cpu"] device_list_output = [device, "cpu", "cpu"] diff --git a/scripts/train_finetune.py b/scripts/train_finetune.py index 1f77deb..bb945ee 100644 --- a/scripts/train_finetune.py +++ b/scripts/train_finetune.py @@ -36,7 +36,7 @@ from monai.data import DataLoader, DistributedSampler from monai.metrics import compute_dice from monai.networks.utils import copy_model_state -from monai.utils import set_determinism, RankFilter +from monai.utils import set_determinism from torch.nn.parallel import DistributedDataParallel from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm diff --git a/scripts/utils/trans_utils.py b/scripts/utils/trans_utils.py index ef1feb4..c6dc0c0 100644 --- a/scripts/utils/trans_utils.py +++ b/scripts/utils/trans_utils.py @@ -311,9 +311,9 @@ def get_largest_connected_component_mask( class VistaPostTransform(MapTransform): def __init__( - self, - keys: KeysCollection, - allow_missing_keys: bool = False, + self, + keys: KeysCollection, + allow_missing_keys: bool = False, ) -> None: """ Args: @@ -326,29 +326,31 @@ def __init__( """ super().__init__(keys, allow_missing_keys) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + def __call__( + self, data: Mapping[Hashable, NdarrayOrTensor] + ) -> dict[Hashable, NdarrayOrTensor]: for keys in self.keys: if keys in data: pred = data[keys] object_num = pred.shape[0] - device = pred.device + # device = pred.device pred[pred < 0] = 0.0 # if it's multichannel, perform argmax if object_num > 1: # concate background channel. Make sure user did not provide 0 as prompt. - is_bk = torch.all(pred<=0, dim=0, keepdim=True) + is_bk = torch.all(pred <= 0, dim=0, keepdim=True) pred = pred.argmax(0).unsqueeze(0).float() + 1.0 pred[is_bk] = 0.0 else: # AsDiscrete will remove NaN # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred) pred[pred > 0] = 1.0 - if "label_prompt" in data and data['label_prompt'] is not None: - pred += 0.5 # inplace mapping to avoid cloning pred - for i in range(1, object_num + 1): - frac = i + 0.5 - pred[pred == frac] = data['label_prompt'][i-1].to(pred.dtype) - pred[pred == 0.5] = 0.0 + if "label_prompt" in data and data["label_prompt"] is not None: + pred += 0.5 # inplace mapping to avoid cloning pred + for i in range(1, object_num + 1): + frac = i + 0.5 + pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) + pred[pred == 0.5] = 0.0 data[keys] = pred return data diff --git a/tests/test_logger.py b/tests/test_logger.py index b5f713c..748abd3 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -22,12 +22,6 @@ def test_vista3d_logger(self): logging.config.dictConfig(CONFIG) logger.warning("check train logging format") - def test_vista3d_logger_infer(self): - from scripts.infer import CONFIG - - logging.config.dictConfig(CONFIG) - logger.warning("check infer logging format") - if __name__ == "__main__": unittest.main()