Skip to content

Commit

Permalink
Support receiving and returning image embeddings in the generate API
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed May 5, 2024
1 parent 609cbdd commit 6f0cf00
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
8 changes: 4 additions & 4 deletions osam/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def serve(reload):
@click.option("--prompt", type=json.loads, help="prompt")
@click.option("--json", is_flag=True, help="json output")
def run(model_name: str, image_path: str, prompt, json: bool) -> None:
image: np.ndarray = np.asarray(PIL.Image.open(image_path))

try:
request: types.GenerateRequest = types.GenerateRequest(
model=model_name,
image=np.asarray(PIL.Image.open(image_path)),
prompt=prompt,
model=model_name, image=image, prompt=prompt
)
response: types.GenerateResponse = apis.generate(request=request)
except ValueError as e:
Expand All @@ -123,7 +123,7 @@ def run(model_name: str, image_path: str, prompt, json: bool) -> None:
click.echo(response.model_dump_json())
else:
visualization: np.ndarray = (
0.5 * request.image
0.5 * image
+ 0.5
* np.array([0, 255, 0])[None, None, :]
* (response.mask > 0)[:, :, None]
Expand Down
13 changes: 10 additions & 3 deletions osam/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ def generate(request: types.GenerateRequest) -> types.GenerateResponse:
if model is None or model.name != model_cls.name:
model = model_cls()

image: np.ndarray = request.image
if request.image_embedding is None:
if request.image is None:
raise ValueError("Either image_embedding or image must be given")
image: np.ndarray = request.image
image_embedding: types.ImageEmbedding = model.encode_image(image=image)
else:
image_embedding = request.image_embedding

if request.prompt is None:
height, width = image.shape[:2]
Expand All @@ -31,8 +37,9 @@ def generate(request: types.GenerateRequest) -> types.GenerateResponse:
else:
prompt = request.prompt

image_embedding: types.ImageEmbedding = model.encode_image(image=image)
mask: np.ndarray = model.generate_mask(
image_embedding=image_embedding, prompt=prompt
)
return types.GenerateResponse(model=request.model, mask=mask)
return types.GenerateResponse(
model=request.model, mask=mask, image_embedding=image_embedding
)
15 changes: 12 additions & 3 deletions osam/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def validate_embedding(cls, embedding):
)
return embedding

@pydantic.field_serializer("embedding")
def serialize_embedding(self, embedding: np.ndarray) -> List[List[List[float]]]:
return embedding.tolist()


class Prompt(pydantic.BaseModel):
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
Expand Down Expand Up @@ -65,11 +69,15 @@ class GenerateRequest(pydantic.BaseModel):
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)

model: str
image: np.ndarray
image_embedding: Optional[ImageEmbedding] = pydantic.Field(default=None)
image: Optional[np.ndarray] = pydantic.Field(default=None)
prompt: Optional[Prompt] = pydantic.Field(default=None)

@pydantic.validator("image", pre=True)
def validate_image(cls, image: Union[str, np.ndarray]) -> np.ndarray:
@pydantic.field_validator("image", mode="before")
@classmethod
def validate_image(
cls, image: Optional[Union[str, np.ndarray]]
) -> Optional[np.ndarray]:
if isinstance(image, str):
return _json.image_b64data_to_ndarray(b64data=image)
return image
Expand All @@ -79,6 +87,7 @@ class GenerateResponse(pydantic.BaseModel):
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)

model: str
image_embedding: ImageEmbedding
mask: np.ndarray

@pydantic.field_serializer("mask")
Expand Down

0 comments on commit 6f0cf00

Please sign in to comment.