From 43a84e7c583b29491f71fd144c48945c6d07ea13 Mon Sep 17 00:00:00 2001 From: am Date: Fri, 23 Aug 2024 13:21:41 -0700 Subject: [PATCH] vista2d net from monai component Signed-off-by: am --- vista2d/configs/hyper_parameters.yaml | 2 +- vista2d/configs/inference.json | 2 +- vista2d/scripts/cell_sam_wrapper.py | 80 --------------------------- 3 files changed, 2 insertions(+), 82 deletions(-) delete mode 100644 vista2d/scripts/cell_sam_wrapper.py diff --git a/vista2d/configs/hyper_parameters.yaml b/vista2d/configs/hyper_parameters.yaml index 9f6e8b7..2855457 100644 --- a/vista2d/configs/hyper_parameters.yaml +++ b/vista2d/configs/hyper_parameters.yaml @@ -83,7 +83,7 @@ infer: device: "$torch.device(('cuda:' + os.environ.get('LOCAL_RANK', '0')) if torch.cuda.is_available() else 'cpu')" network_def: - _target_: scripts.cell_sam_wrapper.CellSamWrapper + _target_: monai.networks.nets.cell_sam_wrapper.CellSamWrapper checkpoint: $os.path.join(@ckpt_path, "sam_vit_b_01ec64.pth") network: $@network_def.to(@device) diff --git a/vista2d/configs/inference.json b/vista2d/configs/inference.json index cf5bfa8..2439a4a 100644 --- a/vista2d/configs/inference.json +++ b/vista2d/configs/inference.json @@ -20,7 +20,7 @@ "use_amp": true, "amp_dtype": "$torch.float", "network_def": { - "_target_": "scripts.cell_sam_wrapper.CellSamWrapper", + "_target_": "monai.networks.nets.cell_sam_wrapper.CellSamWrapper", "checkpoint": "@sam_ckpt_path" }, "network": "$@network_def.to(@device)", diff --git a/vista2d/scripts/cell_sam_wrapper.py b/vista2d/scripts/cell_sam_wrapper.py deleted file mode 100644 index 8cbd8af..0000000 --- a/vista2d/scripts/cell_sam_wrapper.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from segment_anything.build_sam import build_sam_vit_b -from torch import nn -from torch.nn import functional as F - - -class CellSamWrapper(torch.nn.Module): - def __init__( - self, - auto_resize_inputs=True, - network_resize_roi=[1024, 1024], - checkpoint="sam_vit_b_01ec64.pth", - return_features=False, - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - - print( - f"CellSamWrapper auto_resize_inputs {auto_resize_inputs} network_resize_roi {network_resize_roi} checkpoint {checkpoint}" - ) - self.network_resize_roi = network_resize_roi - self.auto_resize_inputs = auto_resize_inputs - self.return_features = return_features - - model = build_sam_vit_b(checkpoint=checkpoint) - - model.prompt_encoder = None - model.mask_decoder = None - - model.mask_decoder = nn.Sequential( - nn.BatchNorm2d(num_features=256), - nn.ReLU(inplace=True), - nn.ConvTranspose2d( - 256, - 128, - kernel_size=3, - stride=2, - padding=1, - output_padding=1, - bias=False, - ), - nn.BatchNorm2d(num_features=128), - nn.ReLU(inplace=True), - nn.ConvTranspose2d( - 128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True - ), - ) - - self.model = model - - def forward(self, x): - # print("CellSamWrapper x0", x.shape) - sh = x.shape[2:] - - if self.auto_resize_inputs: - x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear") - - # print("CellSamWrapper x1", x.shape) - x = self.model.image_encoder(x) # shape: (1, 256, 64, 64) - # print("CellSamWrapper image_embeddings", x.shape) - - if not self.return_features: - x = self.model.mask_decoder(x) - if self.auto_resize_inputs: - x = F.interpolate(x, size=sh, mode="bilinear") - - # print("CellSamWrapper x final", x.shape) - return x