diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 41d9aded4c1b..124f2e290ebc 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -24,7 +24,7 @@ from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..modelcard import ModelCard -from ..tokenization_utils import PreTrainedTokenizer +from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy from ..utils import logging @@ -577,7 +577,9 @@ def check_model_type(self, supported_models: Union[List[str], dict]): f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}", ) - def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs): + def _parse_and_tokenize( + self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs + ): """ Parse arguments and tokenize """ @@ -587,6 +589,7 @@ def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **k add_special_tokens=add_special_tokens, return_tensors=self.framework, padding=padding, + truncation=truncation, ) return inputs diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 9dce94626d24..7e22b8b92b24 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available +from ..tokenization_utils import TruncationStrategy from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -317,12 +318,14 @@ def __call__( else: return output - def _parse_and_tokenize(self, inputs, **kwargs): + def _parse_and_tokenize( + self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs + ): """ Parse arguments and tokenize, adding an EOS token at the end of the user input """ # Parse arguments - inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", []) + inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", []) for input in inputs: input.append(self.tokenizer.eos_token_id) return inputs diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 67a2eb11e358..3fb7d00c6eb5 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -1,4 +1,5 @@ from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available +from ..tokenization_utils import TruncationStrategy from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -50,7 +51,13 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int): return True def __call__( - self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs + self, + *args, + return_tensors=False, + return_text=True, + clean_up_tokenization_spaces=False, + truncation=TruncationStrategy.DO_NOT_TRUNCATE, + **generate_kwargs ): r""" Generate the output text(s) using text(s) given as inputs. @@ -64,6 +71,10 @@ def __call__( Whether or not to include the decoded texts in the outputs. clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to clean up the potential extra spaces in the text output. + truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`): + The truncation strategy for the tokenization within the pipeline. + :obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable + to truncate the input to fit the model's max_length instead of throwing an error down the line. generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method corresponding to your framework `here <./model.html#generative-models>`__). @@ -96,7 +107,7 @@ def __call__( ) with self.device_placement(): - inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs) + inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation) if self.framework == "pt": inputs = self.ensure_tensor_on_device(**inputs) @@ -108,9 +119,6 @@ def __call__( max_length = generate_kwargs.get("max_length", self.model.config.max_length) self.check_inputs(input_length, min_length, max_length) - # truncation should be used by _parse_and_tokenize - generate_kwargs.pop("truncation", None) - generations = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 9b33fb273e29..5ace38930ae6 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -50,25 +50,15 @@ def __init__(self, *args, **kwargs): self.check_model_type(self.ALLOWED_MODELS) # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments - - def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs): + def _parse_and_tokenize(self, *args, **kwargs): """ Parse arguments and tokenize """ # Parse arguments if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: - tokenizer_kwargs = {"add_space_before_punct_symbol": True} - else: - tokenizer_kwargs = {} - inputs = self.tokenizer( - inputs, - add_special_tokens=add_special_tokens, - return_tensors=self.framework, - padding=padding, - **tokenizer_kwargs, - ) - - return inputs + kwargs.update({"add_space_before_punct_symbol": True}) + + return super()._parse_and_tokenize(*args, **kwargs) def __call__( self, diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index b3c292888dcf..380188d286d4 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -3,6 +3,7 @@ import numpy as np from ..file_utils import add_end_docstrings +from ..tokenization_utils import TruncationStrategy from ..utils import logging from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline @@ -78,7 +79,14 @@ def entailment_id(self): return -1 def _parse_and_tokenize( - self, sequences, candidate_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs + self, + sequences, + candidate_labels, + hypothesis_template, + padding=True, + add_special_tokens=True, + truncation=TruncationStrategy.ONLY_FIRST, + **kwargs ): """ Parse arguments and tokenize only_first so that hypothesis (label) is not truncated @@ -89,7 +97,7 @@ def _parse_and_tokenize( add_special_tokens=add_special_tokens, return_tensors=self.framework, padding=padding, - truncation="only_first", + truncation=truncation, ) return inputs diff --git a/tests/test_pipelines_summarization.py b/tests/test_pipelines_summarization.py index 0622053204b4..2d2bc3330db2 100644 --- a/tests/test_pipelines_summarization.py +++ b/tests/test_pipelines_summarization.py @@ -14,15 +14,72 @@ import unittest -from transformers import pipeline +from transformers import AutoTokenizer, is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device +from transformers.tokenization_utils import TruncationStrategy from .test_pipelines_common import MonoInputPipelineCommonMixin +if is_torch_available(): + import torch + + from transformers.models.bart import BartConfig, BartForConditionalGeneration + DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 +class SimpleSummarizationPipelineTests(unittest.TestCase): + @require_torch + def test_input_too_long(self): + torch.manual_seed(0) + config = BartConfig( + vocab_size=257, + d_model=32, + encoder_layers=1, + decoder_layers=1, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + # So any text > 4 should raise an exception + max_position_embeddings=4, + encoder_attention_heads=1, + decoder_attention_heads=1, + max_length=4, + min_length=1, + ) + model = BartForConditionalGeneration(config) + # Bias output towards L + V, C = model.lm_head.weight.shape + + bias = torch.zeros(V, requires_grad=True) + bias[76] = 10 + + model.lm_head.bias = torch.nn.Parameter(bias) + + # # Generated with: + # import tempfile + # from tokenizers import Tokenizer, models + # from transformers import PreTrainedTokenizerFast + # model_max_length = 4 + # vocab = [(chr(i), i) for i in range(256)] + # tokenizer = Tokenizer(models.Unigram(vocab)) + # with tempfile.NamedTemporaryFile() as f: + # tokenizer.save(f.name) + # real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length) + # real_tokenizer._tokenizer.save("tokenizer.json") + # # + add missing config.json with albert as model_type + tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test") + nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer) + + with self.assertLogs("transformers", level="WARNING"): + with self.assertRaises(IndexError): + _ = nlp("This is a test") + + output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST) + # 2 is default BOS from Bart. + self.assertEqual(output, [{"summary_text": "\x02 L L L"}]) + + class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "summarization" pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}