Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor bug fixes #24

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ docker pull projectmonai/monai:1.3.2


### Inference
We provide two ways to use the model for inference.
1. We recommend users to use the optimized and standardized [MONAI bundle]() model. The bundle provides a unified API for inference.
The [VISTA3D NVIDIA Inference Microservices (NIM)]() deploys the bundle with an interactive front-end.
We provide two ways to use the model for inference.
1. We recommend users to use the optimized and standardized [MONAI bundle]() model. The bundle provides a unified API for inference.
The [VISTA3D NVIDIA Inference Microservices (NIM)]() deploys the bundle with an interactive front-end.
2. For quick debugging and model development purposes, we also provide the `infer.py` script and its light-weight front-end `debugger.py`. `python -m scripts.debugger run`. Note we will prioritize [NIM]() and [monai bundle]() developments and those functions will be deprecated in the future.
```
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt [1] --save_mask true
Expand Down
21 changes: 15 additions & 6 deletions scripts/gui.py → scripts/debugger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import copy
from tkinter import Tk, filedialog, messagebox, simpledialog
from tkinter import Tk, filedialog, messagebox

import fire
import pdb
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from matplotlib.widgets import Button, CheckButtons, TextBox
from matplotlib.widgets import Button, TextBox

from .infer import InferClass
from .utils.workflow_utils import get_point_label
Expand Down Expand Up @@ -62,6 +61,12 @@ def generate_mask(self):
prompt_class = None
neg_id, pos_id = get_point_label(1)
else:
if self.class_label in [2, 20, 21]:
messagebox.showwarning(
"Warning",
"Current debugger skip kidney (2), lung (20), and bone (21). Use their subclasses.",
)
return
label_prompt = int(self.class_label)
neg_id, pos_id = get_point_label(label_prompt)
label_prompt = np.array([label_prompt])[np.newaxis, ...]
Expand All @@ -85,7 +90,7 @@ def generate_mask(self):
label_prompt,
prompt_class,
save_mask=True,
point_start=self.point_start
point_start=self.point_start,
)[0]
nan_mask = np.isnan(mask)
mask = mask.data.cpu().numpy() > 0.5
Expand Down Expand Up @@ -118,8 +123,12 @@ def on_button_click(event, ax=ax):
print("-- segmenting ---")
self.generate_mask()
print("-- done ---")
print("-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---")
print("-- Note: Click points not matching class prompts will also cause confusion. ---")
print(
"-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---"
)
print(
"-- Note: Click points not matching class prompts will also cause confusion. ---"
)
print("-- Note: CTRL + Right Click will be adding negative points. ---")
self.update_slice(ax)
# self.point_start = len(self.clicked_points)
Expand Down
58 changes: 14 additions & 44 deletions scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,50 +20,20 @@
import torch.distributed as dist
from monai import transforms
from monai.apps.auto3dseg.auto_runner import logger
from monai.apps.utils import DEFAULT_FMT
from monai.auto3dseg.utils import datafold_read
from monai.bundle import ConfigParser
from monai.bundle.scripts import _pop_args, _update_args
from monai.data import decollate_batch, list_data_collate, partition_dataset
from monai.utils import optional_import, RankFilter
from monai.utils import optional_import

from vista3d import vista_model_registry

from .sliding_window import point_based_window_inferer, sliding_window_inference
from .utils.trans_utils import get_largest_connected_component_point, VistaPostTransform
# from .train import CONFIG
from .train import CONFIG
from .utils.trans_utils import VistaPostTransform

rearrange, _ = optional_import("einops", name="rearrange")
RankFilter, _ = optional_import("monai.utils", name="RankFilter")
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {"monai_default": {"format": DEFAULT_FMT}},
"loggers": {
"monai.apps.auto3dseg.auto_runner": {
"handlers": ["file", "console"],
"level": "DEBUG",
"propagate": False,
}
},
"filters": {"rank_filter": {"{}": RankFilter}},
"handlers": {
"file": {
"class": "logging.FileHandler",
"filename": "runner.log",
"mode": "a", # append or overwrite
"level": "DEBUG",
"formatter": "monai_default",
"filters": ["rank_filter"],
},
"console": {
"class": "logging.StreamHandler",
"level": "INFO",
"formatter": "monai_default",
"filters": ["rank_filter"],
},
},
}
IGNORE_PROMPT = set(
[
2, # kidney
Expand Down Expand Up @@ -92,7 +62,7 @@ def infer_wrapper(inputs, model, **kwargs):


class InferClass:
def __init__(self, config_file='./configs/infer.yaml', **override):
def __init__(self, config_file="./configs/infer.yaml", **override):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

_args = _update_args(config_file=config_file, **override)
Expand Down Expand Up @@ -139,7 +109,7 @@ def __init__(self, config_file='./configs/infer.yaml', **override):
meta_key_postfix="meta_dict",
nearest_interp=True,
to_tensor=True,
)
),
]

# For Vista3d, sigmoid is always used, but for visualization, argmax is needed
Expand Down Expand Up @@ -184,16 +154,16 @@ def infer(
label_prompt=None,
prompt_class=None,
save_mask=False,
point_start=0
point_start=0,
):
"""Infer a single image_file. If save_mask is true, save the argmax prediction to disk. If false,
do not save and return the probability maps (usually used by autorunner emsembler). point_start is
do not save and return the probability maps (usually used by autorunner emsembler). point_start is
used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save
time and avoid repeated inference. This is by default disabled.
"""
self.model.eval()
if type(image_file) is not dict:
image_file = {'image': image_file}
if not isinstance(image_file, dict):
image_file = {"image": image_file}
if self.batch_data is not None:
batch_data = self.batch_data
else:
Expand Down Expand Up @@ -279,15 +249,15 @@ def infer(
if not finished:
raise RuntimeError("Infer not finished due to OOM.")
return batch_data[0]["pred"]

@torch.no_grad()
def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
self.model.eval()
device = f"cuda:{rank}"
if type(image_file) is not dict:
image_file = {'image': image_file}
if not isinstance(image_file, dict):
image_file = {"image": image_file}
batch_data = self.infer_transforms(image_file)
batch_data['label_prompt'] = label_prompt
batch_data["label_prompt"] = label_prompt
batch_data = list_data_collate([batch_data])
device_list_input = [device, device, "cpu"]
device_list_output = [device, "cpu", "cpu"]
Expand Down
2 changes: 1 addition & 1 deletion scripts/train_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from monai.data import DataLoader, DistributedSampler
from monai.metrics import compute_dice
from monai.networks.utils import copy_model_state
from monai.utils import set_determinism, RankFilter
from monai.utils import set_determinism
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
Expand Down
26 changes: 14 additions & 12 deletions scripts/utils/trans_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ def get_largest_connected_component_mask(

class VistaPostTransform(MapTransform):
def __init__(
self,
keys: KeysCollection,
allow_missing_keys: bool = False,
self,
keys: KeysCollection,
allow_missing_keys: bool = False,
) -> None:
"""
Args:
Expand All @@ -326,29 +326,31 @@ def __init__(
"""
super().__init__(keys, allow_missing_keys)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
def __call__(
self, data: Mapping[Hashable, NdarrayOrTensor]
) -> dict[Hashable, NdarrayOrTensor]:
for keys in self.keys:
if keys in data:
pred = data[keys]
object_num = pred.shape[0]
device = pred.device
# device = pred.device
pred[pred < 0] = 0.0
# if it's multichannel, perform argmax
if object_num > 1:
# concate background channel. Make sure user did not provide 0 as prompt.
is_bk = torch.all(pred<=0, dim=0, keepdim=True)
is_bk = torch.all(pred <= 0, dim=0, keepdim=True)
pred = pred.argmax(0).unsqueeze(0).float() + 1.0
pred[is_bk] = 0.0
else:
# AsDiscrete will remove NaN
# pred = monai.transforms.AsDiscrete(threshold=0.5)(pred)
pred[pred > 0] = 1.0
if "label_prompt" in data and data['label_prompt'] is not None:
pred += 0.5 # inplace mapping to avoid cloning pred
for i in range(1, object_num + 1):
frac = i + 0.5
pred[pred == frac] = data['label_prompt'][i-1].to(pred.dtype)
pred[pred == 0.5] = 0.0
if "label_prompt" in data and data["label_prompt"] is not None:
pred += 0.5 # inplace mapping to avoid cloning pred
for i in range(1, object_num + 1):
frac = i + 0.5
pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype)
pred[pred == 0.5] = 0.0
data[keys] = pred
return data

Expand Down
6 changes: 0 additions & 6 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ def test_vista3d_logger(self):
logging.config.dictConfig(CONFIG)
logger.warning("check train logging format")

def test_vista3d_logger_infer(self):
from scripts.infer import CONFIG

logging.config.dictConfig(CONFIG)
logger.warning("check infer logging format")


if __name__ == "__main__":
unittest.main()