Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..tokenization_utils import PreTrainedTokenizer, TruncationStrategy
from ..utils import logging


Expand Down Expand Up @@ -577,7 +577,9 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
)

def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(
self, inputs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "kwargs" just used to catch unused params here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

):
"""
Parse arguments and tokenize
"""
Expand All @@ -587,6 +589,7 @@ def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **k
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation=truncation,
)

return inputs
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional, Union

from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline

Expand Down Expand Up @@ -317,12 +318,14 @@ def __call__(
else:
return output

def _parse_and_tokenize(self, inputs, **kwargs):
def _parse_and_tokenize(
self, inputs, add_special_tokens=False, padding=False, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is kwargs just used to catch "unused" params?

):
"""
Parse arguments and tokenize, adding an EOS token at the end of the user input
"""
# Parse arguments
inputs = self.tokenizer(inputs, add_special_tokens=False, padding=False).get("input_ids", [])
inputs = self.tokenizer(inputs, add_special_tokens=add_special_tokens, padding=padding).get("input_ids", [])
for input in inputs:
input.append(self.tokenizer.eos_token_id)
return inputs
Expand Down
18 changes: 13 additions & 5 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline

Expand Down Expand Up @@ -50,7 +51,13 @@ def check_inputs(self, input_length: int, min_length: int, max_length: int):
return True

def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
self,
*args,
return_tensors=False,
return_text=True,
clean_up_tokenization_spaces=False,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
**generate_kwargs
):
r"""
Generate the output text(s) using text(s) given as inputs.
Expand All @@ -64,6 +71,10 @@ def __call__(
Whether or not to include the decoded texts in the outputs.
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to clean up the potential extra spaces in the text output.
truncation (:obj:`TruncationStrategy`, `optional`, defaults to :obj:`TruncationStrategy.DO_NOT_TRUNCATE`):
The truncation strategy for the tokenization within the pipeline.
:obj:`TruncationStrategy.DO_NOT_TRUNCATE` (default) will never truncate, but it is sometimes desirable
to truncate the input to fit the model's max_length instead of throwing an error down the line.
generate_kwargs:
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework `here <./model.html#generative-models>`__).
Expand Down Expand Up @@ -96,7 +107,7 @@ def __call__(
)

with self.device_placement():
inputs = self._parse_and_tokenize(*args, padding=padding, **generate_kwargs)
inputs = self._parse_and_tokenize(*args, padding=padding, truncation=truncation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome


if self.framework == "pt":
inputs = self.ensure_tensor_on_device(**inputs)
Expand All @@ -108,9 +119,6 @@ def __call__(
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length)

# truncation should be used by _parse_and_tokenize
generate_kwargs.pop("truncation", None)

generations = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
Expand Down
18 changes: 4 additions & 14 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,15 @@ def __init__(self, *args, **kwargs):
self.check_model_type(self.ALLOWED_MODELS)

# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments

def _parse_and_tokenize(self, inputs, padding=True, add_special_tokens=True, **kwargs):
def _parse_and_tokenize(self, *args, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
tokenizer_kwargs = {"add_space_before_punct_symbol": True}
else:
tokenizer_kwargs = {}
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
**tokenizer_kwargs,
)

return inputs
kwargs.update({"add_space_before_punct_symbol": True})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! I like this change


return super()._parse_and_tokenize(*args, **kwargs)

def __call__(
self,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/pipelines/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from ..file_utils import add_end_docstrings
from ..tokenization_utils import TruncationStrategy
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Pipeline

Expand Down Expand Up @@ -78,7 +79,14 @@ def entailment_id(self):
return -1

def _parse_and_tokenize(
self, sequences, candidate_labels, hypothesis_template, padding=True, add_special_tokens=True, **kwargs
self,
sequences,
candidate_labels,
hypothesis_template,
padding=True,
add_special_tokens=True,
truncation=TruncationStrategy.ONLY_FIRST,
**kwargs
):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
Expand All @@ -89,7 +97,7 @@ def _parse_and_tokenize(
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation="only_first",
truncation=truncation,
)

return inputs
Expand Down
59 changes: 58 additions & 1 deletion tests/test_pipelines_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,72 @@

import unittest

from transformers import pipeline
from transformers import AutoTokenizer, is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.tokenization_utils import TruncationStrategy

from .test_pipelines_common import MonoInputPipelineCommonMixin


if is_torch_available():
import torch

from transformers.models.bart import BartConfig, BartForConditionalGeneration

DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0


class SimpleSummarizationPipelineTests(unittest.TestCase):
@require_torch
def test_input_too_long(self):
torch.manual_seed(0)
config = BartConfig(
vocab_size=257,
d_model=32,
encoder_layers=1,
decoder_layers=1,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
# So any text > 4 should raise an exception
max_position_embeddings=4,
encoder_attention_heads=1,
decoder_attention_heads=1,
max_length=4,
min_length=1,
)
model = BartForConditionalGeneration(config)
# Bias output towards L
V, C = model.lm_head.weight.shape

bias = torch.zeros(V, requires_grad=True)
bias[76] = 10

model.lm_head.bias = torch.nn.Parameter(bias)

# # Generated with:
# import tempfile
# from tokenizers import Tokenizer, models
# from transformers import PreTrainedTokenizerFast
# model_max_length = 4
# vocab = [(chr(i), i) for i in range(256)]
# tokenizer = Tokenizer(models.Unigram(vocab))
# with tempfile.NamedTemporaryFile() as f:
# tokenizer.save(f.name)
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, model_max_length=model_max_length)
# real_tokenizer._tokenizer.save("tokenizer.json")
# # + add missing config.json with albert as model_type
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_summarization_test")
nlp = pipeline(task="summarization", model=model, tokenizer=tokenizer)

with self.assertLogs("transformers", level="WARNING"):
with self.assertRaises(IndexError):
_ = nlp("This is a test")

output = nlp("This is a test", truncation=TruncationStrategy.ONLY_FIRST)
# 2 is default BOS from Bart.
self.assertEqual(output, [{"summary_text": "\x02 L L L"}])


class SummarizationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "summarization"
pipeline_running_kwargs = {"num_beams": 2, "min_length": 2, "max_length": 5}
Expand Down