diff --git a/segment_anything/build_sam.py b/segment_anything/build_sam.py index 37cd24512..85c0472c0 100644 --- a/segment_anything/build_sam.py +++ b/segment_anything/build_sam.py @@ -11,36 +11,39 @@ from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer -def build_sam_vit_h(checkpoint=None): +def build_sam_vit_h(checkpoint=None, device='cpu'): return _build_sam( encoder_embed_dim=1280, encoder_depth=32, encoder_num_heads=16, encoder_global_attn_indexes=[7, 15, 23, 31], checkpoint=checkpoint, + device=device ) build_sam = build_sam_vit_h -def build_sam_vit_l(checkpoint=None): +def build_sam_vit_l(checkpoint=None, device='cpu'): return _build_sam( encoder_embed_dim=1024, encoder_depth=24, encoder_num_heads=16, encoder_global_attn_indexes=[5, 11, 17, 23], checkpoint=checkpoint, + device=device ) -def build_sam_vit_b(checkpoint=None): +def build_sam_vit_b(checkpoint=None, device='cpu'): return _build_sam( encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12, encoder_global_attn_indexes=[2, 5, 8, 11], checkpoint=checkpoint, + device=device ) @@ -58,6 +61,7 @@ def _build_sam( encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, + device='cpu' ): prompt_embed_dim = 256 image_size = 1024 @@ -99,6 +103,7 @@ def _build_sam( pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], ) + sam.to(device) sam.eval() if checkpoint is not None: with open(checkpoint, "rb") as f: