From af044bfd7c0b2bdc7aa3b56a6db40731d74365fb Mon Sep 17 00:00:00 2001 From: "U. Artie Eoff" Date: Tue, 18 Feb 2025 17:19:56 -0500 Subject: [PATCH] tests: use OH_DEVICE_CONTEXT instead of GAUDI2_CI ... in preparations for future gaudi3 specific cases and refs. Signed-off-by: U. Artie Eoff --- tests/test_bnb_inference.py | 5 ++-- tests/test_bnb_qlora.py | 5 ++-- tests/test_custom_file_input.py | 4 ++- tests/test_diffusers.py | 12 ++++---- tests/test_encoder_decoder.py | 8 ++---- tests/test_examples.py | 5 ++-- tests/test_feature_extraction.py | 5 ++-- tests/test_fp8_examples.py | 10 +++---- tests/test_fsdp_examples.py | 11 ++++---- ...test_functional_text_generation_example.py | 4 ++- tests/test_image_to_text_example.py | 7 ++--- tests/test_object_detection.py | 3 +- tests/test_openclip_vqa.py | 5 ++-- tests/test_sentence_transformers.py | 5 ++-- tests/test_table_transformer.py | 5 ++-- tests/test_text_generation_example.py | 28 ++++++++++--------- tests/test_video_llava.py | 3 +- tests/test_video_mae.py | 5 ++-- tests/test_zero_shot_object_detection.py | 5 ++-- 19 files changed, 71 insertions(+), 64 deletions(-) diff --git a/tests/test_bnb_inference.py b/tests/test_bnb_inference.py index 9218869669..1768efacd9 100644 --- a/tests/test_bnb_inference.py +++ b/tests/test_bnb_inference.py @@ -14,7 +14,6 @@ # limitations under the License. import copy -import os import torch from habana_frameworks.torch.hpu import wrap_in_hpu_graph @@ -22,10 +21,12 @@ from optimum.habana.transformers import modeling_utils +from .utils import OH_DEVICE_CONTEXT + modeling_utils.adapt_transformers_to_gaudi() -assert os.environ.get("GAUDI2_CI", "0") == "1", "Execution does not support on Gaudi1" +assert OH_DEVICE_CONTEXT != "gaudi1", "Execution does not support on Gaudi1" MODEL_ID = "meta-llama/Llama-3.2-1B" diff --git a/tests/test_bnb_qlora.py b/tests/test_bnb_qlora.py index ac33a74ee1..95c531a3dd 100644 --- a/tests/test_bnb_qlora.py +++ b/tests/test_bnb_qlora.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import subprocess import pytest @@ -24,10 +23,12 @@ from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments from optimum.habana.transformers import modeling_utils +from .utils import OH_DEVICE_CONTEXT + modeling_utils.adapt_transformers_to_gaudi() -assert os.environ.get("GAUDI2_CI", "0") == "1", "Execution does not support on Gaudi1" +assert OH_DEVICE_CONTEXT != "gaudi1", "Execution does not support on Gaudi1" try: import sys diff --git a/tests/test_custom_file_input.py b/tests/test_custom_file_input.py index 1fb0e0e7fd..14dc84058e 100644 --- a/tests/test_custom_file_input.py +++ b/tests/test_custom_file_input.py @@ -8,10 +8,12 @@ import pytest from transformers.testing_utils import slow +from .utils import OH_DEVICE_CONTEXT + PATH_TO_RESOURCES = Path(__file__).resolve().parent.parent / "tests/resource" -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: MODEL_FILE_OPTIONS_TO_TEST = { "bf16": [ ( diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 396dc8f35e..fe3a9aba11 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -125,12 +125,12 @@ from optimum.habana.utils import set_seed from .clip_coco_utils import calculate_clip_score, download_files +from .utils import OH_DEVICE_CONTEXT -IS_GAUDI2 = os.environ.get("GAUDI2_CI", "0") == "1" +IS_GAUDI1 = bool("gaudi1" == OH_DEVICE_CONTEXT) - -if IS_GAUDI2: +if OH_DEVICE_CONTEXT in ["gaudi2"]: THROUGHPUT_BASELINE_BF16 = 1.086 THROUGHPUT_BASELINE_AUTOCAST = 0.394 TEXTUAL_INVERSION_THROUGHPUT = 131.7606336456344 @@ -1695,7 +1695,7 @@ def test_fused_qkv_projections(self): @slow @check_gated_model_access("stabilityai/stable-diffusion-3-medium-diffusers") - @pytest.mark.skipif(not IS_GAUDI2, reason="does not fit into Gaudi1 memory") + @pytest.mark.skipif(IS_GAUDI1, reason="does not fit into Gaudi1 memory") def test_sd3_inference(self): repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" @@ -5985,7 +5985,7 @@ def test_flux_prompt_embeds(self): assert max_diff < 1e-4 @slow - @pytest.mark.skipif(not IS_GAUDI2, reason="does not fit into Gaudi1 memory") + @pytest.mark.skipif(IS_GAUDI1, reason="does not fit into Gaudi1 memory") def test_flux_inference(self): prompts = [ "A cat holding a sign that says hello world", @@ -6154,7 +6154,7 @@ def test_flux_prompt_embeds(self): @slow @check_gated_model_access("black-forest-labs/FLUX.1-dev") - @pytest.mark.skipif(not IS_GAUDI2, reason="does not fit into Gaudi1 memory") + @pytest.mark.skipif(IS_GAUDI1, reason="does not fit into Gaudi1 memory") def test_flux_img2img_inference(self): repo_id = "black-forest-labs/FLUX.1-dev" image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" diff --git a/tests/test_encoder_decoder.py b/tests/test_encoder_decoder.py index b9be7b77f6..ceac8fc54f 100644 --- a/tests/test_encoder_decoder.py +++ b/tests/test_encoder_decoder.py @@ -1,5 +1,4 @@ import json -import os import re import subprocess from pathlib import Path @@ -9,6 +8,7 @@ import pytest from .test_examples import ACCURACY_PERF_FACTOR, TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT MODELS_TO_TEST = { @@ -88,12 +88,10 @@ def _run_test( with open(Path(tmp_dir) / "predict_results.json") as fp: results = json.load(fp) - device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1" - # Ensure performance requirements (throughput) are met self.baseline.assertRef( compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], predict_samples_per_second=results["predict_samples_per_second"], ) @@ -103,7 +101,7 @@ def _run_test( accuracy_metric = "predict_bleu" self.baseline.assertRef( compare=lambda actual, ref: actual >= ACCURACY_PERF_FACTOR * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], **{accuracy_metric: results[accuracy_metric]}, ) diff --git a/tests/test_examples.py b/tests/test_examples.py index f82850a885..578ba7825e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -50,6 +50,7 @@ MODELS_TO_TEST_FOR_SEQUENCE_CLASSIFICATION, MODELS_TO_TEST_FOR_SPEECH_RECOGNITION, MODELS_TO_TEST_MAPPING, + OH_DEVICE_CONTEXT, ) @@ -60,7 +61,7 @@ TIME_PERF_FACTOR = 1.05 -IS_GAUDI2 = os.environ.get("GAUDI2_CI", "0") == "1" +IS_GAUDI2 = bool("gaudi2" == OH_DEVICE_CONTEXT) def _get_supported_models_for_script( @@ -439,7 +440,7 @@ def test(self): # Assess accuracy with open(Path(tmp_dir) / "accuracy_metrics.json") as fp: results = json.load(fp) - baseline = 0.43 if os.environ.get("GAUDI2_CI", "0") == "1" else 0.42 + baseline = 0.43 if IS_GAUDI2 else 0.42 self.assertGreaterEqual(results["accuracy"], baseline) return elif self.EXAMPLE_NAME == "run_clip": diff --git a/tests/test_feature_extraction.py b/tests/test_feature_extraction.py index f934ac15a9..85f25354b5 100644 --- a/tests/test_feature_extraction.py +++ b/tests/test_feature_extraction.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import time from unittest import TestCase @@ -25,10 +24,12 @@ from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from .utils import OH_DEVICE_CONTEXT + adapt_transformers_to_gaudi() -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines LATENCY_GTE_SMALL_BF16_GRAPH_BASELINE = 0.6812 else: diff --git a/tests/test_fp8_examples.py b/tests/test_fp8_examples.py index 4e2382b8b7..94ccc360f8 100644 --- a/tests/test_fp8_examples.py +++ b/tests/test_fp8_examples.py @@ -1,5 +1,4 @@ import json -import os import re import subprocess from pathlib import Path @@ -8,9 +7,10 @@ import pytest from .test_examples import ACCURACY_PERF_FACTOR, TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines MODELS_TO_TEST = { "fp8": [ @@ -109,17 +109,15 @@ def _test_fp8_train( with open(Path(tmp_dir) / "all_results.json") as fp: results = json.load(fp) - device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1" - # Ensure performance requirements (throughput) are met baseline.assertRef( compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], train_samples_per_second=results["train_samples_per_second"], ) baseline.assertRef( compare=lambda actual, ref: actual >= ACCURACY_PERF_FACTOR * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], eval_accuracy=results["eval_accuracy"], ) diff --git a/tests/test_fsdp_examples.py b/tests/test_fsdp_examples.py index 90931e1e25..69ca704f96 100644 --- a/tests/test_fsdp_examples.py +++ b/tests/test_fsdp_examples.py @@ -8,9 +8,10 @@ import pytest from .test_examples import ACCURACY_PERF_FACTOR, TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines MODELS_TO_TEST = { "bf16": [ @@ -145,24 +146,22 @@ def _test_fsdp( with open(Path(tmp_dir) / "all_results.json") as fp: results = json.load(fp) - device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1" - # Ensure performance requirements (throughput) are met baseline.assertRef( compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], train_samples_per_second=results["train_samples_per_second"], ) if model_name == "bert-base-uncased": baseline.assertRef( compare=lambda actual, ref: actual >= ACCURACY_PERF_FACTOR * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], eval_f1=results["eval_f1"], ) else: baseline.assertRef( compare=lambda actual, ref: actual <= (2 - ACCURACY_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], train_loss=results["train_loss"], ) diff --git a/tests/test_functional_text_generation_example.py b/tests/test_functional_text_generation_example.py index 8b57847d17..b002d7c4f6 100644 --- a/tests/test_functional_text_generation_example.py +++ b/tests/test_functional_text_generation_example.py @@ -9,8 +9,10 @@ from optimum.habana.utils import set_seed +from .utils import OH_DEVICE_CONTEXT -if os.environ.get("GAUDI2_CI", "0") == "1": + +if OH_DEVICE_CONTEXT in ["gaudi2"]: MODEL_OUTPUTS = { "bigcode/starcoder": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_twice():\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_thrice():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n\ndef print_hello_world_four_times():\n print_hello_world()\n print_hello_world()\n print_hello_world()\n ', "bigcode/starcoder2-3b": 'def print_hello_world():\n print("Hello World")\n\ndef print_hello_world_with_name(name):\n print("Hello World, " + name)\n\ndef print_hello_world_with_name_and_age(name, age):\n print("Hello World, " + name + ", " + str(age))\n\ndef print_hello_world_with_name_and_age_and_gender(name, age, gender):\n print("Hello', diff --git a/tests/test_image_to_text_example.py b/tests/test_image_to_text_example.py index 51e99b8466..192ba75dab 100644 --- a/tests/test_image_to_text_example.py +++ b/tests/test_image_to_text_example.py @@ -8,9 +8,10 @@ import pytest from .test_examples import TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines MODELS_TO_TEST = { "bf16": [ @@ -119,12 +120,10 @@ def _test_image_to_text( with open(Path(tmp_dir) / "results.json") as fp: results = json.load(fp) - device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1" - # Ensure performance requirements (throughput) are met baseline.assertRef( compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], throughput=results["throughput"], ) diff --git a/tests/test_object_detection.py b/tests/test_object_detection.py index 014a7704da..3ff7c4e2e9 100644 --- a/tests/test_object_detection.py +++ b/tests/test_object_detection.py @@ -27,11 +27,12 @@ from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from .test_examples import TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT adapt_transformers_to_gaudi() -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines LATENCY_DETR_BF16_GRAPH_BASELINE = 7.0 else: diff --git a/tests/test_openclip_vqa.py b/tests/test_openclip_vqa.py index 812db05645..3ac12d4b46 100644 --- a/tests/test_openclip_vqa.py +++ b/tests/test_openclip_vqa.py @@ -8,6 +8,7 @@ import pytest from .test_examples import TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT MODELS_TO_TEST = { @@ -62,12 +63,10 @@ def _test_openclip_vqa(model_name: str, baseline): with open(Path(tmp_dir) / "results.json") as fp: results = json.load(fp) - device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1" - # Ensure performance requirements (throughput) are met baseline.assertRef( compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], throughput=results["throughput"], ) diff --git a/tests/test_sentence_transformers.py b/tests/test_sentence_transformers.py index f9b3033a7f..a8ddcbb78a 100644 --- a/tests/test_sentence_transformers.py +++ b/tests/test_sentence_transformers.py @@ -7,6 +7,7 @@ from sentence_transformers import SentenceTransformer, util from .test_examples import TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT MODELS_TO_TEST = [ @@ -56,12 +57,10 @@ def _test_sentence_transformers( diff_time = end_time - start_time measured_throughput = len(sentences) / diff_time - device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1" - # Only assert the last measured throughtput as the first iteration is used as a warmup baseline.assertRef( compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], measured_throughput=measured_throughput, ) diff --git a/tests/test_table_transformer.py b/tests/test_table_transformer.py index 946e5cb676..12a9105544 100644 --- a/tests/test_table_transformer.py +++ b/tests/test_table_transformer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import time from unittest import TestCase @@ -25,11 +24,13 @@ from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from .utils import OH_DEVICE_CONTEXT + adapt_transformers_to_gaudi() MODEL_NAME = "microsoft/table-transformer-detection" -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: LATENCY_TABLE_TRANSFORMER_BF16_GRAPH_BASELINE = 2.2 else: LATENCY_TABLE_TRANSFORMER_BF16_GRAPH_BASELINE = 6.6 diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 0c513d8bb1..1725c587ba 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -13,13 +13,13 @@ from optimum.habana.utils import set_seed from .test_examples import TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT prev_quant_model_name = None prev_quant_rank = 0 -if os.environ.get("GAUDI2_CI", "0") == "1": - # Gaudi2 CI +if OH_DEVICE_CONTEXT in ["gaudi2"]: MODELS_TO_TEST = { "bf16_1x": [ ("bigscience/bloomz-7b1", 1, False, False), @@ -366,18 +366,20 @@ def _test_text_generation( with open(Path(tmp_dir) / "results.json") as fp: results = json.load(fp) - device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1" - # Ensure performance requirements (throughput) are met baseline.assertRef( compare=lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref, - context=[device], + context=[OH_DEVICE_CONTEXT], throughput=results["throughput"], ) # Verify output for 1 HPU, BF16 if check_output: - baseline.assertRef(compare=operator.eq, context=[device], output=results["output"][0][0]) + baseline.assertRef( + compare=operator.eq, + context=[OH_DEVICE_CONTEXT], + output=results["output"][0][0], + ) @pytest.mark.parametrize("model_name, batch_size, reuse_cache, check_output", MODELS_TO_TEST["bf16_1x"]) @@ -394,7 +396,7 @@ def test_text_generation_bf16_1x( ) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") +@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize( "model_name, world_size, batch_size, reuse_cache, input_len, output_len", MODELS_TO_TEST["fp8"] ) @@ -423,7 +425,7 @@ def test_text_generation_fp8( ) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") +@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize( "model_name, world_size, batch_size, reuse_cache, input_len, output_len", MODELS_TO_TEST["load_quantized_model_with_autogptq"], @@ -454,7 +456,7 @@ def test_text_generation_gptq( ) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") +@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize( "model_name, world_size, batch_size, reuse_cache, input_len, output_len", MODELS_TO_TEST["load_quantized_model_with_autoawq"], @@ -490,20 +492,20 @@ def test_text_generation_deepspeed(model_name: str, world_size: int, batch_size: _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, batch_size=batch_size) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") +@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize("model_name", MODELS_TO_TEST["torch_compile"]) def test_text_generation_torch_compile(model_name: str, baseline, token): _test_text_generation(model_name, baseline, token, torch_compile=True) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") +@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize("model_name", MODELS_TO_TEST["torch_compile_distributed"]) def test_text_generation_torch_compile_distributed(model_name: str, baseline, token): world_size = 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") +@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize("model_name", MODELS_TO_TEST["distributed_tp"]) def test_text_generation_distributed_tp(model_name: str, baseline, token): world_size = 8 @@ -524,7 +526,7 @@ def test_text_generation_contrastive_search(model_name: str, batch_size: int, re _test_text_generation(model_name, baseline, token, batch_size, reuse_cache, contrastive_search=True) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") +@pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize("model_name, batch_size, reuse_cache", MODELS_TO_TEST["beam_search"]) def test_text_generation_beam_search(model_name: str, batch_size: int, reuse_cache: bool, baseline, token): _test_text_generation(model_name, baseline, token, batch_size, reuse_cache, num_beams=3) diff --git a/tests/test_video_llava.py b/tests/test_video_llava.py index 30c42b0cd8..ec584e7ab8 100644 --- a/tests/test_video_llava.py +++ b/tests/test_video_llava.py @@ -8,9 +8,10 @@ import pytest from .test_examples import TIME_PERF_FACTOR +from .utils import OH_DEVICE_CONTEXT -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines MODELS_TO_TEST = { "bf16": [ diff --git a/tests/test_video_mae.py b/tests/test_video_mae.py index 00dc9c2d26..a2026f446a 100644 --- a/tests/test_video_mae.py +++ b/tests/test_video_mae.py @@ -14,7 +14,6 @@ # limitations under the License. -import os import time from unittest import TestCase @@ -24,8 +23,10 @@ import torch from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor +from .utils import OH_DEVICE_CONTEXT -if os.environ.get("GAUDI2_CI", "0") == "1": + +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines LATENCY_VIDEOMAE_BF16_GRAPH_BASELINE = 17.544198036193848 else: diff --git a/tests/test_zero_shot_object_detection.py b/tests/test_zero_shot_object_detection.py index a70f8f9f36..85838c3de9 100644 --- a/tests/test_zero_shot_object_detection.py +++ b/tests/test_zero_shot_object_detection.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import time from unittest import TestCase @@ -26,10 +25,12 @@ from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from .utils import OH_DEVICE_CONTEXT + adapt_transformers_to_gaudi() -if os.environ.get("GAUDI2_CI", "0") == "1": +if OH_DEVICE_CONTEXT in ["gaudi2"]: # Gaudi2 CI baselines LATENCY_OWLVIT_BF16_GRAPH_BASELINE = 4.2139556878198333 else: