Skip to content

Commit

Permalink
test_zero_shot_object_detection: use baseline fixture
Browse files Browse the repository at this point in the history
Signed-off-by: U. Artie Eoff <[email protected]>
  • Loading branch information
uartie committed Feb 20, 2025
1 parent d5eba07 commit c680cc9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
10 changes: 10 additions & 0 deletions tests/baselines/fixture/tests/test_zero_shot_object_detection.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"tests/test_zero_shot_object_detection.py::GaudiOWlVITTester::test_no_latency_regression_bf16": {
"gaudi1": {
"latency": 8.460688591003418
},
"gaudi2": {
"latency": 4.213955687819833
}
}
}
22 changes: 13 additions & 9 deletions tests/test_zero_shot_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import habana_frameworks.torch as ht
import numpy as np
import pytest
import requests
import torch
from PIL import Image
Expand All @@ -30,19 +31,19 @@

adapt_transformers_to_gaudi()

if OH_DEVICE_CONTEXT in ["gaudi2"]:
# Gaudi2 CI baselines
LATENCY_OWLVIT_BF16_GRAPH_BASELINE = 4.2139556878198333
else:
# Gaudi1 CI baselines
LATENCY_OWLVIT_BF16_GRAPH_BASELINE = 8.460688591003418


class GaudiOWlVITTester(TestCase):
"""
Tests for Zero Shot Object Detection - OWLVIT
"""

@pytest.fixture(autouse=True)
def _use_(self, baseline):
"""
https://docs.pytest.org/en/stable/how-to/unittest.html#using-autouse-fixtures-and-accessing-other-fixtures
"""
self.baseline = baseline

def prepare_model_and_processor(self):
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to("hpu")
model = model.eval()
Expand Down Expand Up @@ -120,5 +121,8 @@ def test_no_latency_regression_bf16(self):
model_end_time = time.time()
total_model_time = total_model_time + (model_end_time - model_start_time)

latency = total_model_time * 1000 / iterations # in terms of ms
self.assertLessEqual(latency, 1.05 * LATENCY_OWLVIT_BF16_GRAPH_BASELINE)
self.baseline.assertRef(
compare=lambda latency, expect: latency <= (1.05 * expect),
context=[OH_DEVICE_CONTEXT],
latency=total_model_time * 1000 / iterations, # in terms of ms
)

0 comments on commit c680cc9

Please sign in to comment.