From 1e7489bf1e7dd05a3e6d8e3a9a77b53344ab239f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 23 Dec 2020 13:47:40 +0100 Subject: [PATCH 01/10] Enable TruncationStrategy override for pipelines --- src/transformers/pipelines/base.py | 7 +- src/transformers/pipelines/conversational.py | 7 +- src/transformers/pipelines/text_generation.py | 18 ++--- .../pipelines/zero_shot_classification.py | 12 +++- tests/test_pipelines_summarization.py | 66 ++++++++++++++++++- 5 files changed, 89 insertions(+), 21 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 41d9aded4c1b..124f2e290ebc 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -24,7 +24,7 @@ from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..modelcard import ModelCard -from ..tokenization_utils import PreTrainedTokenizer +from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy from ..utils import logging @@ -577,7 +577,9 @@ def check_model_type(self, supported_models: Union[List[str], dict]): f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}", ) - def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs): + def _parse_and_tokenize( + self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs + ): """ Parse arguments and tokenize """ @@ -587,6 +589,7 @@ def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **k add_special_tokens=add_special_tokens, return_tensors=self.framework, padding=padding, + truncation=truncation, ) return inputs diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 9dce94626d24..34c67b89d245 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -3,6 +3,7 @@ from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..utils import logging +from ..tokenization_utils import TruncationStrategy from .base import PIPELINE_INIT_ARGS, Pipeline @@ -317,12 +318,14 @@ def __call__( else: return output - def _parse_and_tokenize(self, inputs, **kwargs): + def _parse_and_tokenize( + self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs + ): """ Parse arguments and tokenize, adding an EOS token at the end of the user input """ # Parse arguments - inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", []) + inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", []) for input in inputs: input.append(self.tokenizer.eos_token_id) return inputs diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 9b33fb273e29..6217718dc4ed 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -50,25 +50,15 @@ def __init__(self, *args, **kwargs): self.check_model_type(self.ALLOWED_MODELS) # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments - - def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs): + def _parse_and_tokenize(self, **kwargs): """ Parse arguments and tokenize """ # Parse arguments if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: - tokenizer_kwargs = {"add_space_before_punct_symbol": True} - else: - tokenizer_kwargs = {} - inputs = self.tokenizer( - inputs, - add_special_tokens=add_special_tokens, - return_tensors=self.framework, - padding=padding, - **tokenizer_kwargs, - ) - - return inputs + kwargs.update({"add_space_before_punct_symbol": True}) + + return super()._parse_and_tokenize(**kwargs) def __call__( self, diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index b3c292888dcf..c52aa3821ac6 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -4,6 +4,7 @@ from ..file_utils import add_end_docstrings from ..utils import logging +from ..tokenization_utils import TruncationStrategy from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline @@ -78,7 +79,14 @@ def entailment_id(self): return -1 def _parse_and_tokenize( - self, sequences, candidate_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs + self, + sequences, + candidate_labels, + hypothesis_template, + padding=True, + add_special_tokens=True, + truncation=TruncationStrategy.ONLY_FIRST, + **kwargs ): """ Parse arguments and tokenize only_first so that hypothesis (label) is not truncated @@ -89,7 +97,7 @@ def _parse_and_tokenize( add_special_tokens=add_special_tokens, return_tensors=self.framework, padding=padding, - truncation="only_first", + truncation=truncation, ) return inputs diff --git a/tests/test_pipelines_summarization.py b/tests/test_pipelines_summarization.py index 0622053204b4..75ef3a5b2bde 100644 --- a/tests/test_pipelines_summarization.py +++ b/tests/test_pipelines_summarization.py @@ -14,14 +14,78 @@ import unittest -from transformers import pipeline +from transformers import is_torch_available, pipeline +from transformers.models.bart import BartConfig, BartForConditionalGeneration from transformers.testing_utils import require_torch, slow, torch_device +from transformers.tokenization_utils import TruncationStrategy from .test_pipelines_common import MonoInputPipelineCommonMixin DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 +if is_torch_available(): + import torch + + +class DummyTok: + pad_token_id = 0 + + def __init__(self, **kwargs): + for name, v in kwargs.items(): + setattr(self, name, v) + + def __call__(self, inputs, **kwargs): + if isinstance(inputs, str): + input_ids = self.encode(inputs).unsqueeze(0) + else: + input_ids = torch.nn.utils.rnn.pad_sequence( + [self.encode(input_) for input_ in inputs], + padding_value=self.pad_token_id, + ) + + if kwargs.get("truncation", TruncationStrategy.DO_NOT_TRUNCATE) == TruncationStrategy.ONLY_FIRST: + input_ids = input_ids[:, : self.model_max_length] + attention_mask = torch.zeros_like(input_ids).long() + 1 + return {"input_ids": input_ids, "attention_mask": attention_mask} + + def encode(self, input_): + return torch.LongTensor(list(input_.encode("utf-8"))) + + def decode(self, sequence, **kwargs): + try: + return bytes(sequence).decode("utf-8") + except Exception: + return "D" * len(sequence) + + +class SimpleSummarizationPipelineTests(unittest.TestCase): + @require_torch + def test_input_too_long(self): + config = BartConfig( + vocab_size=257, + d_model=32, + encoder_layers=1, + decoder_layers=1, + encoder_ffn_dim=32, + decoder_ffn_dim=32, + # So any text > 4 should raise an exception + max_position_embeddings=4, + encoder_attention_heads=1, + decoder_attention_heads=1, + max_length=4, + min_length=1, + ) + model = BartForConditionalGeneration(config) + tokenizer = DummyTok(model_max_length=4) + nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer) + + with self.assertRaises(IndexError): + _ = nlp("This is a test") + + output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST) + self.assertEquals(output, [{"summary_text": "\0\0\0\0"}]) + class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "summarization" From 3836835878aaecf90527c454281129a0762a6c98 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Jan 2021 10:27:16 +0100 Subject: [PATCH 02/10] Update isort. --- src/transformers/pipelines/conversational.py | 2 +- .../pipelines/zero_shot_classification.py | 2 +- src/transformers/test.py | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 src/transformers/test.py diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 34c67b89d245..7e22b8b92b24 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -2,8 +2,8 @@ from typing import List, Optional, Union from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available -from ..utils import logging from ..tokenization_utils import TruncationStrategy +from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index c52aa3821ac6..380188d286d4 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -3,8 +3,8 @@ import numpy as np from ..file_utils import add_end_docstrings -from ..utils import logging from ..tokenization_utils import TruncationStrategy +from ..utils import logging from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline diff --git a/src/transformers/test.py b/src/transformers/test.py new file mode 100644 index 000000000000..f4fac9183c3f --- /dev/null +++ b/src/transformers/test.py @@ -0,0 +1,19 @@ +import datetime + +import numpy as np + +import onnxruntime as rt + + +sess_options = rt.SessionOptions() + +# Set graph optimization level +sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL + +sess = rt.InferenceSession("onnx/gpt2-optimized-quantized.onnx", sess_options) +input_name = sess.get_inputs()[0].name +X_test = np.zeros((1, 50)) +start = datetime.datetime.now() +pred_onx = sess.run(None, {input_name: X_test.astype(np.long)})[0] +print(datetime.datetime.now() - start) +print(pred_onx) From 64cf667d30076851b0f14f15f8a9406ca6da6487 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Jan 2021 10:38:04 +0100 Subject: [PATCH 03/10] Fixing test --- tests/test_pipelines_summarization.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_pipelines_summarization.py b/tests/test_pipelines_summarization.py index 75ef3a5b2bde..7f8e33ae90c4 100644 --- a/tests/test_pipelines_summarization.py +++ b/tests/test_pipelines_summarization.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from string import ascii_lowercase from transformers import is_torch_available, pipeline from transformers.models.bart import BartConfig, BartForConditionalGeneration @@ -53,15 +54,17 @@ def encode(self, input_): return torch.LongTensor(list(input_.encode("utf-8"))) def decode(self, sequence, **kwargs): - try: - return bytes(sequence).decode("utf-8") - except Exception: - return "D" * len(sequence) + N = len(ascii_lowercase) + output = "" + for i in range(len(sequence)): + output += ascii_lowercase[i % N] + return output class SimpleSummarizationPipelineTests(unittest.TestCase): @require_torch def test_input_too_long(self): + torch.manual_seed(0) config = BartConfig( vocab_size=257, d_model=32, @@ -84,7 +87,7 @@ def test_input_too_long(self): _ = nlp("This is a test") output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST) - self.assertEquals(output, [{"summary_text": "\0\0\0\0"}]) + self.assertEqual(output, [{"summary_text": "ab"}]) class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): From ce40c3f0c20827ab263f34fc586fd1a7cf0a68e5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Jan 2021 10:56:10 +0100 Subject: [PATCH 04/10] Fixing text_generation pipeline. --- src/transformers/pipelines/text_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 6217718dc4ed..5ace38930ae6 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -50,7 +50,7 @@ def __init__(self, *args, **kwargs): self.check_model_type(self.ALLOWED_MODELS) # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments - def _parse_and_tokenize(self, **kwargs): + def _parse_and_tokenize(self, *args, **kwargs): """ Parse arguments and tokenize """ @@ -58,7 +58,7 @@ def _parse_and_tokenize(self, **kwargs): if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]: kwargs.update({"add_space_before_punct_symbol": True}) - return super()._parse_and_tokenize(**kwargs) + return super()._parse_and_tokenize(*args, **kwargs) def __call__( self, From 9c6676c1d983e6625ac4a3cf02aef84e49832853 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Jan 2021 11:08:50 +0100 Subject: [PATCH 05/10] Using same DummyTok as other PR for easier merge later. --- tests/test_pipelines_common.py | 70 ++++++++++++++++++++++++++- tests/test_pipelines_summarization.py | 35 +------------- 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index c8a66053a307..6f8aa46579a7 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -12,18 +12,86 @@ # See the License for the specific language governing permissions and # limitations under the License. +from string import ascii_lowercase from typing import List, Optional from unittest import mock from transformers import is_tf_available, is_torch_available, pipeline from transformers.pipelines import Pipeline from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow -from transformers.tokenization_utils_base import to_py_obj +from transformers.tokenization_utils_base import TruncationStrategy, to_py_obj +if is_torch_available(): + import torch + VALID_INPUTS = ["A simple string", ["list of strings"]] +class DummyTok: + pad_token_id = 0 + eos_token_id = 0 + + def __init__(self, **kwargs): + for name, v in kwargs.items(): + setattr(self, name, v) + self.index = 0 + + def __call__(self, inputs, **kwargs): + if kwargs.get("return_tensors", "") == "pt": + return self.encode_pt(inputs, **kwargs) + else: + return self.encode_list(inputs, **kwargs) + + def encode_list(self, inputs, **kwargs): + unwrap = False + if isinstance(inputs, str): + unwrap = True + inputs = [inputs] + + assert isinstance(inputs, list) + input_ids = [self.encode(input_) for input_ in inputs] + + if unwrap: + input_ids = input_ids[0] + + return {"input_ids": input_ids} + + def encode_pt(self, inputs, **kwargs): + if isinstance(inputs, str): + input_ids = torch.LongTensor(self.encode(inputs)).unsqueeze(0) + else: + input_ids = self._pad([self.encode(input_) for input_ in inputs]) + return self.finalize_pt(input_ids, **kwargs) + + def finalize_pt(self, input_ids, **kwargs): + if kwargs.get("truncation", TruncationStrategy.DO_NOT_TRUNCATE) == TruncationStrategy.ONLY_FIRST: + input_ids = input_ids[:, : self.model_max_length] + attention_mask = torch.zeros_like(input_ids).long() + 1 + return {"input_ids": input_ids, "attention_mask": attention_mask} + + def _pad(self, inputs): + return torch.nn.utils.rnn.pad_sequence( + [torch.LongTensor(input_) for input_ in inputs], + padding_value=self.pad_token_id, + ).transpose(1, 0) + + def pad(self, inputs, **kwargs): + input_ids = self._pad(inputs["input_ids"]) + return self.finalize_pt(input_ids, **kwargs) + + def encode(self, input_): + return list(input_.encode("utf-8")) + + def decode(self, sequence, **kwargs): + string = "" + for i in range(len(sequence)): + string += ascii_lowercase[self.index] + self.index += 1 + self.index %= len(ascii_lowercase) + return string + + @is_pipeline_test class CustomInputPipelineCommonMixin: pipeline_task = None diff --git a/tests/test_pipelines_summarization.py b/tests/test_pipelines_summarization.py index 7f8e33ae90c4..f6a72eba8796 100644 --- a/tests/test_pipelines_summarization.py +++ b/tests/test_pipelines_summarization.py @@ -13,14 +13,13 @@ # limitations under the License. import unittest -from string import ascii_lowercase from transformers import is_torch_available, pipeline from transformers.models.bart import BartConfig, BartForConditionalGeneration from transformers.testing_utils import require_torch, slow, torch_device from transformers.tokenization_utils import TruncationStrategy -from .test_pipelines_common import MonoInputPipelineCommonMixin +from .test_pipelines_common import DummyTok, MonoInputPipelineCommonMixin DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 @@ -29,38 +28,6 @@ import torch -class DummyTok: - pad_token_id = 0 - - def __init__(self, **kwargs): - for name, v in kwargs.items(): - setattr(self, name, v) - - def __call__(self, inputs, **kwargs): - if isinstance(inputs, str): - input_ids = self.encode(inputs).unsqueeze(0) - else: - input_ids = torch.nn.utils.rnn.pad_sequence( - [self.encode(input_) for input_ in inputs], - padding_value=self.pad_token_id, - ) - - if kwargs.get("truncation", TruncationStrategy.DO_NOT_TRUNCATE) == TruncationStrategy.ONLY_FIRST: - input_ids = input_ids[:, : self.model_max_length] - attention_mask = torch.zeros_like(input_ids).long() + 1 - return {"input_ids": input_ids, "attention_mask": attention_mask} - - def encode(self, input_): - return torch.LongTensor(list(input_.encode("utf-8"))) - - def decode(self, sequence, **kwargs): - N = len(ascii_lowercase) - output = "" - for i in range(len(sequence)): - output += ascii_lowercase[i % N] - return output - - class SimpleSummarizationPipelineTests(unittest.TestCase): @require_torch def test_input_too_long(self): From 1ec452cdc6ad75cee29346e3f3b4baa1a8c481ce Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Jan 2021 11:18:34 +0100 Subject: [PATCH 06/10] Some more import guards. --- tests/test_pipelines_summarization.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_pipelines_summarization.py b/tests/test_pipelines_summarization.py index f6a72eba8796..374e75dfc451 100644 --- a/tests/test_pipelines_summarization.py +++ b/tests/test_pipelines_summarization.py @@ -15,18 +15,19 @@ import unittest from transformers import is_torch_available, pipeline -from transformers.models.bart import BartConfig, BartForConditionalGeneration from transformers.testing_utils import require_torch, slow, torch_device from transformers.tokenization_utils import TruncationStrategy from .test_pipelines_common import DummyTok, MonoInputPipelineCommonMixin -DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 - if is_torch_available(): import torch + from transformers.models.bart import BartConfig, BartForConditionalGeneration + +DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 + class SimpleSummarizationPipelineTests(unittest.TestCase): @require_torch From e4dc04ae94b7e6bdfb597e2fa105be55fbd43f6c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Jan 2021 14:12:35 +0100 Subject: [PATCH 07/10] Remove bogus file. --- src/transformers/test.py | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 src/transformers/test.py diff --git a/src/transformers/test.py b/src/transformers/test.py deleted file mode 100644 index f4fac9183c3f..000000000000 --- a/src/transformers/test.py +++ /dev/null @@ -1,19 +0,0 @@ -import datetime - -import numpy as np - -import onnxruntime as rt - - -sess_options = rt.SessionOptions() - -# Set graph optimization level -sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL - -sess = rt.InferenceSession("onnx/gpt2-optimized-quantized.onnx", sess_options) -input_name = sess.get_inputs()[0].name -X_test = np.zeros((1, 50)) -start = datetime.datetime.now() -pred_onx = sess.run(None, {input_name: X_test.astype(np.long)})[0] -print(datetime.datetime.now() - start) -print(pred_onx) From 2e81658a860d2ea5ffb460a66c7d6a150f49e7f5 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 8 Jan 2021 10:51:58 +0100 Subject: [PATCH 08/10] Do not pass `generate_kwargs` to `_parse_and_tokenize`. @patrickvonplaten --- .../pipelines/text2text_generation.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 67a2eb11e358..4d67b4fb14d9 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -1,4 +1,5 @@ from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available +from ..tokenization_utils import TruncationStrategy from ..utils import logging from .base import PIPELINE_INIT_ARGS, Pipeline @@ -50,7 +51,13 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int): return True def __call__( - self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs + self, + *args, + return_tensors=False, + return_text=True, + clean_up_tokenization_spaces=False, + truncation=TruncationStrategy.DO_NOT_TRUNCATE, + **generate_kwargs ): r""" Generate the output text(s) using text(s) given as inputs. @@ -64,6 +71,9 @@ def __call__( Whether or not to include the decoded texts in the outputs. clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to clean up the potential extra spaces in the text output. + truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`): + The truncation strategy for the tokenization within the pipeline. :obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable + to truncate the input to fit the model's max_length instead of throwing an error down the line. generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method corresponding to your framework `here <./model.html#generative-models>`__). @@ -96,7 +106,7 @@ def __call__( ) with self.device_placement(): - inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs) + inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation) if self.framework == "pt": inputs = self.ensure_tensor_on_device(**inputs) @@ -108,9 +118,6 @@ def __call__( max_length = generate_kwargs.get("max_length", self.model.config.max_length) self.check_inputs(input_length, min_length, max_length) - # truncation should be used by _parse_and_tokenize - generate_kwargs.pop("truncation", None) - generations = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], From b5585decfe52e73169d1b08e95dd7441ea9b660d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 8 Jan 2021 12:41:16 +0100 Subject: [PATCH 09/10] Removed DummyTok. --- tests/test_pipelines_common.py | 70 +-------------------------- tests/test_pipelines_summarization.py | 34 ++++++++++--- 2 files changed, 29 insertions(+), 75 deletions(-) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 6f8aa46579a7..c8a66053a307 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -12,86 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from string import ascii_lowercase from typing import List, Optional from unittest import mock from transformers import is_tf_available, is_torch_available, pipeline from transformers.pipelines import Pipeline from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow -from transformers.tokenization_utils_base import TruncationStrategy, to_py_obj +from transformers.tokenization_utils_base import to_py_obj -if is_torch_available(): - import torch - VALID_INPUTS = ["A simple string", ["list of strings"]] -class DummyTok: - pad_token_id = 0 - eos_token_id = 0 - - def __init__(self, **kwargs): - for name, v in kwargs.items(): - setattr(self, name, v) - self.index = 0 - - def __call__(self, inputs, **kwargs): - if kwargs.get("return_tensors", "") == "pt": - return self.encode_pt(inputs, **kwargs) - else: - return self.encode_list(inputs, **kwargs) - - def encode_list(self, inputs, **kwargs): - unwrap = False - if isinstance(inputs, str): - unwrap = True - inputs = [inputs] - - assert isinstance(inputs, list) - input_ids = [self.encode(input_) for input_ in inputs] - - if unwrap: - input_ids = input_ids[0] - - return {"input_ids": input_ids} - - def encode_pt(self, inputs, **kwargs): - if isinstance(inputs, str): - input_ids = torch.LongTensor(self.encode(inputs)).unsqueeze(0) - else: - input_ids = self._pad([self.encode(input_) for input_ in inputs]) - return self.finalize_pt(input_ids, **kwargs) - - def finalize_pt(self, input_ids, **kwargs): - if kwargs.get("truncation", TruncationStrategy.DO_NOT_TRUNCATE) == TruncationStrategy.ONLY_FIRST: - input_ids = input_ids[:, : self.model_max_length] - attention_mask = torch.zeros_like(input_ids).long() + 1 - return {"input_ids": input_ids, "attention_mask": attention_mask} - - def _pad(self, inputs): - return torch.nn.utils.rnn.pad_sequence( - [torch.LongTensor(input_) for input_ in inputs], - padding_value=self.pad_token_id, - ).transpose(1, 0) - - def pad(self, inputs, **kwargs): - input_ids = self._pad(inputs["input_ids"]) - return self.finalize_pt(input_ids, **kwargs) - - def encode(self, input_): - return list(input_.encode("utf-8")) - - def decode(self, sequence, **kwargs): - string = "" - for i in range(len(sequence)): - string += ascii_lowercase[self.index] - self.index += 1 - self.index %= len(ascii_lowercase) - return string - - @is_pipeline_test class CustomInputPipelineCommonMixin: pipeline_task = None diff --git a/tests/test_pipelines_summarization.py b/tests/test_pipelines_summarization.py index 374e75dfc451..2d2bc3330db2 100644 --- a/tests/test_pipelines_summarization.py +++ b/tests/test_pipelines_summarization.py @@ -14,11 +14,11 @@ import unittest -from transformers import is_torch_available, pipeline +from transformers import AutoTokenizer, is_torch_available, pipeline from transformers.testing_utils import require_torch, slow, torch_device from transformers.tokenization_utils import TruncationStrategy -from .test_pipelines_common import DummyTok, MonoInputPipelineCommonMixin +from .test_pipelines_common import MonoInputPipelineCommonMixin if is_torch_available(): @@ -48,14 +48,36 @@ def test_input_too_long(self): min_length=1, ) model = BartForConditionalGeneration(config) - tokenizer = DummyTok(model_max_length=4) + # Bias output towards L + V, C = model.lm_head.weight.shape + + bias = torch.zeros(V, requires_grad=True) + bias[76] = 10 + + model.lm_head.bias = torch.nn.Parameter(bias) + + # # Generated with: + # import tempfile + # from tokenizers import Tokenizer, models + # from transformers import PreTrainedTokenizerFast + # model_max_length = 4 + # vocab = [(chr(i), i) for i in range(256)] + # tokenizer = Tokenizer(models.Unigram(vocab)) + # with tempfile.NamedTemporaryFile() as f: + # tokenizer.save(f.name) + # real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length) + # real_tokenizer._tokenizer.save("tokenizer.json") + # # + add missing config.json with albert as model_type + tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test") nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer) - with self.assertRaises(IndexError): - _ = nlp("This is a test") + with self.assertLogs("transformers", level="WARNING"): + with self.assertRaises(IndexError): + _ = nlp("This is a test") output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST) - self.assertEqual(output, [{"summary_text": "ab"}]) + # 2 is default BOS from Bart. + self.assertEqual(output, [{"summary_text": "\x02 L L L"}]) class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): From 0975eb67c606552af8dcb813adc0e4cb502dae09 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 8 Jan 2021 13:03:34 +0100 Subject: [PATCH 10/10] Doc quality. --- src/transformers/pipelines/text2text_generation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/text2text_generation.py b/src/transformers/pipelines/text2text_generation.py index 4d67b4fb14d9..3fb7d00c6eb5 100644 --- a/src/transformers/pipelines/text2text_generation.py +++ b/src/transformers/pipelines/text2text_generation.py @@ -72,7 +72,8 @@ def __call__( clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to clean up the potential extra spaces in the text output. truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`): - The truncation strategy for the tokenization within the pipeline. :obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable + The truncation strategy for the tokenization within the pipeline. + :obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's max_length instead of throwing an error down the line. generate_kwargs: Additional keyword arguments to pass along to the generate method of the model (see the generate method