Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@
from .object_detection import ObjectDetectionPipeline
from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline
from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
from .text2text_generation import Text2TextGenerationPipeline
from .text_classification import TextClassificationPipeline
from .text_generation import TextGenerationPipeline
from .text_generation import SummarizationPipeline, TextGenerationPipeline, TranslationPipeline
from .text_to_audio import TextToAudioPipeline
from .token_classification import (
AggregationStrategy,
Expand Down Expand Up @@ -267,7 +267,7 @@
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"impl": TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {
Expand All @@ -288,7 +288,7 @@
"type": "text",
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"impl": TextGenerationPipeline,
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
"default": {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},
Expand Down
179 changes: 0 additions & 179 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,6 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int):
"""
return True

def _parse_and_tokenize(self, *args, truncation):
prefix = self.model.config.prefix if self.model.config.prefix is not None else ""
if isinstance(args[0], list):
if self.tokenizer.pad_token_id is None:
raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
args = ([prefix + arg for arg in args[0]],)
padding = True

elif isinstance(args[0], str):
args = (prefix + args[0],)
padding = False
else:
raise ValueError(
f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
)
inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors=self.framework)
# This is produced by tokenizers but is an invalid generate kwargs
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
return inputs

def __call__(self, *args, **kwargs):
r"""
Generate the output text(s) using text(s) given as inputs.
Expand Down Expand Up @@ -211,161 +190,3 @@ def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_token
}
records.append(record)
return records


@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class SummarizationPipeline(Text2TextGenerationPipeline):
"""
Summarize news articles and other documents.

This summarizing pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"summarization"`.

The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
currently, '*bart-large-cnn*', '*google-t5/t5-small*', '*google-t5/t5-base*', '*google-t5/t5-large*', '*google-t5/t5-3b*', '*google-t5/t5-11b*'. See the up-to-date
list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list
of available parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)

Usage:

```python
# use bart in pytorch
summarizer = pipeline("summarization")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)

# use t5 in tf
summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
```"""

# Used in the return key of the pipeline.
return_name = "summary"

def __call__(self, *args, **kwargs):
r"""
Summarize the text(s) given as inputs.

Args:
documents (*str* or `List[str]`):
One or several articles (or one list of articles) to summarize.
return_text (`bool`, *optional*, defaults to `True`):
Whether or not to include the decoded texts in the outputs
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the potential extra spaces in the text output.
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework [here](./model#generative-models)).

Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:

- **summary_text** (`str`, present when `return_text=True`) -- The summary of the corresponding input.
- **summary_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
ids of the summary.
"""
return super().__call__(*args, **kwargs)

def check_inputs(self, input_length: int, min_length: int, max_length: int) -> bool:
"""
Checks whether there might be something wrong with given input with regard to the model.
"""
if max_length < min_length:
logger.warning(f"Your min_length={min_length} must be inferior than your max_length={max_length}.")

if input_length < max_length:
logger.warning(
f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is "
"a summarization task, where outputs shorter than the input are typically wanted, you might "
f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length//2})"
)


@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class TranslationPipeline(Text2TextGenerationPipeline):
"""
Translates from one language to another.

This translation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"translation_xx_to_yy"`.

The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation).
For a list of available parameters, see the [following
documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)

Usage:

```python
en_fr_translator = pipeline("translation_en_to_fr")
en_fr_translator("How old are you?")
```"""

# Used in the return key of the pipeline.
return_name = "translation"

def check_inputs(self, input_length: int, min_length: int, max_length: int):
if input_length > 0.9 * max_length:
logger.warning(
f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
"increasing your max_length manually, e.g. translator('...', max_length=400)"
)
return True

def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None):
if getattr(self.tokenizer, "_build_translation_inputs", None):
return self.tokenizer._build_translation_inputs(
*args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang
)
else:
return super()._parse_and_tokenize(*args, truncation=truncation)

def _sanitize_parameters(self, src_lang=None, tgt_lang=None, **kwargs):
preprocess_params, forward_params, postprocess_params = super()._sanitize_parameters(**kwargs)
if src_lang is not None:
preprocess_params["src_lang"] = src_lang
if tgt_lang is not None:
preprocess_params["tgt_lang"] = tgt_lang
if src_lang is None and tgt_lang is None:
# Backward compatibility, direct arguments use is preferred.
task = kwargs.get("task", self.task)
items = task.split("_")
if task and len(items) == 4:
# translation, XX, to YY
preprocess_params["src_lang"] = items[1]
preprocess_params["tgt_lang"] = items[3]
return preprocess_params, forward_params, postprocess_params

def __call__(self, *args, **kwargs):
r"""
Translate the text(s) given as inputs.

Args:
args (`str` or `List[str]`):
Texts to be translated.
return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to include the tensors of predictions (as token indices) in the outputs.
return_text (`bool`, *optional*, defaults to `True`):
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the potential extra spaces in the text output.
src_lang (`str`, *optional*):
The language of the input. Might be required for multilingual models. Will not have any effect for
single pair translation models
tgt_lang (`str`, *optional*):
The language of the desired output. Might be required for multilingual models. Will not have any effect
for single pair translation models
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework [here](./model#generative-models)).

Return:
A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:

- **translation_text** (`str`, present when `return_text=True`) -- The translation.
- **translation_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The
token ids of the translation.
"""
return super().__call__(*args, **kwargs)
Loading
Loading