diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 16452de1ab65..f63071b407db 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -2430,18 +2430,31 @@ def __call__( **generate_kwargs, ) - cleaned_history = self._clean_padding_history(generated_responses) + if self.model.config.is_encoder_decoder: + if self.framework == "pt": + history = torch.cat((inputs["input_ids"], generated_responses[:, 1:]), 1) + elif self.framework == "tf": + history = tf.concat([inputs["input_ids"], generated_responses[:, 1:]], 1) + else: + history = generated_responses + + history = self._clean_padding_history(history) + if self.model.config.is_encoder_decoder: + start_position = 1 + else: + start_position = input_length + output = [] for conversation_index, conversation in enumerate(conversations): conversation.mark_processed() conversation.generated_responses.append( self.tokenizer.decode( - cleaned_history[conversation_index][input_length:], + generated_responses[conversation_index][start_position:], skip_special_tokens=True, clean_up_tokenization_spaces=clean_up_tokenization_spaces, ) ) - conversation.set_history(cleaned_history[conversation_index]) + conversation.set_history(history[conversation_index]) output.append(conversation) if len(output) == 1: return output[0] @@ -2475,6 +2488,8 @@ def _clean_padding_history(self, generated_tensor) -> List[List[int]]: is_previous_pad = False for token in sequence: if token == self.tokenizer.pad_token_id: + if self.tokenizer.pad_token_id != self.tokenizer.eos_token_id: + continue if is_previous_pad: continue else: diff --git a/tests/test_pipelines_conversational.py b/tests/test_pipelines_conversational.py index 2c5da9ee2491..066dc97fef84 100644 --- a/tests/test_pipelines_conversational.py +++ b/tests/test_pipelines_conversational.py @@ -1,6 +1,6 @@ import unittest -from transformers import Conversation, pipeline +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Conversation, ConversationalPipeline, pipeline from transformers.testing_utils import require_torch, slow, torch_device from .test_pipelines_common import MonoInputPipelineCommonMixin @@ -15,8 +15,9 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator invalid_inputs = ["Hi there!", Conversation()] - def _test_pipeline(self, nlp): - # e overide the default test method to check that the output is a `Conversation` object + def _test_pipeline( + self, nlp + ): # override the default test method to check that the output is a `Conversation` object self.assertIsNotNone(nlp) # We need to recreate conversation for successive tests to pass as @@ -95,3 +96,50 @@ def test_integration_torch_conversation_truncated_history(self): self.assertEqual(len(result.generated_responses), 2) self.assertEqual(result.past_user_inputs[1], "Is it an action movie?") self.assertEqual(result.generated_responses[1], "It's a comedy.") + + @require_torch + @slow + def test_integration_torch_conversation_encoder_decoder(self): + # When + tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-90M") + model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-90M") + nlp = ConversationalPipeline(model=model, tokenizer=tokenizer, device=DEFAULT_DEVICE_NUM) + + conversation_1 = Conversation("My name is Sarah and I live in London") + conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ") + # Then + self.assertEqual(len(conversation_1.past_user_inputs), 0) + self.assertEqual(len(conversation_2.past_user_inputs), 0) + # When + result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000) + # Then + self.assertEqual(result, [conversation_1, conversation_2]) + self.assertEqual(len(result[0].past_user_inputs), 1) + self.assertEqual(len(result[1].past_user_inputs), 1) + self.assertEqual(len(result[0].generated_responses), 1) + self.assertEqual(len(result[1].generated_responses), 1) + self.assertEqual(result[0].past_user_inputs[0], "My name is Sarah and I live in London") + self.assertEqual( + result[0].generated_responses[0], + "hi sarah, i live in london as well. do you have any plans for the weekend?", + ) + self.assertEqual( + result[1].past_user_inputs[0], "Going to the movies tonight, What movie would you recommend? " + ) + self.assertEqual( + result[1].generated_responses[0], "i don't know... i'm not really sure. what movie are you going to see?" + ) + # When + conversation_1.add_user_input("Not yet, what about you?") + conversation_2.add_user_input("What's your name?") + result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000) + # Then + self.assertEqual(result, [conversation_1, conversation_2]) + self.assertEqual(len(result[0].past_user_inputs), 2) + self.assertEqual(len(result[1].past_user_inputs), 2) + self.assertEqual(len(result[0].generated_responses), 2) + self.assertEqual(len(result[1].generated_responses), 2) + self.assertEqual(result[0].past_user_inputs[1], "Not yet, what about you?") + self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.") + self.assertEqual(result[1].past_user_inputs[1], "What's your name?") + self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")