Skip to content

Commit

Permalink
encoder and decoder now support diff model types
Browse files Browse the repository at this point in the history
add another two checkpoints
  • Loading branch information
coolzhao committed Jun 24, 2023
1 parent d7e2137 commit 989ba10
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 82 deletions.
Binary file added checkpoint/sam_vit_b_01ec64_no_img_encoder.pth
Binary file not shown.
Binary file added checkpoint/sam_vit_l_0b3195_no_img_encoder.pth
Binary file not shown.
24 changes: 17 additions & 7 deletions tools/SAMTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from qgis.core import QgsRectangle, QgsMessageLog, Qgis
from torch.utils.data import DataLoader
from .torchgeo_sam import SamTestFeatureDataset, SamTestFeatureGeoSampler
from .sam_ext import sam_model_registry_no_encoder, SamPredictorNoImgEncoder
from .sam_ext import build_sam_no_encoder, SamPredictorNoImgEncoder
from .geoTool import LayerExtent, ImageCRSManager
from .canvasTool import SAM_PolygonFeature, Canvas_Rectangle, Canvas_Points
from torchgeo.datasets import BoundingBox, stack_samples
Expand All @@ -21,19 +21,29 @@
class SAM_Model:
def __init__(self, feature_dir, cwd, model_type="vit_h"):
self.feature_dir = feature_dir
self.sam_checkpoint = cwd + "/checkpoint/sam_vit_h_4b8939_no_img_encoder.pth"
self.model_type = model_type
self.sam_checkpoint = {
"vit_h": cwd + "/checkpoint/sam_vit_h_4b8939_no_img_encoder.pth", # vit huge model
"vit_l": cwd + "/checkpoint/sam_vit_l_0b3195_no_img_encoder.pth", # vit large model
"vit_b": cwd + "/checkpoint/sam_vit_b_01ec64_no_img_encoder.pth", # vit base model
}
self.model_type = None
self.img_crs = None
self.extent = None
self.sample_path = None # necessary
self._prepare_data_and_layer()
self.sample_path = None

def _prepare_data_and_layer(self):
"""Prepares data and layer."""
self.test_features = SamTestFeatureDataset(
root=self.feature_dir, bands=None, cache=False) # display(test_imgs.index) #
root=self.feature_dir, bands=None, cache=False)
self.img_crs = str(self.test_features.crs)
# Load sam decoder
sam = sam_model_registry_no_encoder[self.model_type](
checkpoint=self.sam_checkpoint)
self.model_type = self.test_features.model_type
if self.model_type is None:
raise Exception("No sam model type info. found in feature files")

sam = build_sam_no_encoder(
checkpoint=self.sam_checkpoint[self.model_type])
self.predictor = SamPredictorNoImgEncoder(sam)

feature_bounds = self.test_features.index.bounds
Expand Down
41 changes: 10 additions & 31 deletions tools/sam_ext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Modified from sam.build_sam.py
# Modified from sam.build_sam.py

from typing import Tuple, Type
import torch
Expand All @@ -9,38 +9,17 @@
from segment_anything.utils.transforms import ResizeLongestSide
import torch.nn as nn

def build_sam_vit_h_no_encoder(checkpoint=None):
return _build_sam_no_encoder(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)

# build_sam = build_sam_vit_h

sam_model_registry_no_encoder = {
"default": build_sam_vit_h_no_encoder,
"vit_h": build_sam_vit_h_no_encoder,
}

