Skip to content

Commit

Permalink
test_text_generation_example: SQUASH
Browse files Browse the repository at this point in the history
Signed-off-by: U. Artie Eoff <[email protected]>
  • Loading branch information
uartie committed Jan 29, 2025
1 parent 06b3b4d commit 131fe97
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import operator
import os
import re
import subprocess
Expand All @@ -18,7 +19,7 @@
prev_quant_rank = 0

if os.environ.get("GAUDI2_CI", "0") == "1":
# Gaudi2 CI baselines
# Gaudi2 CI
MODELS_TO_TEST = {
"bf16_1x": [
("bigscience/bloomz-7b1", 1, False, False),
Expand Down Expand Up @@ -107,7 +108,7 @@
],
}
else:
# Gaudi1 CI baselines
# Gaudi1 CI
MODELS_TO_TEST = {
"bf16_1x": [
("bigscience/bloomz-7b1", 1, False, False),
Expand Down Expand Up @@ -357,18 +358,24 @@ def _test_text_generation(
device = "gaudi2" if os.environ.get("GAUDI2_CI", "0") == "1" else "gaudi1"

# Ensure performance requirements (throughput) are met
def check_throughput(key, expect, actual):
assert actual >= (2 - TIME_PERF_FACTOR) * expect
baseline.check(compare = check_throughput, context = [device], throughput = results["throughput"])
# def check_throughput(key, expect, actual):
# assert actual >= (2 - TIME_PERF_FACTOR) * expect
# baseline.check(compare = check_throughput, context = [device], throughput = results["throughput"])

baseline.assertRef(
compare = lambda actual, ref: actual >= (2 - TIME_PERF_FACTOR) * ref),
context = [device], throughput = results["throughput"],
)

# Verify output for 1 HPU, BF16
if check_output:
def do_check_output(key, expect, actual):
assert expect is not None, (
f"Failed functional testing, missing expected output for model {model_name}"
)
assert actual == expect
baseline.check(compare = do_check_output, context = [device], output = results["output"][0][0])
# def do_check_output(key, expect, actual):
# assert expect is not None, (
# f"Failed functional testing, missing expected output for model {model_name}"
# )
# assert actual == expect
# baseline.check(compare = do_check_output, context = [device], output = results["output"][0][0])
baseline.assertRef(compare = operator.eq, context = [device], output = results["output"][0][0])


@pytest.mark.parametrize("model_name, batch_size, reuse_cache, check_output", MODELS_TO_TEST["bf16_1x"])
Expand Down

0 comments on commit 131fe97

Please sign in to comment.