From 8817cbf444278becc15a9e93d0bc2b85d7b10404 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 16 Jul 2024 08:19:46 -0400 Subject: [PATCH 1/2] Add point postprocessing Signed-off-by: heyufan1995 --- scripts/infer.py | 7 +++++-- scripts/utils/trans_utils.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 9b2be48..246e623 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -30,7 +30,7 @@ from .sliding_window import point_based_window_inferer, sliding_window_inference from .train import CONFIG -from .utils.trans_utils import VistaPostTransform +from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -168,7 +168,8 @@ def infer( batch_data = self.batch_data else: batch_data = self.infer_transforms(image_file) - batch_data["label_prompt"] = label_prompt + if label_prompt is not None: + batch_data["label_prompt"] = label_prompt batch_data = list_data_collate([batch_data]) self.batch_data = batch_data if point is not None: @@ -231,6 +232,8 @@ def infer( meta=batch_data["image"].meta, ) self.prev_mask = batch_data["pred"] + if label_prompt is None and point is not None: + batch_data['pred'] = get_largest_connected_component_point(batch_data['pred'],point_coords=point, point_labels=point_label) batch_data["image"] = batch_data["image"].to("cpu") batch_data["pred"] = batch_data["pred"].to("cpu") torch.cuda.empty_cache() diff --git a/scripts/utils/trans_utils.py b/scripts/utils/trans_utils.py index c446bad..ec3ac4d 100644 --- a/scripts/utils/trans_utils.py +++ b/scripts/utils/trans_utils.py @@ -195,7 +195,7 @@ def dilate3d(input_tensor, erosion=3): def get_largest_connected_component_point( - img: NdarrayTensor, point_coords=None, point_labels=None, post_idx=3 + img: NdarrayTensor, point_coords=None, point_labels=None ) -> NdarrayTensor: """ Gets the largest connected component mask of an image. img is before post process! And will include NaN values. From a6d8bbfdd36c04de0f98e8cea87c04e83604ed0d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 12:21:59 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/infer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/infer.py b/scripts/infer.py index 246e623..f9102e8 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -233,7 +233,9 @@ def infer( ) self.prev_mask = batch_data["pred"] if label_prompt is None and point is not None: - batch_data['pred'] = get_largest_connected_component_point(batch_data['pred'],point_coords=point, point_labels=point_label) + batch_data["pred"] = get_largest_connected_component_point( + batch_data["pred"], point_coords=point, point_labels=point_label + ) batch_data["image"] = batch_data["image"].to("cpu") batch_data["pred"] = batch_data["pred"].to("cpu") torch.cuda.empty_cache()