class FakeImageEncoderViT(nn.Module):
def __init__(self, img_size: int = 1024) -> None:
super().__init__()
self.img_size = img_size

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def _build_sam_no_encoder(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
def build_sam_no_encoder(checkpoint=None):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
Expand Down Expand Up @@ -75,20 +54,20 @@ def _build_sam_no_encoder(
sam.load_state_dict(state_dict)
return sam


class SamPredictorNoImgEncoder(SamPredictor):
def __init__(
self,
sam_model: Sam,
# image_encoder_img_size: int = 1024
) -> None:
# super(SamPredictor, self).__init__()
) -> None:
self.model = sam_model
# self.transform = ResizeLongestSide(image_encoder_img_size)
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
self.reset_image()

def set_image_feature(self, img_features: np.ndarray, img_shape: Tuple[int, int]):
self.features = torch.as_tensor(img_features, device=self.device) # .to(device=device)
self.features = torch.as_tensor(
img_features, device=self.device) # .to(device=device)
self.original_size = img_shape
self.input_size = self.transform.get_preprocess_shape(img_shape[0], img_shape[1], self.model.image_encoder.img_size)
self.is_image_set = True
self.input_size = self.transform.get_preprocess_shape(
img_shape[0], img_shape[1], self.model.image_encoder.img_size)
self.is_image_set = True
55 changes: 35 additions & 20 deletions tools/sam_processing_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def initAlgorithm(self, config=None):
)
)

self.model_type_options = ['vit_h', 'vit_l', 'vit_s']
self.model_type_options = ['vit_h', 'vit_l', 'vit_b']
self.addParameter(
QgsProcessingParameterEnum(
name=self.MODEL_TYPE,
Expand Down Expand Up @@ -164,21 +164,21 @@ def processAlgorithm(self, parameters, context, feedback):
stride = self.parameterAsInt(
parameters, self.STRIDE, context)

bbox = self.parameterAsExtent(parameters, self.EXTENT, context)
extent = self.parameterAsExtent(parameters, self.EXTENT, context)
# if bbox.isNull() and not rlayer:
# raise QgsProcessingException(
# self.tr("No reference layer selected nor extent box provided"))

if not bbox.isNull():
bboxCrs = self.parameterAsExtentCrs(
if not extent.isNull():
extentCrs = self.parameterAsExtentCrs(
parameters, self.EXTENT, context)
if bboxCrs != rlayer.crs():
if extentCrs != rlayer.crs():
transform = QgsCoordinateTransform(
bboxCrs, rlayer.crs(), context.transformContext())
bbox = transform.transformBoundingBox(bbox)
extentCrs, rlayer.crs(), context.transformContext())
extent = transform.transformBoundingBox(extent)

if bbox.isNull() and rlayer:
bbox = rlayer.extent() # QgsProcessingUtils.combineLayerExtents(layers, crs, context)
if extent.isNull() and rlayer:
extent = rlayer.extent() # QgsProcessingUtils.combineLayerExtents(layers, crs, context)

# output_dir = self.parameterAsFileOutput(
# parameters, self.OUTPUT, context)
Expand All @@ -197,7 +197,7 @@ def processAlgorithm(self, parameters, context, feedback):
# feedback.pushInfo('Layer display band name is {}'.format(
# rlayer.dataProvider().displayBandName(1)))
feedback.pushInfo(
f'Layer extent: minx:{bbox.xMinimum():.2f}, maxx:{bbox.xMaximum():.2f}, miny:{bbox.yMinimum():.2f}, maxy:{bbox.yMaximum():.2f}')
f'Layer extent: minx:{extent.xMinimum():.2f}, maxx:{extent.xMaximum():.2f}, miny:{extent.yMinimum():.2f}, maxy:{extent.yMaximum():.2f}')

# If sink was not created, throw an exception to indicate that the algorithm
# encountered a fatal error. The exception text can be any string, but in this
Expand Down Expand Up @@ -264,13 +264,16 @@ def processAlgorithm(self, parameters, context, feedback):
f'RasterDS info, input bands: {rlayer_ds.bands}, \n all bands: {rlayer_ds.all_bands}, \
\n raster_ds crs: {rlayer_ds.crs}, \
\n raster_ds index: {rlayer_ds.index}')
roi = BoundingBox(minx=bbox.xMinimum(), maxx=bbox.xMaximum(), miny=bbox.yMinimum(), maxy=bbox.yMaximum(),
mint=rlayer_ds.index.bounds[4], maxt=rlayer_ds.index.bounds[5])
extent_bbox = BoundingBox(minx=extent.xMinimum(), maxx=extent.xMaximum(), miny=extent.yMinimum(), maxy=extent.yMaximum(),
mint=rlayer_ds.index.bounds[4], maxt=rlayer_ds.index.bounds[5])
ds_sampler = TestGridGeoSampler(
rlayer_ds, size=1024, stride=stride, roi=roi, units=Units.PIXELS) # Units.CRS or Units.PIXELS
rlayer_ds, size=1024, stride=stride, roi=extent_bbox, units=Units.PIXELS) # Units.CRS or Units.PIXELS

