Skip to content

Commit

Permalink
Move back all modules to osam
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Jul 30, 2024
1 parent 2b833b7 commit 8f9929c
Show file tree
Hide file tree
Showing 29 changed files with 1,196 additions and 17 deletions.
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ all:

PACKAGE_DIR=osam

mypy:
mypy --package $(PACKAGE_DIR)

lint:
ruff format --check
ruff check
mypy --package $(PACKAGE_DIR)

format:
ruff format
Expand Down
1 change: 0 additions & 1 deletion osam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

__version__ = importlib.metadata.version("osam")

from . import _models # noqa: F401
from . import apis # noqa: F401
from . import types # noqa: F401
8 changes: 5 additions & 3 deletions osam/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import PIL.Image
import uvicorn
from loguru import logger
from osam_core import apis
from osam_core import types

from . import __version__
from . import _humanize
from . import _tabulate
from . import apis
from . import types


@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
Expand Down Expand Up @@ -131,7 +131,9 @@ def run(model_name: str, image_path: str, prompt, json: bool) -> None:

if request.prompt and request.prompt.texts is not None:
labels = [
1 + request.prompt.texts.index(annotation.text)
0
if annotation.text is None
else 1 + request.prompt.texts.index(annotation.text)
for annotation in response.annotations
]
else:
Expand Down
28 changes: 28 additions & 0 deletions osam/_contextlib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import contextlib
import os
import sys
import tempfile


@contextlib.contextmanager
def suppress():
original_stdout_fd = os.dup(sys.stdout.fileno())
original_stderr_fd = os.dup(sys.stderr.fileno())

with tempfile.TemporaryFile(mode="w+b") as temp_stdout, tempfile.TemporaryFile(
mode="w+b"
) as temp_stderr:
os.dup2(temp_stdout.fileno(), sys.stdout.fileno())
os.dup2(temp_stderr.fileno(), sys.stderr.fileno())

try:
yield
finally:
sys.stdout.flush()
sys.stderr.flush()

os.dup2(original_stdout_fd, sys.stdout.fileno())
os.dup2(original_stderr_fd, sys.stderr.fileno())

os.close(original_stdout_fd)
os.close(original_stderr_fd)
19 changes: 19 additions & 0 deletions osam/_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import base64
import io

import numpy as np
import PIL.Image


def image_ndarray_to_b64data(ndarray):
pil = PIL.Image.fromarray(ndarray)
f = io.BytesIO()
pil.save(f, format="PNG")
data = f.getvalue()
return base64.b64encode(data).decode("utf-8")


def image_b64data_to_ndarray(b64data):
data = base64.b64decode(b64data)
pil = PIL.Image.open(io.BytesIO(data))
return np.asarray(pil)
25 changes: 25 additions & 0 deletions osam/_json_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import pytest

from . import _json


@pytest.fixture
def image():
y, x = np.meshgrid(np.arange(10), np.arange(10))
center = (np.array(x.shape) - 1) / 2
image_float = np.exp(-((x - center[1]) ** 2 + (y - center[0]) ** 2) / 10)
image = (image_float * 255).astype(np.uint8)
return image


def test_image_ndarray_to_b64data(image):
b64data = _json.image_ndarray_to_b64data(image)
assert isinstance(b64data, str)
assert len(b64data) == 204


def test_image_b64data_to_ndarray(image):
b64data = _json.image_ndarray_to_b64data(image)
image_recovered = _json.image_b64data_to_ndarray(b64data)
np.testing.assert_array_equal(image, image_recovered)
4 changes: 0 additions & 4 deletions osam/_models.py

This file was deleted.

6 changes: 6 additions & 0 deletions osam/_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .efficientsam import EfficientSam10m # noqa: F401
from .efficientsam import EfficientSam30m # noqa: F401
from .sam import Sam100m # noqa: F401
from .sam import Sam300m # noqa: F401
from .sam import Sam600m # noqa: F401
from .yoloworld import YoloWorldXL # noqa: F401
126 changes: 126 additions & 0 deletions osam/_models/efficientsam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import imgviz
import numpy as np
from loguru import logger

