Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,8 @@ def __call__(
-- The token ids of the generated text.
"""

if isinstance(text_inputs, str):
text_inputs = [text_inputs]
results = []
for prompt_text in text_inputs:
# Manage correct placement of the tensors
Expand Down Expand Up @@ -2382,6 +2384,8 @@ def __call__(
updated generated responses for those containing a new user input.
"""

if isinstance(conversations, Conversation):
conversations = [conversations]
# Input validation
if isinstance(conversations, list):
for conversation in conversations:
Expand All @@ -2398,8 +2402,6 @@ def __call__(
assert (
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input"
elif isinstance(conversations, Conversation):
conversations = [conversations]
else:
raise ValueError("DialoguePipeline expects a Conversation or list of Conversations as an input")

Expand Down
20 changes: 12 additions & 8 deletions tests/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,30 @@
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0


class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "conversational"
small_models = [] # Models tested without the @slow decorator
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
valid_inputs = [Conversation("Hi there!"), [Conversation("Hi there!"), Conversation("How are you?")]]
invalid_inputs = ["Hi there!", Conversation()]

def _test_pipeline(
self, nlp
): # e overide the default test method to check that the output is a `Conversation` object
def _test_pipeline(self, nlp):
# e overide the default test method to check that the output is a `Conversation` object
self.assertIsNotNone(nlp)

mono_result = nlp(self.valid_inputs[0])
# We need to recreate conversation for successive tests to pass as
# Conversation objects get *consumed* by the pipeline
conversation = Conversation("Hi there!")
mono_result = nlp(conversation)
self.assertIsInstance(mono_result, Conversation)

multi_result = nlp(self.valid_inputs[1])
conversations = [Conversation("Hi there!"), Conversation("How are you?")]
multi_result = nlp(conversations)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], Conversation)
# Conversation have been consumed and are not valid anymore
# Inactive conversations passed to the pipeline raise a ValueError
self.assertRaises(ValueError, nlp, self.valid_inputs[1])
self.assertRaises(ValueError, nlp, conversation)
self.assertRaises(ValueError, nlp, conversations)

for bad_input in self.invalid_inputs:
self.assertRaises(Exception, nlp, bad_input)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

from transformers import pipeline

from .test_pipelines_common import MonoInputPipelineCommonMixin


Expand All @@ -8,3 +10,20 @@ class TextGenerationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
pipeline_running_kwargs = {"prefix": "This is "}
small_models = ["sshleifer/tiny-ctrl"] # Models tested without the @slow decorator
large_models = [] # Models tested with the @slow decorator

def test_simple_generation(self):
nlp = pipeline(task="text-generation", model=self.small_models[0])
# text-generation is non-deterministic by nature, we can't fully test the output

outputs = nlp("This is a test")

self.assertEqual(len(outputs), 1)
self.assertEqual(list(outputs[0].keys()), ["generated_text"])
self.assertEqual(type(outputs[0]["generated_text"]), str)

outputs = nlp(["This is a test", "This is a second test"])
self.assertEqual(len(outputs[0]), 1)
self.assertEqual(list(outputs[0][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[0][0]["generated_text"]), str)
self.assertEqual(list(outputs[1][0].keys()), ["generated_text"])
self.assertEqual(type(outputs[1][0]["generated_text"]), str)