From eb3c2b3e92e63a09514ec39df6865618287e35b7 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:15:58 +0100 Subject: [PATCH 1/6] support chat generator as input of TextGenerationPipeline --- src/transformers/pipelines/text_generation.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index e00d980ca1d..c55a2d9ab53 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,4 +1,5 @@ import enum +import itertools import warnings from typing import Dict @@ -260,16 +261,24 @@ def __call__(self, text_inputs, **kwargs): ids of the generated text. """ if isinstance( - text_inputs, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple) - ) and isinstance(text_inputs[0], (list, tuple, dict)): - # We have one or more prompts in list-of-dicts format, so this is chat mode - if isinstance(text_inputs[0], dict): - return super().__call__(Chat(text_inputs), **kwargs) + text_inputs, (list, tuple, types.GeneratorType, KeyDataset) if is_torch_available() else (list, tuple, types.GeneratorType) + ): + if isinstance(text_inputs, types.GeneratorType): + text_inputs, _ = itertools.tee(text_inputs) + first_item = next(_) else: - chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈 - return super().__call__(chats, **kwargs) - else: - return super().__call__(text_inputs, **kwargs) + first_item = text_inputs[0] + if isinstance(first_item, (list, tuple, dict)): + # We have one or more prompts in list-of-dicts format, so this is chat mode + if isinstance(first_item, dict): + return super().__call__(Chat(text_inputs), **kwargs) + else: + chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈 + if isinstance(text_inputs, types.GeneratorType): + return super().__call__(chats, **kwargs) + else: + return super().__call__(list(chats), **kwargs) + return super().__call__(text_inputs, **kwargs) def preprocess( self, From 6f7c1cdc277bf50216d120851e4752aa789296a1 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 7 Jan 2025 18:29:10 +0100 Subject: [PATCH 2/6] missing import --- src/transformers/pipelines/text_generation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index c55a2d9ab53..a0c2890a00d 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,5 +1,6 @@ import enum import itertools +import types import warnings from typing import Dict @@ -261,7 +262,10 @@ def __call__(self, text_inputs, **kwargs): ids of the generated text. """ if isinstance( - text_inputs, (list, tuple, types.GeneratorType, KeyDataset) if is_torch_available() else (list, tuple, types.GeneratorType) + text_inputs, + (list, tuple, types.GeneratorType, KeyDataset) + if is_torch_available() + else (list, tuple, types.GeneratorType), ): if isinstance(text_inputs, types.GeneratorType): text_inputs, _ = itertools.tee(text_inputs) From bed5cd51ec3f5bd8851e6fd8bfb2bd55e19bb09f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 7 Jan 2025 18:45:16 +0100 Subject: [PATCH 3/6] fix tests --- src/transformers/pipelines/text_generation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index a0c2890a00d..2a048b29745 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -270,15 +270,17 @@ def __call__(self, text_inputs, **kwargs): if isinstance(text_inputs, types.GeneratorType): text_inputs, _ = itertools.tee(text_inputs) first_item = next(_) + is_generator = True else: first_item = text_inputs[0] + is_generator = False if isinstance(first_item, (list, tuple, dict)): # We have one or more prompts in list-of-dicts format, so this is chat mode if isinstance(first_item, dict): return super().__call__(Chat(text_inputs), **kwargs) else: chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈 - if isinstance(text_inputs, types.GeneratorType): + if is_generator: return super().__call__(chats, **kwargs) else: return super().__call__(list(chats), **kwargs) From 329b1e75a6bbd6142877815c626983e892df5f0e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 7 Jan 2025 18:58:50 +0100 Subject: [PATCH 4/6] again --- src/transformers/pipelines/text_generation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 2a048b29745..dac4c93e044 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -268,12 +268,12 @@ def __call__(self, text_inputs, **kwargs): else (list, tuple, types.GeneratorType), ): if isinstance(text_inputs, types.GeneratorType): - text_inputs, _ = itertools.tee(text_inputs) - first_item = next(_) is_generator = True + text_inputs, _ = itertools.tee(text_inputs) + text_inputs, first_item = (x for x in text_inputs), next(_) else: - first_item = text_inputs[0] is_generator = False + first_item = text_inputs[0] if isinstance(first_item, (list, tuple, dict)): # We have one or more prompts in list-of-dicts format, so this is chat mode if isinstance(first_item, dict): From 689e452a800d4013f18c80044654573f6dc3eecf Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 7 Jan 2025 19:00:47 +0100 Subject: [PATCH 5/6] simpler --- src/transformers/pipelines/text_generation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index dac4c93e044..e15228fbe0a 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -268,11 +268,9 @@ def __call__(self, text_inputs, **kwargs): else (list, tuple, types.GeneratorType), ): if isinstance(text_inputs, types.GeneratorType): - is_generator = True text_inputs, _ = itertools.tee(text_inputs) text_inputs, first_item = (x for x in text_inputs), next(_) else: - is_generator = False first_item = text_inputs[0] if isinstance(first_item, (list, tuple, dict)): # We have one or more prompts in list-of-dicts format, so this is chat mode @@ -280,7 +278,7 @@ def __call__(self, text_inputs, **kwargs): return super().__call__(Chat(text_inputs), **kwargs) else: chats = (Chat(chat) for chat in text_inputs) # 🐈 🐈 🐈 - if is_generator: + if isinstance(text_inputs, types.GeneratorType): return super().__call__(chats, **kwargs) else: return super().__call__(list(chats), **kwargs) From 6745dfd55e5c9733610e50ba684c6c7aed956b6d Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 8 Jan 2025 12:54:09 +0100 Subject: [PATCH 6/6] add test --- .../test_pipelines_text_generation.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 51f3cae5e31..7de84e646e1 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -292,6 +292,50 @@ def __getitem__(self, i): ], ) + @require_torch + def test_small_chat_model_with_iterator_pt(self): + from transformers.pipelines.pt_utils import PipelineIterator + + text_generator = pipeline( + task="text-generation", model="hf-internal-testing/tiny-gpt2-with-chatml-template", framework="pt" + ) + + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + ] + chat2 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a second test"}, + ] + expected_chat1 = chat1 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + expected_chat2 = chat2 + [ + { + "role": "assistant", + "content": " stairs stairs stairs stairs stairs stairs stairs stairs stairs stairs", + } + ] + + def data(): + yield from [chat1, chat2] + + outputs = text_generator(data(), do_sample=False, max_new_tokens=10) + assert isinstance(outputs, PipelineIterator) + outputs = list(outputs) + self.assertEqual( + outputs, + [ + [{"generated_text": expected_chat1}], + [{"generated_text": expected_chat2}], + ], + ) + @require_tf def test_small_model_tf(self): text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")