Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hugging Face integration #45

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion efficient_sam/build_efficient_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .efficient_sam import build_efficient_sam
from efficient_sam import build_efficient_sam

def build_efficient_sam_vitt():
return build_efficient_sam(
Expand Down
107 changes: 104 additions & 3 deletions efficient_sam/efficient_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from torch import nn, Tensor

from .efficient_sam_decoder import MaskDecoder, PromptEncoder
from .efficient_sam_encoder import ImageEncoderViT
from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
from efficient_sam_decoder import MaskDecoder, PromptEncoder
from efficient_sam_encoder import ImageEncoderViT
from two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer

from huggingface_hub import PyTorchModelHubMixin, hf_hub_download

class EfficientSam(nn.Module):
mask_threshold: float = 0.0
Expand Down Expand Up @@ -303,3 +305,102 @@ def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=N
state_dict = torch.load(f, map_location="cpu")
sam.load_state_dict(state_dict["model"])
return sam


class EfficientSAM(EfficientSam, PyTorchModelHubMixin):
def __init__(self, config):

assert config["activation"] in ["relu", "gelu"]
if config["activation"] == "relu":
activation_fn = nn.ReLU
else:
activation_fn = nn.GELU

image_encoder = ImageEncoderViT(
img_size=config["img_size"],
patch_size=config["encoder_patch_size"],
in_chans=3,
patch_embed_dim=config["encoder_patch_embed_dim"],
normalization_type=config["normalization_type"],
depth=config["encoder_depth"],
num_heads=config["encoder_num_heads"],
mlp_ratio=config["encoder_mlp_ratio"],
neck_dims=config["encoder_neck_dims"],
act_layer=activation_fn,
)

image_embedding_size = image_encoder.image_embedding_size
encoder_transformer_output_dim = image_encoder.transformer_output_dim

prompt_encoder = PromptEncoder(
embed_dim=encoder_transformer_output_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(config["img_size"], config["img_size"]),
),
decoder_max_num_input_points=config["decoder_max_num_input_points"],
mask_decoder = MaskDecoder(
transformer_dim=encoder_transformer_output_dim,
transformer=TwoWayTransformer(
depth=config["decoder_transformer_depth"],
embedding_dim=encoder_transformer_output_dim,
num_heads=config["decoder_num_heads"],
mlp_dim=config["decoder_transformer_mlp_dim"],
activation=activation_fn,
normalize_before_activation=config["normalize_before_activation"],
),
num_multimask_outputs=config["num_multimask_outputs"],
activation=activation_fn,
normalization_type=config["normalization_type"],
normalize_before_activation=config["normalize_before_activation"],
iou_head_depth=config["iou_head_depth"] - 1,
iou_head_hidden_dim=config["iou_head_hidden_dim"],
upscaling_layer_dims=config["decoder_upscaling_layer_dims"],
),

super().__init__(image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
decoder_max_num_input_points=decoder_max_num_input_points,
pixel_mean=config["pixel_mean"], pixel_std=config["pixel_std"])


config = dict(img_size = 1024,
encoder_patch_embed_dim=192,
encoder_num_heads=3,
encoder_patch_size = 16,
encoder_depth = 12,
encoder_mlp_ratio = 4.0,
encoder_neck_dims = [256, 256],
decoder_max_num_input_points = 6,
decoder_transformer_depth = 2,
decoder_transformer_mlp_dim = 2048,
decoder_num_heads = 8,
decoder_upscaling_layer_dims = [64, 32],
num_multimask_outputs = 3,
iou_head_depth = 3,
iou_head_hidden_dim = 256,
activation = "gelu",
normalization_type = "layer_norm",
normalize_before_activation = False,
pixel_mean=[0.485, 0.456, 0.406],
pixel_std=[0.229, 0.224, 0.225],)

model = EfficientSAM(config)

# load weights
filepath = hf_hub_download("merve/EfficientSAM", filename="efficient_sam_vitt.pt", repo_type="model")
state_dict = torch.load(filepath, map_location="cpu")

for name, param in state_dict["model"].items():
print(name, param.shape)

model.load_state_dict(state_dict["model"])

# save locally
model.save_pretrained("efficient_sam")

# push to HF hub
model.push_to_hub("nielsr/efficientsam-tiny", config=config)

# reload
model = EfficientSAM.from_pretrained("nielsr/efficientsam-tiny")
2 changes: 1 addition & 1 deletion efficient_sam/efficient_sam_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn as nn
import torch.nn.functional as F

from .mlp import MLPBlock
from mlp import MLPBlock


class PromptEncoder(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion efficient_sam/two_way_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Tuple, Type
import torch
from torch import nn, Tensor
from .mlp import MLPBlock
from mlp import MLPBlock



Expand Down