if len(ds_sampler) == 0:
feedback.pushInfo(
f'No available patch sample inside the chosen extent')
return {'Input layer dir': rlayer_dir, 'Sample num': len(ds_sampler)}

feedback.pushInfo(f'Sample number: {len(ds_sampler)}')

ds_dataloader = DataLoader(
Expand All @@ -295,7 +298,8 @@ def processAlgorithm(self, parameters, context, feedback):
for size in list(features.shape)))
feedback.pushInfo(
f"SAM encoding executed with {elapsed_time:.3f} ms")
self.save_sam_feature(output_dir, batch, features, model_type)
self.save_sam_feature(
output_dir, batch, features, extent_bbox, model_type)

# Update the progress bar
feedback.setProgress(int((current+1) * total))
Expand All @@ -320,7 +324,14 @@ def get_sam_feature(self, sam_model, batch_input):
features = sam_model.image_encoder(batch_input)
return features.cpu().numpy()

def save_sam_feature(self, export_dir_str: str, data_batch: Tensor, feature: np.ndarray, model_type: str = "vit_h"):
def save_sam_feature(
self,
export_dir_str: str,
data_batch: Tensor,
feature: np.ndarray,
extent: BoundingBox,
model_type: str = "vit_h"
):
export_dir = Path(export_dir_str)
# iterate over batch_size dimension
for idx in range(feature.shape[-4]):
Expand All @@ -333,25 +344,29 @@ def save_sam_feature(self, export_dir_str: str, data_batch: Tensor, feature: np.
filepath = Path(data_batch['path'][idx])
bbox = [bbox.minx, bbox.miny, bbox.maxx, bbox.maxy]
bbox_str = '_'.join(map("{:.6f}".format, bbox))
extent = [extent.minx, extent.miny, extent.maxx, extent.maxy]
extent_str = '_'.join(map("{:.6f}".format, extent))
# bbox_hash = hashlib.md5()
# Unicode-objects must be encoded before hashing with hashlib and
# because strings in Python 3 are Unicode by default (unlike Python 2),
# you'll need to encode the string using the .encode method.
# bbox_hash.update(bbox_str.encode("utf-8"))
bbox_hash = hashlib.sha256(bbox_str.encode("utf-8")).hexdigest()
extent_hash = hashlib.sha256(
extent_str.encode("utf-8")).hexdigest()

export_dir_sub = export_dir / filepath.stem
export_dir_sub = (export_dir / filepath.stem /
f"sam_feat_{model_type}_{extent_hash}")
# display(export_dir_sub)
export_dir_sub.mkdir(parents=True, exist_ok=True)
feature_tiff = export_dir_sub / "sam_feat_{model}_{bbox}.tif".format(
model=model_type, bbox=bbox_hash)
feature_tiff = (export_dir_sub /
f"sam_feat_{model_type}_{bbox_hash}.tif")
# print(feature_tiff)
with rasterio.open(
feature_tiff,
mode="w",
driver="GTiff",
height=height,
width=width,
height=height, width=width,
count=band_num,
dtype='float32',
crs=data_batch['crs'][idx],
Expand Down
56 changes: 32 additions & 24 deletions tools/torchgeo_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,25 @@ def __init__(self, root: str = "data",
self.root = root
self.bands = bands or self.all_bands
self.cache = cache
self.model_type = None
model_type_list = ["vit_h", "vit_l", "vit_b"]

# Populate the dataset index
i = 0
# pathname = os.path.join(root, "**", self.filename_glob)
pathname = os.path.join(root, self.filename_glob)
raster_list = glob.glob(pathname, recursive=True)
raster_name = os.path.basename(raster_list[0])
for m_type in model_type_list:
if m_type in raster_name:
self.model_type = m_type
break
dir_name = os.path.basename(root)
csv_filepath = os.path.join(root, dir_name + '.csv')
index_set = False
if os.path.exists(csv_filepath):
self.index_df = pd.read_csv(csv_filepath)
filepath_csv = self.index_df.loc[0, 'filepath']
# filepath_csv = self.index_df.loc[0, 'filepath']
# and os.path.dirname(filepath_csv) == os.path.dirname(raster_list[0]):
if len(self.index_df) == len(raster_list):
for _, row_df in self.index_df.iterrows():
Expand Down Expand Up @@ -180,29 +187,30 @@ def __init__(self, root: str = "data",
# change to relative path
filepath_list.append(os.path.basename(filepath))
i += 1
self.index_df['id'] = id_list
self.index_df['filepath'] = filepath_list
self.index_df['minx'] = pd.to_numeric(
[coord[0] for coord in coords_list], downcast='float')
self.index_df['maxx'] = pd.to_numeric(
[coord[1] for coord in coords_list], downcast='float')
self.index_df['miny'] = pd.to_numeric(
[coord[2] for coord in coords_list], downcast='float')
self.index_df['maxy'] = pd.to_numeric(
[coord[3] for coord in coords_list], downcast='float')
self.index_df['mint'] = pd.to_numeric(
[coord[4] for coord in coords_list], downcast='float')
self.index_df['maxt'] = pd.to_numeric(
[coord[5] for coord in coords_list], downcast='float')
# print(type(crs), res)
self.index_df.loc[:, 'crs'] = str(crs)
self.index_df.loc[:, 'res'] = res
# print(self.index_df.dtypes)
index_set = True
self.index_df.to_csv(csv_filepath)
# print('index file: ', os.path.basename(csv_filepath), ' saved')
QgsMessageLog.logMessage(
f"Index file: {os.path.basename(csv_filepath)} saved", 'Geo SAM', level=Qgis.Info)
if i>0:
self.index_df['id'] = id_list
self.index_df['filepath'] = filepath_list
self.index_df['minx'] = pd.to_numeric(
[coord[0] for coord in coords_list], downcast='float')
self.index_df['maxx'] = pd.to_numeric(
[coord[1] for coord in coords_list], downcast='float')
self.index_df['miny'] = pd.to_numeric(
[coord[2] for coord in coords_list], downcast='float')
self.index_df['maxy'] = pd.to_numeric(
[coord[3] for coord in coords_list], downcast='float')
self.index_df['mint'] = pd.to_numeric(
[coord[4] for coord in coords_list], downcast='float')
self.index_df['maxt'] = pd.to_numeric(
[coord[5] for coord in coords_list], downcast='float')
# print(type(crs), res)
self.index_df.loc[:, 'crs'] = str(crs)
self.index_df.loc[:, 'res'] = res
# print(self.index_df.dtypes)
index_set = True
self.index_df.to_csv(csv_filepath)
# print('index file: ', os.path.basename(csv_filepath), ' saved')
QgsMessageLog.logMessage(
f"Index file: {os.path.basename(csv_filepath)} saved", 'Geo SAM', level=Qgis.Info)

if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`"
Expand Down

0 comments on commit 989ba10

Please sign in to comment.