Skip to content

Commit

Permalink
Use osam_core in osam
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Jun 22, 2024
1 parent 6f0cf00 commit adaf2f7
Show file tree
Hide file tree
Showing 15 changed files with 63 additions and 404 deletions.
8 changes: 5 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ all:
@echo
@$(MAKE) -pRrq -f $(lastword $(MAKEFILE_LIST)) : 2>/dev/null | awk -v RS= -F: '/^# File/,/^# Finished Make data base/ {if ($$1 !~ "^[#.]") {print $$1}}' | sort | egrep -v -e '^[^[:alnum:]]' -e '^$@$$' | xargs

PACKAGE_DIR=osam

lint:
mypy --package osam
ruff format --check
ruff check
mypy --package $(PACKAGE_DIR)

format:
ruff format
ruff check --fix

test:
python -m pytest -n auto -v tests
python -m pytest -n auto -v $(PACKAGE_DIR)

clean:
rm -rf build dist *.egg-info
Expand All @@ -22,7 +24,7 @@ build: clean
python -m build --sdist --wheel

upload: build
python -m twine upload dist/osam-*
python -m twine upload dist/$(PACKAGE_DIR)-*

publish: build upload

Expand Down
4 changes: 4 additions & 0 deletions osam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import importlib.metadata

__version__ = importlib.metadata.version("osam")

from . import _models # noqa: F401
from . import apis # noqa: F401
from . import types # noqa: F401
23 changes: 11 additions & 12 deletions osam/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import PIL.Image
import uvicorn
from loguru import logger
from osam_core import apis
from osam_core import types

from osam import __version__
from osam import _humanize
from osam import _models
from osam import _tabulate
from osam import apis
from osam import types
from . import __version__
from . import _humanize
from . import _tabulate


@click.group(context_settings=dict(help_option_names=["-h", "--help"]))
Expand Down Expand Up @@ -51,9 +50,9 @@ def help(ctx, subcommand):
@click.option("--all", "-a", "show_all", is_flag=True, help="show all models")
def list(show_all):
rows = []
for model in _models.MODELS:
size = model.get_size()
modified_at = model.get_modified_at()
for model_type in apis.registered_model_types:
size = model_type.get_size()
modified_at = model_type.get_modified_at()

if size is None or modified_at is None:
if show_all:
Expand All @@ -67,14 +66,14 @@ def list(show_all):
datetime.datetime.fromtimestamp(modified_at)
)

rows.append([model.name, model.get_id(), size, modified_at])
rows.append([model_type.name, model_type.get_id(), size, modified_at])
click.echo(_tabulate.tabulate(rows, headers=["NAME", "ID", "SIZE", "MODIFIED"]))


@cli.command(help="Pull a model")
@click.argument("model_name", metavar="model", type=str)
def pull(model_name):
cls = _models.get_model_class_by_name(model_name)
cls = apis.get_model_type_by_name(model_name)
logger.info("Pulling {model_name!r}...", model_name=model_name)
cls.pull()
logger.info("Pulled {model_name!r}", model_name=model_name)
Expand All @@ -83,7 +82,7 @@ def pull(model_name):
@cli.command(help="Remove a model")
@click.argument("model_name", metavar="model", type=str)
def rm(model_name):
cls = _models.get_model_class_by_name(model_name)
cls = apis.get_model_type_by_name(model_name)
logger.info("Removing {model_name!r}...", model_name=model_name)
cls.remove()
logger.info("Removed {model_name!r}", model_name=model_name)
Expand Down
File renamed without changes.
28 changes: 0 additions & 28 deletions osam/_contextlib.py

This file was deleted.

19 changes: 0 additions & 19 deletions osam/_json.py

This file was deleted.

35 changes: 2 additions & 33 deletions osam/_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,2 @@
from typing import Type

from ._base import ModelBase # noqa: F401
from ._efficient_sam import EfficientSam10m
from ._efficient_sam import EfficientSam25m
from ._sam import Sam91m
from ._sam import Sam308m
from ._sam import Sam636m

# TODO: Provide a better way to register models.
# Currently, we have to manually add the model to the MODELS list.
MODELS = [
EfficientSam10m,
EfficientSam25m,
Sam91m,
Sam308m,
Sam636m,
]


def get_model_class_by_name(name: str) -> Type[ModelBase]:
model_name: str
if ":" in name:
model_name = name
else:
model_name = f"{name}:latest"

for cls in MODELS:
if cls.name == model_name:
break
else:
raise ValueError(f"Model {name!r} not found.")
return cls
from osam._models import _efficient_sam # noqa: F401
from osam._models import _sam # noqa: F401
136 changes: 0 additions & 136 deletions osam/_models/_base.py

This file was deleted.

14 changes: 9 additions & 5 deletions osam/_models/_efficient_sam.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np

from osam._models._base import ModelBase
from osam._models._base import ModelBlob
from osam.types import ImageEmbedding
from osam.types import Prompt
from osam_core import apis
from osam_core.types import ImageEmbedding
from osam_core.types import ModelBase
from osam_core.types import ModelBlob
from osam_core.types import Prompt


class EfficientSam(ModelBase):
Expand Down Expand Up @@ -85,3 +85,7 @@ class EfficientSam25m(EfficientSam):
hash="sha256:4727baf23dacfb51d4c16795b2ac382c403505556d0284e84c6ff3d4e8e36f22",
),
}


apis.register_model_type(EfficientSam10m)
apis.register_model_type(EfficientSam25m)
Loading

0 comments on commit adaf2f7

Please sign in to comment.