From a09f5f8ba4b2fb6a6017743d2401f04698e09ff3 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 8 Mar 2024 13:23:59 +0000 Subject: [PATCH] Fix the listification bug Move methods to the new classes Handle TranslationPipeline getting an extra kwarg Remove debug breakpoint again Correct logging import make fixup Reparent SummarizationPipeline and TranslationPipeline onto TextGenerationPipeline Change SUPPORTED_TASKS to move tasks to TextGenerationPipeline Change the pipeline tag for Seq2Seq models Fix postprocessing for Seq2Seq generation Cleanup the add_special_tokens logic Make the model checking work properly Pushed my debug breakpoint again Silently praying that TextGenerationPipeline still works after this --- src/transformers/pipelines/__init__.py | 8 +- .../pipelines/text2text_generation.py | 179 ----------- src/transformers/pipelines/text_generation.py | 285 ++++++++++++++++-- utils/update_metadata.py | 2 +- 4 files changed, 264 insertions(+), 210 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 8ee0137a20b3..50516f8e66b0 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -72,9 +72,9 @@ from .object_detection import ObjectDetectionPipeline from .question_answering import QuestionAnsweringArgumentHandler, QuestionAnsweringPipeline from .table_question_answering import TableQuestionAnsweringArgumentHandler, TableQuestionAnsweringPipeline -from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline +from .text2text_generation import Text2TextGenerationPipeline from .text_classification import TextClassificationPipeline -from .text_generation import TextGenerationPipeline +from .text_generation import SummarizationPipeline, TextGenerationPipeline, TranslationPipeline from .text_to_audio import TextToAudioPipeline from .token_classification import ( AggregationStrategy, @@ -267,7 +267,7 @@ "type": "text", }, "summarization": { - "impl": SummarizationPipeline, + "impl": TextGenerationPipeline, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": { @@ -288,7 +288,7 @@ "type": "text", }, "text2text-generation": { - "impl": Text2TextGenerationPipeline, + "impl": TextGenerationPipeline, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}}, diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index bb8abdfcf7f5..52e4570aa32f 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -114,27 +114,6 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int): """ return True - def _parse_and_tokenize(self, *args, truncation): - prefix = self.model.config.prefix if self.model.config.prefix is not None else "" - if isinstance(args[0], list): - if self.tokenizer.pad_token_id is None: - raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input") - args = ([prefix + arg for arg in args[0]],) - padding = True - - elif isinstance(args[0], str): - args = (prefix + args[0],) - padding = False - else: - raise ValueError( - f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`" - ) - inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors=self.framework) - # This is produced by tokenizers but is an invalid generate kwargs - if "token_type_ids" in inputs: - del inputs["token_type_ids"] - return inputs - def __call__(self, *args, **kwargs): r""" Generate the output text(s) using text(s) given as inputs. @@ -211,161 +190,3 @@ def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_token } records.append(record) return records - - -@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) -class SummarizationPipeline(Text2TextGenerationPipeline): - """ - Summarize news articles and other documents. - - This summarizing pipeline can currently be loaded from [`pipeline`] using the following task identifier: - `"summarization"`. - - The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is - currently, '*bart-large-cnn*', '*google-t5/t5-small*', '*google-t5/t5-base*', '*google-t5/t5-large*', '*google-t5/t5-3b*', '*google-t5/t5-11b*'. See the up-to-date - list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list - of available parameters, see the [following - documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate) - - Usage: - - ```python - # use bart in pytorch - summarizer = pipeline("summarization") - summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20) - - # use t5 in tf - summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf") - summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20) - ```""" - - # Used in the return key of the pipeline. - return_name = "summary" - - def __call__(self, *args, **kwargs): - r""" - Summarize the text(s) given as inputs. - - Args: - documents (*str* or `List[str]`): - One or several articles (or one list of articles) to summarize. - return_text (`bool`, *optional*, defaults to `True`): - Whether or not to include the decoded texts in the outputs - return_tensors (`bool`, *optional*, defaults to `False`): - Whether or not to include the tensors of predictions (as token indices) in the outputs. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): - Whether or not to clean up the potential extra spaces in the text output. - generate_kwargs: - Additional keyword arguments to pass along to the generate method of the model (see the generate method - corresponding to your framework [here](./model#generative-models)). - - Return: - A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: - - - **summary_text** (`str`, present when `return_text=True`) -- The summary of the corresponding input. - - **summary_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token - ids of the summary. - """ - return super().__call__(*args, **kwargs) - - def check_inputs(self, input_length: int, min_length: int, max_length: int) -> bool: - """ - Checks whether there might be something wrong with given input with regard to the model. - """ - if max_length < min_length: - logger.warning(f"Your min_length={min_length} must be inferior than your max_length={max_length}.") - - if input_length < max_length: - logger.warning( - f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is " - "a summarization task, where outputs shorter than the input are typically wanted, you might " - f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length//2})" - ) - - -@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) -class TranslationPipeline(Text2TextGenerationPipeline): - """ - Translates from one language to another. - - This translation pipeline can currently be loaded from [`pipeline`] using the following task identifier: - `"translation_xx_to_yy"`. - - The models that this pipeline can use are models that have been fine-tuned on a translation task. See the - up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation). - For a list of available parameters, see the [following - documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate) - - Usage: - - ```python - en_fr_translator = pipeline("translation_en_to_fr") - en_fr_translator("How old are you?") - ```""" - - # Used in the return key of the pipeline. - return_name = "translation" - - def check_inputs(self, input_length: int, min_length: int, max_length: int): - if input_length > 0.9 * max_length: - logger.warning( - f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider " - "increasing your max_length manually, e.g. translator('...', max_length=400)" - ) - return True - - def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None): - if getattr(self.tokenizer, "_build_translation_inputs", None): - return self.tokenizer._build_translation_inputs( - *args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang - ) - else: - return super()._parse_and_tokenize(*args, truncation=truncation) - - def _sanitize_parameters(self, src_lang=None, tgt_lang=None, **kwargs): - preprocess_params, forward_params, postprocess_params = super()._sanitize_parameters(**kwargs) - if src_lang is not None: - preprocess_params["src_lang"] = src_lang - if tgt_lang is not None: - preprocess_params["tgt_lang"] = tgt_lang - if src_lang is None and tgt_lang is None: - # Backward compatibility, direct arguments use is preferred. - task = kwargs.get("task", self.task) - items = task.split("_") - if task and len(items) == 4: - # translation, XX, to YY - preprocess_params["src_lang"] = items[1] - preprocess_params["tgt_lang"] = items[3] - return preprocess_params, forward_params, postprocess_params - - def __call__(self, *args, **kwargs): - r""" - Translate the text(s) given as inputs. - - Args: - args (`str` or `List[str]`): - Texts to be translated. - return_tensors (`bool`, *optional*, defaults to `False`): - Whether or not to include the tensors of predictions (as token indices) in the outputs. - return_text (`bool`, *optional*, defaults to `True`): - Whether or not to include the decoded texts in the outputs. - clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): - Whether or not to clean up the potential extra spaces in the text output. - src_lang (`str`, *optional*): - The language of the input. Might be required for multilingual models. Will not have any effect for - single pair translation models - tgt_lang (`str`, *optional*): - The language of the desired output. Might be required for multilingual models. Will not have any effect - for single pair translation models - generate_kwargs: - Additional keyword arguments to pass along to the generate method of the model (see the generate method - corresponding to your framework [here](./model#generative-models)). - - Return: - A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: - - - **translation_text** (`str`, present when `return_text=True`) -- The translation. - - **translation_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The - token ids of the translation. - """ - return super().__call__(*args, **kwargs) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index ef64fb84dddd..9d4c3dde3577 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -2,17 +2,26 @@ import warnings from typing import Dict -from ..utils import add_end_docstrings, is_tf_available, is_torch_available +from ..tokenization_utils import TruncationStrategy +from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging from .base import Pipeline, build_pipeline_init_args if is_torch_available(): - from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + from ..models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + ) if is_tf_available(): import tensorflow as tf - from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + from ..models.auto.modeling_tf_auto import ( + TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, + ) + +logger = logging.get_logger(__name__) class ReturnType(enum.Enum): @@ -80,11 +89,27 @@ class TextGenerationPipeline(Pipeline): begging for his blessing. """ + return_name = "generated" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.check_model_type( - TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - ) + + if self.framework == "tf": + mapping_names = TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.copy() + update_names = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + else: + mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.copy() + update_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + for key, val in update_names.items(): + # Combine the two dicts, merging matching keys as tuples + if key not in mapping_names: + mapping_names[key] = val + elif isinstance(mapping_names[key], str): + mapping_names[key] = (mapping_names[key], val) + else: + mapping_names[key] = mapping_names[key] + (val,) + + self.check_model_type(mapping_names) if "prefix" not in self._preprocess_params: # This is very specific. The logic is quite complex and needs to be done # as a "default". @@ -117,12 +142,23 @@ def _sanitize_parameters( prefix=None, handle_long_generation=None, stop_sequence=None, - add_special_tokens=False, + add_special_tokens=None, truncation=None, padding=False, max_length=None, **generate_kwargs, ): + if self.framework == "tf": + seq2seq_lm_map = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + else: + seq2seq_lm_map = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES + if self.model.__class__.__name__ in seq2seq_lm_map.values(): + self.text2text = True + else: + self.text2text = False + if add_special_tokens is None: + add_special_tokens = self.text2text + preprocess_params = { "add_special_tokens": add_special_tokens, "truncation": truncation, @@ -178,17 +214,6 @@ def _sanitize_parameters( return preprocess_params, forward_params, postprocess_params - # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments - def _parse_and_tokenize(self, *args, **kwargs): - """ - Parse arguments and tokenize - """ - # Parse arguments - if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: - kwargs.update({"add_space_before_punct_symbol": True}) - - return super()._parse_and_tokenize(*args, **kwargs) - def __call__(self, text_inputs, **kwargs): """ Complete the prompt(s) given as inputs. @@ -233,12 +258,21 @@ def __call__(self, text_inputs, **kwargs): if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)): # We have one or more prompts in list-of-dicts format, so this is chat mode if isinstance(text_inputs[0], dict): - return super().__call__(Chat(text_inputs), **kwargs) + out = super().__call__(Chat(text_inputs), **kwargs) else: chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈 - return super().__call__(chats, **kwargs) + out = super().__call__(chats, **kwargs) else: - return super().__call__(text_inputs, **kwargs) + out = super().__call__(text_inputs, **kwargs) + if ( + self.text2text + and isinstance(text_inputs, list) + and all(isinstance(el, str) for el in text_inputs) + and all(len(res) == 1 for res in out) + ): + return [res[0] for res in out] + else: + return out def preprocess( self, @@ -331,17 +365,36 @@ def _forward(self, model_inputs, **generate_kwargs): generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) elif self.framework == "tf": generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:])) - return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} + return {f"{self.return_name}_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): - generated_sequence = model_outputs["generated_sequence"][0] + if self.text2text: + # Matt: In Seq2Seq generation the output sequence is separate from the input sequence, and so we don't + # ever want to concatenate the output to the input. + records = [] + for output_ids in model_outputs[f"{self.return_name}_sequence"][0]: + if return_type == ReturnType.TENSORS: + record = {f"{self.return_name}_token_ids": output_ids} + else: + # TODO Matt is there any return type besides TENSORS where we shouldn't do this? + record = { + f"{self.return_name}_text": self.tokenizer.decode( + output_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + ) + } + records.append(record) + return records + + generated_sequence = model_outputs[f"{self.return_name}_sequence"][0] input_ids = model_outputs["input_ids"] prompt_text = model_outputs["prompt_text"] generated_sequence = generated_sequence.numpy().tolist() records = [] for sequence in generated_sequence: if return_type == ReturnType.TENSORS: - record = {"generated_token_ids": sequence} + record = {f"{self.return_name}_token_ids": sequence} elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}: # Decode text text = self.tokenizer.decode( @@ -361,7 +414,6 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_ clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) ) - all_text = text[prompt_length:] if return_type == ReturnType.FULL_TEXT: if isinstance(prompt_text, str): @@ -369,7 +421,188 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_ elif isinstance(prompt_text, Chat): all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}] - record = {"generated_text": all_text} + record = {f"{self.return_name}_text": all_text} records.append(record) return records + + +@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) +class SummarizationPipeline(TextGenerationPipeline): + """ + Summarize news articles and other documents. + + This summarizing pipeline can currently be loaded from [`pipeline`] using the following task identifier: + `"summarization"`. + + The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is + currently, '*bart-large-cnn*', '*google-t5/t5-small*', '*google-t5/t5-base*', '*google-t5/t5-large*', '*google-t5/t5-3b*', '*google-t5/t5-11b*'. See the up-to-date + list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list + of available parameters, see the [following + documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate) + + Usage: + + ```python + # use bart in pytorch + summarizer = pipeline("summarization") + summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20) + + # use t5 in tf + summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf") + summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20) + ```""" + + # Used in the return key of the pipeline. + return_name = "summary" + + def __call__(self, *args, **kwargs): + r""" + Summarize the text(s) given as inputs. + + Args: + documents (*str* or `List[str]`): + One or several articles (or one list of articles) to summarize. + return_text (`bool`, *optional*, defaults to `True`): + Whether or not to include the decoded texts in the outputs + return_tensors (`bool`, *optional*, defaults to `False`): + Whether or not to include the tensors of predictions (as token indices) in the outputs. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the potential extra spaces in the text output. + generate_kwargs: + Additional keyword arguments to pass along to the generate method of the model (see the generate method + corresponding to your framework [here](./model#generative-models)). + + Return: + A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: + + - **summary_text** (`str`, present when `return_text=True`) -- The summary of the corresponding input. + - **summary_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token + ids of the summary. + """ + return super().__call__(*args, **kwargs) + + def check_inputs(self, input_length: int, min_length: int, max_length: int) -> bool: + """ + Checks whether there might be something wrong with given input with regard to the model. + """ + if max_length < min_length: + logger.warning(f"Your min_length={min_length} must be inferior than your max_length={max_length}.") + + if input_length < max_length: + logger.warning( + f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is " + "a summarization task, where outputs shorter than the input are typically wanted, you might " + f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length//2})" + ) + + +@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) +class TranslationPipeline(TextGenerationPipeline): + """ + Translates from one language to another. + + This translation pipeline can currently be loaded from [`pipeline`] using the following task identifier: + `"translation_xx_to_yy"`. + + The models that this pipeline can use are models that have been fine-tuned on a translation task. See the + up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation). + For a list of available parameters, see the [following + documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate) + + Usage: + + ```python + en_fr_translator = pipeline("translation_en_to_fr") + en_fr_translator("How old are you?") + ```""" + + # Used in the return key of the pipeline. + return_name = "translation" + + def check_inputs(self, input_length: int, min_length: int, max_length: int): + if input_length > 0.9 * max_length: + logger.warning( + f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider " + "increasing your max_length manually, e.g. translator('...', max_length=400)" + ) + return True + + def preprocess( + self, prompt_text, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None, **kwargs + ): + if getattr(self.tokenizer, "_build_translation_inputs", None): + out = self.tokenizer._build_translation_inputs( + prompt_text, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang + ) + else: + out = self._parse_and_tokenize(prompt_text, truncation=truncation) + out["prompt_text"] = prompt_text + return out + + def _parse_and_tokenize(self, prompt_text, truncation): + prefix = self.model.config.prefix if self.model.config.prefix is not None else "" + if isinstance(prompt_text, list): + if self.tokenizer.pad_token_id is None: + raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input") + prompt_text = [prefix + arg for arg in prompt_text] + padding = True + + elif isinstance(prompt_text, str): + prompt_text = prefix + prompt_text + padding = False + else: + raise ValueError("prompt_text has the wrong format. This should be either of type `str` or type `list`") + inputs = self.tokenizer(prompt_text, padding=padding, truncation=truncation, return_tensors=self.framework) + # This is produced by tokenizers but is an invalid generate kwargs + if "token_type_ids" in inputs: + del inputs["token_type_ids"] + return inputs + + def _sanitize_parameters(self, src_lang=None, tgt_lang=None, **kwargs): + preprocess_params, forward_params, postprocess_params = super()._sanitize_parameters(**kwargs) + if src_lang is not None: + preprocess_params["src_lang"] = src_lang + if tgt_lang is not None: + preprocess_params["tgt_lang"] = tgt_lang + if src_lang is None and tgt_lang is None: + # Backward compatibility, direct arguments use is preferred. + task = kwargs.get("task", self.task) + items = task.split("_") + if task and len(items) == 4: + # translation, XX, to YY + preprocess_params["src_lang"] = items[1] + preprocess_params["tgt_lang"] = items[3] + return preprocess_params, forward_params, postprocess_params + + def __call__(self, *args, **kwargs): + r""" + Translate the text(s) given as inputs. + + Args: + args (`str` or `List[str]`): + Texts to be translated. + return_tensors (`bool`, *optional*, defaults to `False`): + Whether or not to include the tensors of predictions (as token indices) in the outputs. + return_text (`bool`, *optional*, defaults to `True`): + Whether or not to include the decoded texts in the outputs. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the potential extra spaces in the text output. + src_lang (`str`, *optional*): + The language of the input. Might be required for multilingual models. Will not have any effect for + single pair translation models + tgt_lang (`str`, *optional*): + The language of the desired output. Might be required for multilingual models. Will not have any effect + for single pair translation models + generate_kwargs: + Additional keyword arguments to pass along to the generate method of the model (see the generate method + corresponding to your framework [here](./model#generative-models)). + + Return: + A list or a list of list of `dict`: Each result comes as a dictionary with the following keys: + + - **translation_text** (`str`, present when `return_text=True`) -- The translation. + - **translation_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The + token ids of the translation. + """ + return super().__call__(*args, **kwargs) diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 0762c4c2aa73..a9ce80146e9f 100755 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -77,7 +77,7 @@ "AutoModelForZeroShotObjectDetection", ), ("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"), - ("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"), + ("text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"), ("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"), ("automatic-speech-recognition", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "AutoModelForSpeechSeq2Seq"), (