Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,5 @@ A [`Constraint`] can be used to force the generation to include specific tokens
## Streamers

[[autodoc]] TextStreamer

[[autodoc]] TextIteratorStreamer
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": ["GenerationConfig", "TextStreamer"],
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
"hf_argparser": ["HfArgumentParser"],
"image_transforms": [],
"integrations": [
Expand Down Expand Up @@ -3770,7 +3770,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig, TextStreamer
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available


_import_structure = {"configuration_utils": ["GenerationConfig"], "streamers": ["TextStreamer"]}
_import_structure = {
"configuration_utils": ["GenerationConfig"],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
}

try:
if not is_torch_available():
Expand Down Expand Up @@ -149,7 +152,7 @@

if TYPE_CHECKING:
from .configuration_utils import GenerationConfig
from .streamers import TextStreamer
from .streamers import TextIteratorStreamer, TextStreamer

try:
if not is_torch_available():
Expand Down
91 changes: 91 additions & 0 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from queue import Queue
from typing import TYPE_CHECKING


Expand Down Expand Up @@ -102,3 +103,93 @@ def end(self):

# Print a newline (and the remaining text, if any)
print(printable_text, flush=True)


class TextIteratorStreamer(BaseStreamer):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio
demo).

Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.

Examples:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
>>> from threading import Thread

>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextIteratorStreamer(tok)

>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
>>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
>>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
>>> thread.start()
>>> generated_text = ""
>>> for new_text in streamer:
... generated_text += new_text
>>> generated_text
'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
self.tokenizer = tokenizer
self.token_cache = []
self.print_len = 0
self.queue = Queue()
self.stop_signal = None

def __iter__(self):
return self

def __next__(self):
value = self.queue.get()
if value == self.stop_signal:
raise StopIteration()
else:
return value

def put(self, value):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the logic in put() and end() is a near copy-paste from TextStreamer -- instead of printing, puts things in the queue.

Most of the logic here is to work around model-specific tokenization quirks.

"""
Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]

# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache)

# After symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.queue.put(printable_text)

def end(self):
"""Flushes any remaining cache and puts the stop signal in the queue."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""

self.queue.put(printable_text)
self.queue.put(self.stop_signal) # Put the stop signal
29 changes: 25 additions & 4 deletions tests/generation/test_streamers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# limitations under the License.

import unittest
from threading import Thread

from transformers import AutoTokenizer, TextStreamer, is_torch_available
from transformers import AutoTokenizer, TextIteratorStreamer, TextStreamer, is_torch_available
from transformers.testing_utils import CaptureStdout, require_torch, torch_device

from ..test_modeling_common import ids_tensor
Expand All @@ -27,7 +28,7 @@

@require_torch
class StreamerTester(unittest.TestCase):
def test_text_streamer_stdout(self):
def test_text_streamer_matches_non_streaming(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1
Expand All @@ -39,6 +40,26 @@ def test_text_streamer_stdout(self):
with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer)
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)

# The greedy text should be printed to stdout, except for the final "\n" in the streamer
self.assertEqual(cs.out[:-1], greedy_text)
streamer_text = cs.out[:-1]

self.assertEqual(streamer_text, greedy_text)

def test_iterator_streamer_matches_non_streaming(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1

input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
greedy_text = tokenizer.decode(greedy_ids[0])

streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {"input_ids": input_ids, "max_new_tokens": 10, "do_sample": False, "streamer": streamer}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
streamer_text = ""
for new_text in streamer:
streamer_text += new_text

self.assertEqual(streamer_text, greedy_text)