diff --git a/src/transformers/generation/streamers.py b/src/transformers/generation/streamers.py index 06f7be9d63e0..719c067d21c0 100644 --- a/src/transformers/generation/streamers.py +++ b/src/transformers/generation/streamers.py @@ -42,6 +42,10 @@ class TextStreamer(BaseStreamer): Parameters: tokenizer (`AutoTokenizer`): The tokenized used to decode the tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. Examples: @@ -59,10 +63,15 @@ class TextStreamer(BaseStreamer): ``` """ - def __init__(self, tokenizer: "AutoTokenizer"): + def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): self.tokenizer = tokenizer + self.skip_prompt = skip_prompt + self.decode_kwargs = decode_kwargs + + # variables used in the streaming process self.token_cache = [] self.print_len = 0 + self.next_tokens_are_prompt = True def put(self, value): """ @@ -73,11 +82,15 @@ def put(self, value): elif len(value.shape) > 1: value = value[0] + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + # Add the new token to the cache and decodes the entire thing. self.token_cache.extend(value.tolist()) - text = self.tokenizer.decode(self.token_cache) + text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) - # After symbol for a new line, we flush the cache. + # After the symbol for a new line, we flush the cache. if text.endswith("\n"): printable_text = text[self.print_len :] self.token_cache = [] @@ -94,30 +107,34 @@ def end(self): """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: - text = self.tokenizer.decode(self.token_cache) + text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) printable_text = text[self.print_len :] self.token_cache = [] self.print_len = 0 else: printable_text = "" - # Print a newline (and the remaining text, if any) + self.next_tokens_are_prompt = True self.on_finalized_text(printable_text, stream_end=True) - def on_finalized_text(self, token: str, stream_end: bool = False): - """Prints the new text to stdout.""" - print(token, flush=True, end="" if not stream_end else None) + def on_finalized_text(self, text: str, stream_end: bool = False): + """Prints the new text to stdout. If the stream is ending, also prints a newline.""" + print(text, flush=True, end="" if not stream_end else None) -class TextIteratorStreamer(BaseStreamer): +class TextIteratorStreamer(TextStreamer): """ Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is - useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio - demo). + useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive + Gradio demo). Parameters: tokenizer (`AutoTokenizer`): The tokenized used to decode the tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. Examples: @@ -142,58 +159,23 @@ class TextIteratorStreamer(BaseStreamer): ``` """ - def __init__(self, tokenizer: "AutoTokenizer"): - self.tokenizer = tokenizer - self.token_cache = [] - self.print_len = 0 - self.queue = Queue() + def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.text_queue = Queue() self.stop_signal = None + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" + self.text_queue.put(text) + if stream_end: + self.text_queue.put(self.stop_signal) + def __iter__(self): return self def __next__(self): - value = self.queue.get() + value = self.text_queue.get() if value == self.stop_signal: raise StopIteration() else: return value - - def put(self, value): - """ - Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words. - """ - if len(value.shape) > 1 and value.shape[0] > 1: - raise ValueError("TextStreamer only supports batch size 1") - elif len(value.shape) > 1: - value = value[0] - - # Add the new token to the cache and decodes the entire thing. - self.token_cache.extend(value.tolist()) - text = self.tokenizer.decode(self.token_cache) - - # After symbol for a new line, we flush the cache. - if text.endswith("\n"): - printable_text = text[self.print_len :] - self.token_cache = [] - self.print_len = 0 - # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, - # which may change with the subsequent token -- there are probably smarter ways to do this!) - else: - printable_text = text[self.print_len : text.rfind(" ") + 1] - self.print_len += len(printable_text) - self.queue.put(printable_text) - - def end(self): - """Flushes any remaining cache and puts the stop signal in the queue.""" - # Flush the cache, if it exists - if len(self.token_cache) > 0: - text = self.tokenizer.decode(self.token_cache) - printable_text = text[self.print_len :] - self.token_cache = [] - self.print_len = 0 - else: - printable_text = "" - - self.queue.put(printable_text) - self.queue.put(self.stop_signal) # Put the stop signal diff --git a/tests/generation/test_streamers.py b/tests/generation/test_streamers.py index bf465ab31c6b..7214e56cd3c3 100644 --- a/tests/generation/test_streamers.py +++ b/tests/generation/test_streamers.py @@ -23,6 +23,8 @@ if is_torch_available(): + import torch + from transformers import AutoModelForCausalLM @@ -63,3 +65,40 @@ def test_iterator_streamer_matches_non_streaming(self): streamer_text += new_text self.assertEqual(streamer_text, greedy_text) + + def test_text_streamer_skip_prompt(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device) + greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False) + new_greedy_ids = greedy_ids[:, input_ids.shape[1] :] + new_greedy_text = tokenizer.decode(new_greedy_ids[0]) + + with CaptureStdout() as cs: + streamer = TextStreamer(tokenizer, skip_prompt=True) + model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer) + # The greedy text should be printed to stdout, except for the final "\n" in the streamer + streamer_text = cs.out[:-1] + + self.assertEqual(streamer_text, new_greedy_text) + + def test_text_streamer_decode_kwargs(self): + # Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested + # with actual models -- the dummy models' tokenizers are not aligned with their models, and + # `skip_special_tokens=True` has no effect on them + tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(torch_device) + model.config.eos_token_id = -1 + + input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id + with CaptureStdout() as cs: + streamer = TextStreamer(tokenizer, skip_special_tokens=True) + model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer) + + # The prompt contains a special token, so the streamer should not print it. As such, the output text, when + # re-tokenized, must only contain one token + streamer_text = cs.out[:-1] # Remove the final "\n" + streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt") + self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))