Skip to content

Commit

Permalink
Support <model_name>[:latest] syntax, rename model tags to :latest
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Feb 13, 2024
1 parent d839ce0 commit 003d86e
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 33 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pip install osam
To run with Efficient-SAM:

```bash
osam run efficient-sam:25m --image <image_file>
osam run efficient-sam --image <image_file>
```

## Model library
Expand All @@ -50,24 +50,26 @@ Here are models that can be downloaded:
|-------------------|------------|-------|------------------------------|
| SAM 91M | 91M | 100MB | `osam run sam:91m` |
| SAM 308M | 308M | 320MB | `osam run sam:308m` |
| SAM 636M | 636M | 630MB | `osam run sam:636m` |
| SAM 636M | 636M | 630MB | `osam run sam` |
| Efficient-SAM 10M | 10M | 40MB | `osam run efficient-sam:10m` |
| Efficient-SAM 25M | 25M | 100MB | `osam run efficient-sam:25m` |
| Efficient-SAM 25M | 25M | 100MB | `osam run efficient-sam` |

PS. `sam`, `efficient-sam` is equivalent to `sam:latest`, `efficient-sam:latest`.

## Usage

### CLI

```bash
# Run a model with an image
osam run efficient-sam:25m --image examples/_images/dogs.jpg > output.png
osam run efficient-sam --image examples/_images/dogs.jpg > output.png

# Get a JSON output
osam run efficient-sam:25m --image examples/_images/dogs.jpg --json
# {"model": "efficient-sam:25m", "mask": "..."}
osam run efficient-sam --image examples/_images/dogs.jpg --json
# {"model": "efficient-sam", "mask": "..."}

# Give a prompt
osam run efficient-sam:25m --image examples/_images/dogs.jpg \
osam run efficient-sam --image examples/_images/dogs.jpg \
--prompt '{"points": [[1439, 504], [1439, 1289]], "point_labels": [1, 1]}' > output.png
```

Expand All @@ -81,7 +83,7 @@ import osam.apis
import osam.types

request = osam.types.GenerateRequest(
model="efficient-sam:25m",
model="efficient-sam",
image=np.asarray(PIL.Image.open("examples/_images/dogs.jpg")),
prompt=osam.types.Prompt(points=[[1439, 504], [1439, 1289]], point_labels=[1, 1]),
)
Expand All @@ -100,7 +102,7 @@ osam serve
# POST request
curl 127.0.0.1:11368/api/generate -X POST \
-H "Content-Type: application/json" \
-d "{\"model\": \"efficient-sam:25m\", \"image\": \"$(cat examples/_images/dogs.jpg | base64)\"}" \
-d "{\"model\": \"efficient-sam\", \"image\": \"$(cat examples/_images/dogs.jpg | base64)\"}" \
| jq -r .mask | base64 --decode > mask.png
```

Expand Down
2 changes: 1 addition & 1 deletion examples/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def benchmark(n_times: int):
request = osam.types.GenerateRequest(
model="efficient-sam:25m",
model="efficient-sam",
image=np.asarray(PIL.Image.open("../_images/dogs.jpg")),
prompt={"points": [[1280, 800]], "point_labels": [1]},
)
Expand Down
16 changes: 2 additions & 14 deletions osam/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,7 @@ def list(show_all):
@cli.command(help="Pull a model")
@click.argument("model_name", metavar="model", type=str)
def pull(model_name):
for cls in _models.MODELS:
if cls.name == model_name:
break
else:
logger.warning("Model {model_name!r} not found.", model_name=model_name)
sys.exit(1)

cls = _models.get_model_class_by_name(model_name)
logger.info("Pulling {model_name!r}...", model_name=model_name)
cls.pull()
logger.info("Pulled {model_name!r}", model_name=model_name)
Expand All @@ -89,13 +83,7 @@ def pull(model_name):
@cli.command(help="Remove a model")
@click.argument("model_name", metavar="model", type=str)
def rm(model_name):
for cls in _models.MODELS:
if cls.name == model_name:
break
else:
logger.warning("Model {model_name} not found.", model_name=model_name)
sys.exit(1)

cls = _models.get_model_class_by_name(model_name)
logger.info("Removing {model_name!r}...", model_name=model_name)
cls.remove()
logger.info("Removed {model_name!r}", model_name=model_name)
Expand Down
17 changes: 17 additions & 0 deletions osam/_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Type

from ._base import ModelBase # noqa: F401
from ._efficient_sam import EfficientSam10m
from ._efficient_sam import EfficientSam25m
Expand All @@ -12,3 +14,18 @@
Sam308m,
Sam636m,
]


def get_model_class_by_name(name: str) -> Type[ModelBase]:
model_name: str
if ":" in name:
model_name = name
else:
model_name = f"{name}:latest"

for cls in MODELS:
if cls.name == model_name:
break
else:
raise ValueError(f"Model {name!r} not found.")
return cls
2 changes: 1 addition & 1 deletion osam/_models/_efficient_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class EfficientSam10m(EfficientSam):


class EfficientSam25m(EfficientSam):
name = "efficient-sam:25m"
name = "efficient-sam:latest"

_blobs = {
"encoder": ModelBlob(
Expand Down
2 changes: 1 addition & 1 deletion osam/_models/_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Sam308m(Sam):


class Sam636m(Sam):
name = "sam:636m"
name = "sam:latest"

_blobs = {
"encoder": ModelBlob(
Expand Down
10 changes: 3 additions & 7 deletions osam/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
def generate(request: types.GenerateRequest) -> types.GenerateResponse:
global model

if model is None or model.name != request.model:
for model_cls in _models.MODELS:
if model_cls.name == request.model:
model = model_cls()
break
else:
raise ValueError(f"Model not found: {request.model!r}")
model_cls = _models.get_model_class_by_name(name=request.model)
if model is None or model.name != model_cls.name:
model = model_cls()

image: np.ndarray = request.image

Expand Down

0 comments on commit 003d86e

Please sign in to comment.