Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] top_k_top_p_filtering

[[autodoc]] tf_top_k_top_p_filtering

## Streamers

[[autodoc]] TextStreamer
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": ["GenerationConfig"],
"generation": ["GenerationConfig", "TextStreamer"],
"hf_argparser": ["HfArgumentParser"],
"image_transforms": [],
"integrations": [
Expand Down Expand Up @@ -3769,7 +3769,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig
from .generation import GenerationConfig, TextStreamer
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available


_import_structure = {"configuration_utils": ["GenerationConfig"]}

_import_structure = {"configuration_utils": ["GenerationConfig"], "streamers": ["TextStreamer"]}

try:
if not is_torch_available():
Expand Down Expand Up @@ -150,6 +149,7 @@

if TYPE_CHECKING:
from .configuration_utils import GenerationConfig
from .streamers import TextStreamer

try:
if not is_torch_available():
Expand Down
72 changes: 72 additions & 0 deletions src/transformers/generation/streamers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING


if TYPE_CHECKING:
from ..models.auto import AutoTokenizer


class BaseStreamer:
"""
Base class from which `.generate()` streamers should inherit.
"""

def put(self, value):
"""Function that is called by `.generate()` to push new tokens"""
raise NotImplementedError()

def end(self):
"""Function that is called by `.generate()` to signal the end of generation"""
raise NotImplementedError()


class TextStreamer(BaseStreamer):
"""
Simple text streamer that prints a token as soon as it gets them.

Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.

Examples:

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

>>> tok = AutoTokenizer.from_pretrained("distilgpt2")
>>> model = AutoModelForCausalLM.from_pretrained("distilgpt2")
>>> inputs = tok(["This cat is"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> model.generate(**inputs, streamer=streamer)
```
"""

def __init__(self, tokenizer: "AutoTokenizer"):
self.tokenizer = tokenizer

def put(self, value):
"""Prints the token(s) to stdout"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
text = self.tokenizer.decode(value)
print(text, flush=True, end="")

def end(self):
"""Prints a newline to stdout"""
print("", flush=True)
56 changes: 50 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -72,6 +72,10 @@
)


if TYPE_CHECKING:
from .streamers import BaseStreamer


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -1116,6 +1120,7 @@ def generate(
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1165,6 +1170,9 @@ def generate(
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.

kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
Expand Down Expand Up @@ -1291,6 +1299,9 @@ def generate(
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

if streamer is not None:
streamer.put(input_ids.cpu())

# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
Expand Down Expand Up @@ -1326,12 +1337,13 @@ def generate(
)

# 7. determine generation mode
is_constraint_gen_mode = (
is_constraint_gen_mode = (generation_config.num_beams > 1) and (
generation_config.constraints is not None or generation_config.force_words_ids is not None
)

is_contrastive_search_gen_mode = (
generation_config.top_k is not None
(generation_config.num_beams == 1)
and generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.do_sample is False
and generation_config.penalty_alpha is not None
Expand Down Expand Up @@ -1380,6 +1392,11 @@ def generate(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)

if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)

if self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
Expand Down Expand Up @@ -1422,6 +1439,7 @@ def generate(
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

Expand All @@ -1443,6 +1461,7 @@ def generate(
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

Expand All @@ -1469,6 +1488,7 @@ def generate(
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)

Expand Down Expand Up @@ -1602,9 +1622,6 @@ def generate(
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")

if generation_config.num_beams <= 1:
raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.")

Comment thread
gante marked this conversation as resolved.
if generation_config.do_sample:
raise ValueError("`do_sample` needs to be false for constrained generation.")

Expand Down Expand Up @@ -1699,6 +1716,7 @@ def contrastive_search(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[ContrastiveSearchOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -1746,6 +1764,9 @@ def contrastive_search(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2006,6 +2027,8 @@ def contrastive_search(

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
Expand All @@ -2023,6 +2046,9 @@ def contrastive_search(
else:
this_peer_finished = True

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return ContrastiveSearchEncoderDecoderOutput(
Expand Down Expand Up @@ -2057,6 +2083,7 @@ def greedy_search(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2101,6 +2128,9 @@ def greedy_search(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2252,6 +2282,8 @@ def greedy_search(

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
Expand All @@ -2269,6 +2301,9 @@ def greedy_search(
else:
this_peer_finished = True

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GreedySearchEncoderDecoderOutput(
Expand Down Expand Up @@ -2304,6 +2339,7 @@ def sample(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: Optional[bool] = False,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2350,6 +2386,9 @@ def sample(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand Down Expand Up @@ -2521,6 +2560,8 @@ def sample(

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
Expand All @@ -2538,6 +2579,9 @@ def sample(
else:
this_peer_finished = True

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
Expand Down
44 changes: 44 additions & 0 deletions tests/generation/test_streamers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from transformers import AutoTokenizer, TextStreamer, is_torch_available
from transformers.testing_utils import CaptureStdout, require_torch, torch_device

from ..test_modeling_common import ids_tensor


if is_torch_available():
from transformers import AutoModelForCausalLM


@require_torch
class StreamerTester(unittest.TestCase):
def test_text_streamer_stdout(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
model.config.eos_token_id = -1

input_ids = ids_tensor((1, 5), vocab_size=model.config.vocab_size).to(torch_device)
greedy_ids = model.generate(input_ids, max_new_tokens=10, do_sample=False)
greedy_text = tokenizer.decode(greedy_ids[0])

with CaptureStdout() as cs:
streamer = TextStreamer(tokenizer)
model.generate(input_ids, max_new_tokens=10, do_sample=False, streamer=streamer)

# The greedy text should be printed to stdout, except for the final "\n" in the streamer
self.assertEqual(cs.out[:-1], greedy_text)