diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index e00d980ca1d..e15228fbe0a 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,4 +1,6 @@ import enum +import itertools +import types import warnings from typing import Dict @@ -260,16 +262,27 @@ def __call__(self, text_inputs, **kwargs): ids of the generated text. """ if isinstance( - text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple) - ) and isinstance(text_inputs[0], (list, tuple, dict)): - # We have one or more prompts in list-of-dicts format, so this is chat mode - if isinstance(text_inputs[0], dict): - return super().__call__(Chat(text_inputs), **kwargs) + text_inputs, + (list, tuple, types.GeneratorType, KeyDataset) + if is_torch_available() + else (list, tuple, types.GeneratorType), + ): + if isinstance(text_inputs, types.GeneratorType): + text_inputs, _ = itertools.tee(text_inputs) + text_inputs, first_item = (x for x in text_inputs), next(_) else: - chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈 - return super().__call__(chats, **kwargs) - else: - return super().__call__(text_inputs, **kwargs) + first_item = text_inputs[0] + if isinstance(first_item, (list, tuple, dict)): + # We have one or more prompts in list-of-dicts format, so this is chat mode + if isinstance(first_item, dict): + return super().__call__(Chat(text_inputs), **kwargs) + else: + chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈 + if isinstance(text_inputs, types.GeneratorType): + return super().__call__(chats, **kwargs) + else: + return super().__call__(list(chats), **kwargs) + return super().__call__(text_inputs, **kwargs) def preprocess( self, diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 51f3cae5e31..7de84e646e1 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -292,6 +292,50 @@ def __getitem__(self, i): ], ) + @require_torch + def test_small_chat_model_with_iterator_pt(self): + from transformers.pipelines.pt_utils import PipelineIterator + + text_generator = pipeline( + task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt" + ) + + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + ] + chat2 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a second test"}, + ] + expected_chat1 = chat1 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + expected_chat2 = chat2 + [ + { + "role": "assistant", + "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs", + } + ] + + def data(): + yield from [chat1, chat2] + + outputs = text_generator(data(), do_sample=False, max_new_tokens=10) + assert isinstance(outputs, PipelineIterator) + outputs = list(outputs) + self.assertEqual( + outputs, + [ + [{"generated_text": expected_chat1}], + [{"generated_text": expected_chat2}], + ], + ) + @require_tf def test_small_model_tf(self): text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")