diff --git a/checkpoint/sam_vit_b_01ec64_no_img_encoder.pth b/checkpoint/sam_vit_b_01ec64_no_img_encoder.pth new file mode 100644 index 0000000..5ecac88 Binary files /dev/null and b/checkpoint/sam_vit_b_01ec64_no_img_encoder.pth differ diff --git a/checkpoint/sam_vit_l_0b3195_no_img_encoder.pth b/checkpoint/sam_vit_l_0b3195_no_img_encoder.pth new file mode 100644 index 0000000..9b73c4f Binary files /dev/null and b/checkpoint/sam_vit_l_0b3195_no_img_encoder.pth differ diff --git a/tools/SAMTool.py b/tools/SAMTool.py index da2b886..fe2b0b0 100644 --- a/tools/SAMTool.py +++ b/tools/SAMTool.py @@ -11,7 +11,7 @@ from qgis.core import QgsRectangle, QgsMessageLog, Qgis from torch.utils.data import DataLoader from .torchgeo_sam import SamTestFeatureDataset, SamTestFeatureGeoSampler -from .sam_ext import sam_model_registry_no_encoder, SamPredictorNoImgEncoder +from .sam_ext import build_sam_no_encoder, SamPredictorNoImgEncoder from .geoTool import LayerExtent, ImageCRSManager from .canvasTool import SAM_PolygonFeature, Canvas_Rectangle, Canvas_Points from torchgeo.datasets import BoundingBox, stack_samples @@ -21,19 +21,29 @@ class SAM_Model: def __init__(self, feature_dir, cwd, model_type="vit_h"): self.feature_dir = feature_dir - self.sam_checkpoint = cwd + "/checkpoint/sam_vit_h_4b8939_no_img_encoder.pth" - self.model_type = model_type + self.sam_checkpoint = { + "vit_h": cwd + "/checkpoint/sam_vit_h_4b8939_no_img_encoder.pth", # vit huge model + "vit_l": cwd + "/checkpoint/sam_vit_l_0b3195_no_img_encoder.pth", # vit large model + "vit_b": cwd + "/checkpoint/sam_vit_b_01ec64_no_img_encoder.pth", # vit base model + } + self.model_type = None + self.img_crs = None + self.extent = None + self.sample_path = None # necessary self._prepare_data_and_layer() - self.sample_path = None def _prepare_data_and_layer(self): """Prepares data and layer.""" self.test_features = SamTestFeatureDataset( - root=self.feature_dir, bands=None, cache=False) # display(test_imgs.index) # + root=self.feature_dir, bands=None, cache=False) self.img_crs = str(self.test_features.crs) # Load sam decoder - sam = sam_model_registry_no_encoder[self.model_type]( - checkpoint=self.sam_checkpoint) + self.model_type = self.test_features.model_type + if self.model_type is None: + raise Exception("No sam model type info. found in feature files") + + sam = build_sam_no_encoder( + checkpoint=self.sam_checkpoint[self.model_type]) self.predictor = SamPredictorNoImgEncoder(sam) feature_bounds = self.test_features.index.bounds diff --git a/tools/sam_ext.py b/tools/sam_ext.py index 8f90a6a..641caa7 100644 --- a/tools/sam_ext.py +++ b/tools/sam_ext.py @@ -1,4 +1,4 @@ -## Modified from sam.build_sam.py +# Modified from sam.build_sam.py from typing import Tuple, Type import torch @@ -9,38 +9,17 @@ from segment_anything.utils.transforms import ResizeLongestSide import torch.nn as nn -def build_sam_vit_h_no_encoder(checkpoint=None): - return _build_sam_no_encoder( - encoder_embed_dim=1280, - encoder_depth=32, - encoder_num_heads=16, - encoder_global_attn_indexes=[7, 15, 23, 31], - checkpoint=checkpoint, - ) - -# build_sam = build_sam_vit_h - -sam_model_registry_no_encoder = { - "default": build_sam_vit_h_no_encoder, - "vit_h": build_sam_vit_h_no_encoder, -} class FakeImageEncoderViT(nn.Module): def __init__(self, img_size: int = 1024) -> None: super().__init__() self.img_size = img_size - + def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def _build_sam_no_encoder( - encoder_embed_dim, - encoder_depth, - encoder_num_heads, - encoder_global_attn_indexes, - checkpoint=None, - ): +def build_sam_no_encoder(checkpoint=None): prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 @@ -75,20 +54,20 @@ def _build_sam_no_encoder( sam.load_state_dict(state_dict) return sam + class SamPredictorNoImgEncoder(SamPredictor): def __init__( self, sam_model: Sam, - # image_encoder_img_size: int = 1024 - ) -> None: - # super(SamPredictor, self).__init__() + ) -> None: self.model = sam_model - # self.transform = ResizeLongestSide(image_encoder_img_size) self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) self.reset_image() def set_image_feature(self, img_features: np.ndarray, img_shape: Tuple[int, int]): - self.features = torch.as_tensor(img_features, device=self.device) # .to(device=device) + self.features = torch.as_tensor( + img_features, device=self.device) # .to(device=device) self.original_size = img_shape - self.input_size = self.transform.get_preprocess_shape(img_shape[0], img_shape[1], self.model.image_encoder.img_size) - self.is_image_set = True \ No newline at end of file + self.input_size = self.transform.get_preprocess_shape( + img_shape[0], img_shape[1], self.model.image_encoder.img_size) + self.is_image_set = True diff --git a/tools/sam_processing_algorithm.py b/tools/sam_processing_algorithm.py index 757730e..86d94d9 100644 --- a/tools/sam_processing_algorithm.py +++ b/tools/sam_processing_algorithm.py @@ -110,7 +110,7 @@ def initAlgorithm(self, config=None): ) ) - self.model_type_options = ['vit_h', 'vit_l', 'vit_s'] + self.model_type_options = ['vit_h', 'vit_l', 'vit_b'] self.addParameter( QgsProcessingParameterEnum( name=self.MODEL_TYPE, @@ -164,21 +164,21 @@ def processAlgorithm(self, parameters, context, feedback): stride = self.parameterAsInt( parameters, self.STRIDE, context) - bbox = self.parameterAsExtent(parameters, self.EXTENT, context) + extent = self.parameterAsExtent(parameters, self.EXTENT, context) # if bbox.isNull() and not rlayer: # raise QgsProcessingException( # self.tr("No reference layer selected nor extent box provided")) - if not bbox.isNull(): - bboxCrs = self.parameterAsExtentCrs( + if not extent.isNull(): + extentCrs = self.parameterAsExtentCrs( parameters, self.EXTENT, context) - if bboxCrs != rlayer.crs(): + if extentCrs != rlayer.crs(): transform = QgsCoordinateTransform( - bboxCrs, rlayer.crs(), context.transformContext()) - bbox = transform.transformBoundingBox(bbox) + extentCrs, rlayer.crs(), context.transformContext()) + extent = transform.transformBoundingBox(extent) - if bbox.isNull() and rlayer: - bbox = rlayer.extent() # QgsProcessingUtils.combineLayerExtents(layers, crs, context) + if extent.isNull() and rlayer: + extent = rlayer.extent() # QgsProcessingUtils.combineLayerExtents(layers, crs, context) # output_dir = self.parameterAsFileOutput( # parameters, self.OUTPUT, context) @@ -197,7 +197,7 @@ def processAlgorithm(self, parameters, context, feedback): # feedback.pushInfo('Layer display band name is {}'.format( # rlayer.dataProvider().displayBandName(1))) feedback.pushInfo( - f'Layer extent: minx:{bbox.xMinimum():.2f}, maxx:{bbox.xMaximum():.2f}, miny:{bbox.yMinimum():.2f}, maxy:{bbox.yMaximum():.2f}') + f'Layer extent: minx:{extent.xMinimum():.2f}, maxx:{extent.xMaximum():.2f}, miny:{extent.yMinimum():.2f}, maxy:{extent.yMaximum():.2f}') # If sink was not created, throw an exception to indicate that the algorithm # encountered a fatal error. The exception text can be any string, but in this @@ -264,13 +264,16 @@ def processAlgorithm(self, parameters, context, feedback): f'RasterDS info, input bands: {rlayer_ds.bands}, \n all bands: {rlayer_ds.all_bands}, \ \n raster_ds crs: {rlayer_ds.crs}, \ \n raster_ds index: {rlayer_ds.index}') - roi = BoundingBox(minx=bbox.xMinimum(), maxx=bbox.xMaximum(), miny=bbox.yMinimum(), maxy=bbox.yMaximum(), - mint=rlayer_ds.index.bounds[4], maxt=rlayer_ds.index.bounds[5]) + extent_bbox = BoundingBox(minx=extent.xMinimum(), maxx=extent.xMaximum(), miny=extent.yMinimum(), maxy=extent.yMaximum(), + mint=rlayer_ds.index.bounds[4], maxt=rlayer_ds.index.bounds[5]) ds_sampler = TestGridGeoSampler( - rlayer_ds, size=1024, stride=stride, roi=roi, units=Units.PIXELS) # Units.CRS or Units.PIXELS + rlayer_ds, size=1024, stride=stride, roi=extent_bbox, units=Units.PIXELS) # Units.CRS or Units.PIXELS if len(ds_sampler) == 0: + feedback.pushInfo( + f'No available patch sample inside the chosen extent') return {'Input layer dir': rlayer_dir, 'Sample num': len(ds_sampler)} + feedback.pushInfo(f'Sample number: {len(ds_sampler)}') ds_dataloader = DataLoader( @@ -295,7 +298,8 @@ def processAlgorithm(self, parameters, context, feedback): for size in list(features.shape))) feedback.pushInfo( f"SAM encoding executed with {elapsed_time:.3f} ms") - self.save_sam_feature(output_dir, batch, features, model_type) + self.save_sam_feature( + output_dir, batch, features, extent_bbox, model_type) # Update the progress bar feedback.setProgress(int((current+1) * total)) @@ -320,7 +324,14 @@ def get_sam_feature(self, sam_model, batch_input): features = sam_model.image_encoder(batch_input) return features.cpu().numpy() - def save_sam_feature(self, export_dir_str: str, data_batch: Tensor, feature: np.ndarray, model_type: str = "vit_h"): + def save_sam_feature( + self, + export_dir_str: str, + data_batch: Tensor, + feature: np.ndarray, + extent: BoundingBox, + model_type: str = "vit_h" + ): export_dir = Path(export_dir_str) # iterate over batch_size dimension for idx in range(feature.shape[-4]): @@ -333,25 +344,29 @@ def save_sam_feature(self, export_dir_str: str, data_batch: Tensor, feature: np. filepath = Path(data_batch['path'][idx]) bbox = [bbox.minx, bbox.miny, bbox.maxx, bbox.maxy] bbox_str = '_'.join(map("{:.6f}".format, bbox)) + extent = [extent.minx, extent.miny, extent.maxx, extent.maxy] + extent_str = '_'.join(map("{:.6f}".format, extent)) # bbox_hash = hashlib.md5() # Unicode-objects must be encoded before hashing with hashlib and # because strings in Python 3 are Unicode by default (unlike Python 2), # you'll need to encode the string using the .encode method. # bbox_hash.update(bbox_str.encode("utf-8")) bbox_hash = hashlib.sha256(bbox_str.encode("utf-8")).hexdigest() + extent_hash = hashlib.sha256( + extent_str.encode("utf-8")).hexdigest() - export_dir_sub = export_dir / filepath.stem + export_dir_sub = (export_dir / filepath.stem / + f"sam_feat_{model_type}_{extent_hash}") # display(export_dir_sub) export_dir_sub.mkdir(parents=True, exist_ok=True) - feature_tiff = export_dir_sub / "sam_feat_{model}_{bbox}.tif".format( - model=model_type, bbox=bbox_hash) + feature_tiff = (export_dir_sub / + f"sam_feat_{model_type}_{bbox_hash}.tif") # print(feature_tiff) with rasterio.open( feature_tiff, mode="w", driver="GTiff", - height=height, - width=width, + height=height, width=width, count=band_num, dtype='float32', crs=data_batch['crs'][idx], diff --git a/tools/torchgeo_sam.py b/tools/torchgeo_sam.py index 12df512..405847e 100644 --- a/tools/torchgeo_sam.py +++ b/tools/torchgeo_sam.py @@ -96,18 +96,25 @@ def __init__(self, root: str = "data", self.root = root self.bands = bands or self.all_bands self.cache = cache + self.model_type = None + model_type_list = ["vit_h", "vit_l", "vit_b"] # Populate the dataset index i = 0 # pathname = os.path.join(root, "**", self.filename_glob) pathname = os.path.join(root, self.filename_glob) raster_list = glob.glob(pathname, recursive=True) + raster_name = os.path.basename(raster_list[0]) + for m_type in model_type_list: + if m_type in raster_name: + self.model_type = m_type + break dir_name = os.path.basename(root) csv_filepath = os.path.join(root, dir_name + '.csv') index_set = False if os.path.exists(csv_filepath): self.index_df = pd.read_csv(csv_filepath) - filepath_csv = self.index_df.loc[0, 'filepath'] + # filepath_csv = self.index_df.loc[0, 'filepath'] # and os.path.dirname(filepath_csv) == os.path.dirname(raster_list[0]): if len(self.index_df) == len(raster_list): for _, row_df in self.index_df.iterrows(): @@ -180,29 +187,30 @@ def __init__(self, root: str = "data", # change to relative path filepath_list.append(os.path.basename(filepath)) i += 1 - self.index_df['id'] = id_list - self.index_df['filepath'] = filepath_list - self.index_df['minx'] = pd.to_numeric( - [coord[0] for coord in coords_list], downcast='float') - self.index_df['maxx'] = pd.to_numeric( - [coord[1] for coord in coords_list], downcast='float') - self.index_df['miny'] = pd.to_numeric( - [coord[2] for coord in coords_list], downcast='float') - self.index_df['maxy'] = pd.to_numeric( - [coord[3] for coord in coords_list], downcast='float') - self.index_df['mint'] = pd.to_numeric( - [coord[4] for coord in coords_list], downcast='float') - self.index_df['maxt'] = pd.to_numeric( - [coord[5] for coord in coords_list], downcast='float') - # print(type(crs), res) - self.index_df.loc[:, 'crs'] = str(crs) - self.index_df.loc[:, 'res'] = res - # print(self.index_df.dtypes) - index_set = True - self.index_df.to_csv(csv_filepath) - # print('index file: ', os.path.basename(csv_filepath), ' saved') - QgsMessageLog.logMessage( - f"Index file: {os.path.basename(csv_filepath)} saved", 'Geo SAM', level=Qgis.Info) + if i>0: + self.index_df['id'] = id_list + self.index_df['filepath'] = filepath_list + self.index_df['minx'] = pd.to_numeric( + [coord[0] for coord in coords_list], downcast='float') + self.index_df['maxx'] = pd.to_numeric( + [coord[1] for coord in coords_list], downcast='float') + self.index_df['miny'] = pd.to_numeric( + [coord[2] for coord in coords_list], downcast='float') + self.index_df['maxy'] = pd.to_numeric( + [coord[3] for coord in coords_list], downcast='float') + self.index_df['mint'] = pd.to_numeric( + [coord[4] for coord in coords_list], downcast='float') + self.index_df['maxt'] = pd.to_numeric( + [coord[5] for coord in coords_list], downcast='float') + # print(type(crs), res) + self.index_df.loc[:, 'crs'] = str(crs) + self.index_df.loc[:, 'res'] = res + # print(self.index_df.dtypes) + index_set = True + self.index_df.to_csv(csv_filepath) + # print('index file: ', os.path.basename(csv_filepath), ' saved') + QgsMessageLog.logMessage( + f"Index file: {os.path.basename(csv_filepath)} saved", 'Geo SAM', level=Qgis.Info) if i == 0: msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`"