Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/source/en/generation_strategies.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,29 @@ one for summarization with beam search). You must have the right Hub permissions
['Les fichiers de configuration sont faciles à utiliser !']
```

## Streaming

The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance
from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and
`end()` is used to flag the end of text generation.

In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes
ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into
your screen, one word at a time:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)

>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```

## Decoding strategies

Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
Expand Down
4 changes: 4 additions & 0 deletions docs/source/en/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] top_k_top_p_filtering

[[autodoc]] tf_top_k_top_p_filtering

## Streamers

[[autodoc]] TextStreamer
3 changes: 2 additions & 1 deletion docs/source/en/main_classes/text_generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ of the generation method.

To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc,
and how to create and save a customized generation configuration, refer to the
[text generation strategies guide](../generation_strategies).
[text generation strategies guide](../generation_strategies). The guide also explains how to use related features,
like token streaming.

## GenerationConfig

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": ["GenerationConfig"],
"generation": ["GenerationConfig", "TextStreamer"],
"hf_argparser": ["HfArgumentParser"],
"image_transforms": [],
"integrations": [
Expand Down Expand Up @@ -3769,7 +3769,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig
from .generation import GenerationConfig, TextStreamer
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available


_import_structure = {"configuration_utils": ["GenerationConfig"]}

_import_structure = {"configuration_utils": ["GenerationConfig"], "streamers": ["TextStreamer"]}

try:
if not is_torch_available():
Expand Down Expand Up @@ -150,6 +149,7 @@

if TYPE_CHECKING:
from .configuration_utils import GenerationConfig
from .streamers import TextStreamer

try:
if not is_torch_available():
Expand Down
104 changes: 104 additions & 0 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING


if TYPE_CHECKING:
from ..models.auto import AutoTokenizer


class BaseStreamer:
"""
Base class from which `.generate()` streamers should inherit.
"""

def put(self, value):
"""Function that is called by `.generate()` to push new tokens"""
raise NotImplementedError()

def end(self):
"""Function that is called by `.generate()` to signal the end of generation"""
raise NotImplementedError()


class TextStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.

Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.

Examples:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)

>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
self.tokenizer = tokenizer
self.token_cache = []
self.print_len = 0

def put(self, value):
"""
Recives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]

# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache)

# After symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)

print(printable_text, flush=True, end="")

def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache)
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""

# Print a newline (and the remaining text, if any)
print(printable_text, flush=True)
51 changes: 49 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -72,6 +72,10 @@
)


if TYPE_CHECKING:
from .streamers import BaseStreamer


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -1116,6 +1120,7 @@ def generate(
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1165,6 +1170,9 @@ def generate(
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.

kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
Expand Down Expand Up @@ -1291,6 +1299,9 @@ def generate(
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

if streamer is not None:
streamer.put(input_ids.cpu())

# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
Expand Down Expand Up @@ -1331,7 +1342,8 @@ def generate(
)

is_contrastive_search_gen_mode = (
generation_config.top_k is not None
(generation_config.num_beams == 1)
and generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.do_sample is False
and generation_config.penalty_alpha is not None
Expand Down Expand Up @@ -1380,6 +1392,11 @@ def generate(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)

if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)

if self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
Expand Down Expand Up @@ -1422,6 +1439,7 @@ def generate(
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

Expand All @@ -1443,6 +1461,7 @@ def generate(
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

Expand All @@ -1469,6 +1488,7 @@ def generate(
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

Expand Down Expand Up @@ -1699,6 +1719,7 @@ def contrastive_search(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1746,6 +1767,9 @@ def contrastive_search(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2006,6 +2030,8 @@ def contrastive_search(

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
Expand All @@ -2023,6 +2049,9 @@ def contrastive_search(
else:
this_peer_finished = True

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput(
Expand Down Expand Up @@ -2057,6 +2086,7 @@ def greedy_search(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2101,6 +2131,9 @@ def greedy_search(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2252,6 +2285,8 @@ def greedy_search(

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
Expand All @@ -2269,6 +2304,9 @@ def greedy_search(
else:
this_peer_finished = True

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GreedySearchEncoderDecoderOutput(
Expand Down Expand Up @@ -2304,6 +2342,7 @@ def sample(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2350,6 +2389,9 @@ def sample(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2521,6 +2563,8 @@ def sample(

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
Expand All @@ -2538,6 +2582,9 @@ def sample(
else:
this_peer_finished = True

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
Expand Down
Loading