from ... import types


class EfficientSam(types.Model):
def encode_image(self, image: np.ndarray) -> types.ImageEmbedding:
if image.ndim == 2:
raise ValueError("Grayscale images are not supported")
if image.ndim == 3 and image.shape[2] == 4:
raise ValueError("RGBA images are not supported")

batched_images = image.transpose(2, 0, 1)[None].astype(np.float32) / 255
image_embedding = self._inference_sessions["encoder"].run(
output_names=None,
input_feed={"batched_images": batched_images},
)[0][0] # (embedding_dim, height, width)

return types.ImageEmbedding(
original_height=image.shape[0],
original_width=image.shape[1],
embedding=image_embedding,
)

def generate(self, request: types.GenerateRequest) -> types.GenerateResponse:
if request.image_embedding is None:
if request.image is None:
raise ValueError("request.image or request.image_embedding is required")
image_embedding = self.encode_image(request.image)
else:
image_embedding = request.image_embedding

if request.prompt is None:
prompt = types.Prompt(
points=np.array(
[
[
image_embedding.original_width / 2,
image_embedding.original_height / 2,
]
],
dtype=np.float32,
),
point_labels=np.array([1], dtype=np.int32),
)
logger.warning(
"Prompt is not given, so using the center point as prompt: {prompt!r}",
prompt=prompt,
)
else:
prompt = request.prompt
del request

if prompt.points is None or prompt.point_labels is None:
raise ValueError("Prompt must contain points and point_labels: %r", prompt)

input_point = np.array(prompt.points, dtype=np.float32)
input_label = np.array(prompt.point_labels, dtype=np.float32)

# batch_size, embedding_dim, height, width
batched_image_embedding = image_embedding.embedding[None, :, :, :]
# batch_size, num_queries, num_points, 2
batched_point_coords = input_point[None, None, :, :]
# batch_size, num_queries, num_points
batched_point_labels = input_label[None, None, :]

decoder_inputs = {
"image_embeddings": batched_image_embedding,
"batched_point_coords": batched_point_coords,
"batched_point_labels": batched_point_labels,
"orig_im_size": np.array(
(image_embedding.original_height, image_embedding.original_width),
dtype=np.int64,
),
}

masks, _, _ = self._inference_sessions["decoder"].run(None, decoder_inputs)
mask = masks[0, 0, 0, :, :] # (1, 1, 3, H, W) -> (H, W)
mask = mask > 0.0

bbox = imgviz.instances.mask_to_bbox([mask])[0].astype(int)

return types.GenerateResponse(
model=self.name,
image_embedding=image_embedding,
annotations=[
types.Annotation(
mask=mask,
bounding_box=types.BoundingBox(
ymin=bbox[0], xmin=bbox[1], ymax=bbox[2], xmax=bbox[3]
),
)
],
)


class EfficientSam10m(EfficientSam):
name = "efficientsam:10m"

_blobs = {
"encoder": types.Blob(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_encoder.onnx",
hash="sha256:7a73ee65aa2c37237c89b4b18e73082f757ffb173899609c5d97a2bbd4ebb02d",
),
"decoder": types.Blob(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vitt_decoder.onnx",
hash="sha256:e1afe46232c3bfa3470a6a81c7d3181836a94ea89528aff4e0f2d2c611989efd",
),
}


class EfficientSam30m(EfficientSam):
name = "efficientsam:latest"

_blobs = {
"encoder": types.Blob(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_encoder.onnx",
hash="sha256:4cacbb23c6903b1acf87f1d77ed806b840800c5fcd4ac8f650cbffed474b8896",
),
"decoder": types.Blob(
url="https://github.com/labelmeai/efficient-sam/releases/download/onnx-models-20231225/efficient_sam_vits_decoder.onnx",
hash="sha256:4727baf23dacfb51d4c16795b2ac382c403505556d0284e84c6ff3d4e8e36f22",
),
}
Loading

0 comments on commit 8f9929c

Please sign in to comment.