Skip to content

Commit

Permalink
Provide non_maximum_suppression in osam.apis
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Jul 30, 2024
1 parent 8d4d3c5 commit 6379171
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 41 deletions.
43 changes: 2 additions & 41 deletions osam/_models/yoloworld/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import imgviz
import numpy as np
import onnxruntime
from loguru import logger

from ... import apis
from ... import types
from . import clip

Expand Down Expand Up @@ -54,8 +54,7 @@ def generate(self, request: types.GenerateRequest) -> types.GenerateResponse:
max_annotations = (
len(bboxes) if prompt.max_annotations is None else prompt.max_annotations
)
bboxes, scores, labels = _non_maximum_suppression(
inference_session=self._inference_sessions["nms"],
bboxes, scores, labels = apis.non_maximum_suppression(
boxes=bboxes,
scores=scores,
iou_threshold=iou_threshold,
Expand Down Expand Up @@ -98,10 +97,6 @@ class YoloWorldXL(_YoloWorld):
url="https://github.com/wkentaro/yolo-world-onnx/releases/download/v0.1.0/yolo_world_v2_xl_vlpan_bn_2e-3_100e_4x8gpus_obj365v1_goldg_train_lvis_minival.onnx",
hash="sha256:92660c6456766439a2670cf19a8a258ccd3588118622a15959f39e253731c05d",
),
"nms": types.Blob(
url="https://github.com/wkentaro/yolo-world-onnx/releases/download/v0.1.0/non_maximum_suppression.onnx",
hash="sha256:328310ba8fdd386c7ca63fc9df3963cc47b1268909647abd469e8ebdf7f3d20a",
),
}


Expand Down Expand Up @@ -145,37 +140,3 @@ def _untransform_bboxes(
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, original_image_hw[0])
bboxes = bboxes.round().astype(int)
return bboxes


def _non_maximum_suppression(
inference_session: onnxruntime.InferenceSession,
boxes: np.ndarray,
scores: np.ndarray,
iou_threshold: float,
score_threshold: float,
max_num_detections: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
selected_indices = inference_session.run(
output_names=["selected_indices"],
input_feed={
"boxes": boxes[None, :, :],
"scores": scores[None, :, :].transpose(0, 2, 1),
"max_output_boxes_per_class": np.array(
[max_num_detections], dtype=np.int64
),
"iou_threshold": np.array([iou_threshold], dtype=np.float32),
"score_threshold": np.array([score_threshold], dtype=np.float32),
},
)[0]
labels = selected_indices[:, 1]
box_indices = selected_indices[:, 2]
boxes = boxes[box_indices]
scores = scores[box_indices, labels]

if len(boxes) > max_num_detections:
keep_indices = np.argsort(scores)[-max_num_detections:]
boxes = boxes[keep_indices]
scores = scores[keep_indices]
labels = labels[keep_indices]

return boxes, scores, labels
54 changes: 54 additions & 0 deletions osam/apis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type

import numpy as np
import onnxruntime

from . import _models
from . import types

Expand Down Expand Up @@ -42,3 +46,53 @@ def generate(request: types.GenerateRequest) -> types.GenerateResponse:

response: types.GenerateResponse = running_model.generate(request=request)
return response


_non_maximum_suppression_inference_session: Optional[onnxruntime.InferenceSession] = (
None
)


def non_maximum_suppression(
boxes: np.ndarray,
scores: np.ndarray,
iou_threshold: float,
score_threshold: float,
max_num_detections: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
global _non_maximum_suppression_inference_session
if _non_maximum_suppression_inference_session is None:
blob = types.Blob(
url="https://github.com/wkentaro/yolo-world-onnx/releases/download/v0.1.0/non_maximum_suppression.onnx", # noqa
hash="sha256:328310ba8fdd386c7ca63fc9df3963cc47b1268909647abd469e8ebdf7f3d20a",
)
blob.pull()
_non_maximum_suppression_inference_session = onnxruntime.InferenceSession(
blob.path, providers=["CPUExecutionProvider"]
)
inference_session = _non_maximum_suppression_inference_session

selected_indices = inference_session.run(
output_names=["selected_indices"],
input_feed={
"boxes": boxes[None, :, :],
"scores": scores[None, :, :].transpose(0, 2, 1),
"max_output_boxes_per_class": np.array(
[max_num_detections], dtype=np.int64
),
"iou_threshold": np.array([iou_threshold], dtype=np.float32),
"score_threshold": np.array([score_threshold], dtype=np.float32),
},
)[0]
labels = selected_indices[:, 1]
box_indices = selected_indices[:, 2]
boxes = boxes[box_indices]
scores = scores[box_indices, labels]

if len(boxes) > max_num_detections:
keep_indices = np.argsort(scores)[-max_num_detections:]
boxes = boxes[keep_indices]
scores = scores[keep_indices]
labels = labels[keep_indices]

return boxes, scores, labels

0 comments on commit 6379171

Please sign in to comment.