From 6f0cf009dfe6fcb3b4b1cf89b545bb0a59969bf3 Mon Sep 17 00:00:00 2001 From: Kentaro Wada Date: Sun, 5 May 2024 09:31:58 +0900 Subject: [PATCH] Support receiving and returning image embeddings in the generate API --- osam/__main__.py | 8 ++++---- osam/apis.py | 13 ++++++++++--- osam/types.py | 15 ++++++++++++--- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/osam/__main__.py b/osam/__main__.py index 8145e7b..37c6d20 100644 --- a/osam/__main__.py +++ b/osam/__main__.py @@ -108,11 +108,11 @@ def serve(reload): @click.option("--prompt", type=json.loads, help="prompt") @click.option("--json", is_flag=True, help="json output") def run(model_name: str, image_path: str, prompt, json: bool) -> None: + image: np.ndarray = np.asarray(PIL.Image.open(image_path)) + try: request: types.GenerateRequest = types.GenerateRequest( - model=model_name, - image=np.asarray(PIL.Image.open(image_path)), - prompt=prompt, + model=model_name, image=image, prompt=prompt ) response: types.GenerateResponse = apis.generate(request=request) except ValueError as e: @@ -123,7 +123,7 @@ def run(model_name: str, image_path: str, prompt, json: bool) -> None: click.echo(response.model_dump_json()) else: visualization: np.ndarray = ( - 0.5 * request.image + 0.5 * image + 0.5 * np.array([0, 255, 0])[None, None, :] * (response.mask > 0)[:, :, None] diff --git a/osam/apis.py b/osam/apis.py index f5004b9..df6043a 100644 --- a/osam/apis.py +++ b/osam/apis.py @@ -16,7 +16,13 @@ def generate(request: types.GenerateRequest) -> types.GenerateResponse: if model is None or model.name != model_cls.name: model = model_cls() - image: np.ndarray = request.image + if request.image_embedding is None: + if request.image is None: + raise ValueError("Either image_embedding or image must be given") + image: np.ndarray = request.image + image_embedding: types.ImageEmbedding = model.encode_image(image=image) + else: + image_embedding = request.image_embedding if request.prompt is None: height, width = image.shape[:2] @@ -31,8 +37,9 @@ def generate(request: types.GenerateRequest) -> types.GenerateResponse: else: prompt = request.prompt - image_embedding: types.ImageEmbedding = model.encode_image(image=image) mask: np.ndarray = model.generate_mask( image_embedding=image_embedding, prompt=prompt ) - return types.GenerateResponse(model=request.model, mask=mask) + return types.GenerateResponse( + model=request.model, mask=mask, image_embedding=image_embedding + ) diff --git a/osam/types.py b/osam/types.py index bb94222..b8d688e 100644 --- a/osam/types.py +++ b/osam/types.py @@ -23,6 +23,10 @@ def validate_embedding(cls, embedding): ) return embedding + @pydantic.field_serializer("embedding") + def serialize_embedding(self, embedding: np.ndarray) -> List[List[List[float]]]: + return embedding.tolist() + class Prompt(pydantic.BaseModel): model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) @@ -65,11 +69,15 @@ class GenerateRequest(pydantic.BaseModel): model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) model: str - image: np.ndarray + image_embedding: Optional[ImageEmbedding] = pydantic.Field(default=None) + image: Optional[np.ndarray] = pydantic.Field(default=None) prompt: Optional[Prompt] = pydantic.Field(default=None) - @pydantic.validator("image", pre=True) - def validate_image(cls, image: Union[str, np.ndarray]) -> np.ndarray: + @pydantic.field_validator("image", mode="before") + @classmethod + def validate_image( + cls, image: Optional[Union[str, np.ndarray]] + ) -> Optional[np.ndarray]: if isinstance(image, str): return _json.image_b64data_to_ndarray(b64data=image) return image @@ -79,6 +87,7 @@ class GenerateResponse(pydantic.BaseModel): model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) model: str + image_embedding: ImageEmbedding mask: np.ndarray @pydantic.field_serializer("mask")