Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 38 additions & 56 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class TextStreamer(BaseStreamer):
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.

Examples:

Expand All @@ -59,10 +63,15 @@ class TextStreamer(BaseStreamer):
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs

# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True

def put(self, value):
"""
Expand All @@ -73,11 +82,15 @@ def put(self, value):
elif len(value.shape) > 1:
value = value[0]

if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return

# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache)
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)

# After symbol for a new line, we flush the cache.
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
Expand All @@ -94,30 +107,34 @@ def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache)
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""

# Print a newline (and the remaining text, if any)
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)

def on_finalized_text(self, token: str, stream_end: bool = False):
"""Prints the new text to stdout."""
print(token, flush=True, end="" if not stream_end else None)
def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
print(text, flush=True, end="" if not stream_end else None)


class TextIteratorStreamer(BaseStreamer):
class TextIteratorStreamer(TextStreamer):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio
demo).
useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
Gradio demo).

Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.

Examples:

Expand All @@ -142,58 +159,23 @@ class TextIteratorStreamer(BaseStreamer):
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
self.tokenizer = tokenizer
self.token_cache = []
self.print_len = 0
self.queue = Queue()
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
self.text_queue = Queue()
self.stop_signal = None

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self.text_queue.put(text)
if stream_end:
self.text_queue.put(self.stop_signal)

def __iter__(self):
return self

def __next__(self):
value = self.queue.get()
value = self.text_queue.get()
if value == self.stop_signal:
raise StopIteration()
else:
return value

def put(self, value):
"""
Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]

# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache)

# After symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.queue.put(printable_text)

def end(self):
"""Flushes any remaining cache and puts the stop signal in the queue."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""

self.queue.put(printable_text)
self.queue.put(self.stop_signal) # Put the stop signal
39 changes: 39 additions & 0 deletions tests/generation/test_streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@


if is_torch_available():
import torch

from transformers import AutoModelForCausalLM


Expand Down Expand Up @@ -63,3 +65,40 @@ def test_iterator_streamer_matches_non_streaming(self):
streamer_text += new_text

self.assertEqual(streamer_text, greedy_text)

def test_text_streamer_skip_prompt(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1

input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
new_greedy_ids = greedy_ids[:, input_ids.shape[1] :]
new_greedy_text = tokenizer.decode(new_greedy_ids[0])

with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_prompt=True)
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
streamer_text = cs.out[:-1]

self.assertEqual(streamer_text, new_greedy_text)

def test_text_streamer_decode_kwargs(self):
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
# `skip_special_tokens=True` has no effect on them
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2").to(torch_device)
model.config.eos_token_id = -1

input_ids = torch.ones((1, 5), device=torch_device).long() * model.config.bos_token_id
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
model.generate(input_ids, max_new_tokens=1, do_sample=False, streamer=streamer)

# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
# re-tokenized, must only contain one token
streamer_text = cs.out[:-1] # Remove the final "\n"
streamer_text_tokenized = tokenizer(streamer_text, return_tensors="pt")
self.assertEqual(streamer_text_tokenized.input_ids.shape, (1, 1))