diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index dd93c79e92bf..10b050b3f8b2 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -269,3 +269,5 @@ A [`Constraint`] can be used to force the generation to include specific tokens ## Streamers [[autodoc]] TextStreamer + +[[autodoc]] TextIteratorStreamer diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f94aa2de48b4..75ff4e8345a2 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -96,7 +96,7 @@ "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], "file_utils": [], - "generation": ["GenerationConfig", "TextStreamer"], + "generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"], "hf_argparser": ["HfArgumentParser"], "image_transforms": [], "integrations": [ @@ -3770,7 +3770,7 @@ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin # Generation - from .generation import GenerationConfig, TextStreamer + from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer from .hf_argparser import HfArgumentParser # Integrations diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index d163c44dc7f0..bf87b6e5ff5f 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -17,7 +17,10 @@ from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available -_import_structure = {"configuration_utils": ["GenerationConfig"], "streamers": ["TextStreamer"]} +_import_structure = { + "configuration_utils": ["GenerationConfig"], + "streamers": ["TextIteratorStreamer", "TextStreamer"], +} try: if not is_torch_available(): @@ -149,7 +152,7 @@ if TYPE_CHECKING: from .configuration_utils import GenerationConfig - from .streamers import TextStreamer + from .streamers import TextIteratorStreamer, TextStreamer try: if not is_torch_available(): diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index d110693b0eac..78d98666b3d1 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from queue import Queue from typing import TYPE_CHECKING @@ -102,3 +103,93 @@ def end(self): # Print a newline (and the remaining text, if any) print(printable_text, flush=True) + + +class TextIteratorStreamer(BaseStreamer): + """ + Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is + useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio + demo). + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenized used to decode the tokens. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + >>> from threading import Thread + + >>> tok = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") + >>> streamer = TextIteratorStreamer(tok) + + >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. + >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) + >>> thread = Thread(target=model.generate, kwargs=generation_kwargs) + >>> thread.start() + >>> generated_text = "" + >>> for new_text in streamer: + ... generated_text += new_text + >>> generated_text + 'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,' + ``` + """ + + def __init__(self, tokenizer: "AutoTokenizer"): + self.tokenizer = tokenizer + self.token_cache = [] + self.print_len = 0 + self.queue = Queue() + self.stop_signal = None + + def __iter__(self): + return self + + def __next__(self): + value = self.queue.get() + if value == self.stop_signal: + raise StopIteration() + else: + return value + + def put(self, value): + """ + Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words. + """ + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + # Add the new token to the cache and decodes the entire thing. + self.token_cache.extend(value.tolist()) + text = self.tokenizer.decode(self.token_cache) + + # After symbol for a new line, we flush the cache. + if text.endswith("\n"): + printable_text = text[self.print_len :] + self.token_cache = [] + self.print_len = 0 + # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, + # which may change with the subsequent token -- there are probably smarter ways to do this!) + else: + printable_text = text[self.print_len : text.rfind(" ") + 1] + self.print_len += len(printable_text) + self.queue.put(printable_text) + + def end(self): + """Flushes any remaining cache and puts the stop signal in the queue.""" + # Flush the cache, if it exists + if len(self.token_cache) > 0: + text = self.tokenizer.decode(self.token_cache) + printable_text = text[self.print_len :] + self.token_cache = [] + self.print_len = 0 + else: + printable_text = "" + + self.queue.put(printable_text) + self.queue.put(self.stop_signal) # Put the stop signal diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index 120623285904..bf465ab31c6b 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -14,8 +14,9 @@ # limitations under the License. import unittest +from threading import Thread -from transformers import AutoTokenizer, TextStreamer, is_torch_available +from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available from transformers.testing_utils import CaptureStdout, require_torch, torch_device from ..test_modeling_common import ids_tensor @@ -27,7 +28,7 @@ @require_torch class StreamerTester(unittest.TestCase): - def test_text_streamer_stdout(self): + def test_text_streamer_matches_non_streaming(self): tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) model.config.eos_token_id = -1 @@ -39,6 +40,26 @@ def test_text_streamer_stdout(self): with CaptureStdout() as cs: streamer = TextStreamer(tokenizer) model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer) - # The greedy text should be printed to stdout, except for the final "\n" in the streamer - self.assertEqual(cs.out[:-1], greedy_text) + streamer_text = cs.out[:-1] + + self.assertEqual(streamer_text, greedy_text) + + def test_iterator_streamer_matches_non_streaming(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False) + greedy_text = tokenizer.decode(greedy_ids[0]) + + streamer = TextIteratorStreamer(tokenizer) + generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer} + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + streamer_text = "" + for new_text in streamer: + streamer_text += new_text + + self.assertEqual(streamer_text, greedy_text)