Skip to content

Commit

Permalink
test_pipeline: use baseline fixture
Browse files Browse the repository at this point in the history
Use the new baseline fixture to validate test results.

Signed-off-by: U. Artie Eoff <[email protected]>
  • Loading branch information
uartie committed Feb 5, 2025
1 parent 0de45b1 commit 0c92632
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
17 changes: 17 additions & 0 deletions tests/baselines/fixture/tests/test_pipeline.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"tests/test_pipeline.py::TestGaudiPipeline::test_image_to_text[Salesforce/blip-image-captioning-base-44]": {
"generated_text": "a soccer player is playing a game on the app"
},
"tests/test_pipeline.py::TestGaudiPipeline::test_image_to_text[nlpconnect/vit-gpt2-image-captioning-44]": {
"generated_text": "a soccer game with a player jumping to catch"
},
"tests/test_pipeline.py::TestGaudiPipeline::test_text_to_speech[facebook/hf-seamless-m4t-medium]": {
"sampling_rate": 16000
},
"tests/test_pipeline.py::TestGaudiPipeline::test_text_to_speech[facebook/mms-tts-eng]": {
"sampling_rate": 16000
},
"tests/test_pipeline.py::TestGaudiPipeline::test_text_to_speech[microsoft/speecht5_tts]": {
"sampling_rate": 16000
}
}
26 changes: 15 additions & 11 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
import os

import numpy as np
Expand All @@ -27,20 +28,20 @@

MODELS_TO_TEST = {
"text-to-speech": [
("microsoft/speecht5_tts", 16000),
("facebook/hf-seamless-m4t-medium", 16000),
("facebook/mms-tts-eng", 16000),
"microsoft/speecht5_tts",
"facebook/hf-seamless-m4t-medium",
"facebook/mms-tts-eng",
],
"image-to-text": [
("Salesforce/blip-image-captioning-base", "a soccer player is playing a game on the app"),
("nlpconnect/vit-gpt2-image-captioning", "a soccer game with a player jumping to catch"),
("Salesforce/blip-image-captioning-base", 44),
("nlpconnect/vit-gpt2-image-captioning", 44),
],
}


class TestGaudiPipeline:
@pytest.mark.parametrize("model, expected_result", MODELS_TO_TEST["image-to-text"])
def test_image_to_text(self, model, expected_result):
@pytest.mark.parametrize("model, validate_length", MODELS_TO_TEST["image-to-text"])
def test_image_to_text(self, model, validate_length, baseline):
adapt_transformers_to_gaudi()
MODEL_DTYPE_LIST = [torch.bfloat16, torch.float32]
generate_kwargs = {
Expand All @@ -60,10 +61,12 @@ def test_image_to_text(self, model, expected_result):
generator.model = wrap_in_hpu_graph(generator.model)
for i in range(3):
output = generator(image, generate_kwargs=generate_kwargs)
assert output[0]["generated_text"].startswith(expected_result)

@pytest.mark.parametrize("model, expected_sample_rate", MODELS_TO_TEST["text-to-speech"])
def test_text_to_speech(self, model, expected_sample_rate):
result = output[0]["generated_text"][:validate_length]
baseline.assertRef(compare=operator.eq, generated_text=result)

@pytest.mark.parametrize("model", MODELS_TO_TEST["text-to-speech"])
def test_text_to_speech(self, model, baseline):
adapt_transformers_to_gaudi()
MODEL_DTYPE_LIST = [torch.bfloat16, torch.float32]
text = "hello, the dog is cooler"
Expand Down Expand Up @@ -95,4 +98,5 @@ def test_text_to_speech(self, model, expected_sample_rate):
for i in range(3):
output = generator(text, forward_params=forward_params, generate_kwargs=generate_kwargs)
assert isinstance(output["audio"], np.ndarray)
assert output["sampling_rate"] == expected_sample_rate

baseline.assertRef(compare=operator.eq, sampling_rate=output["sampling_rate"])

0 comments on commit 0c92632

Please sign in to comment.