diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index e37607c136b8..4ec85e1f4c19 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -22,6 +22,7 @@ import uuid import warnings from abc import ABC, abstractmethod +from collections.abc import Iterable from contextlib import contextmanager from os.path import abspath, exists from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -1597,55 +1598,52 @@ class QuestionAnsweringArgumentHandler(ArgumentHandler): command-line supplied arguments. """ + def normalize(self, item): + if isinstance(item, SquadExample): + return item + elif isinstance(item, dict): + for k in ["question", "context"]: + if k not in item: + raise KeyError("You need to provide a dictionary with keys {question:..., context:...}") + elif item[k] is None: + raise ValueError("`{}` cannot be None".format(k)) + elif isinstance(item[k], str) and len(item[k]) == 0: + raise ValueError("`{}` cannot be empty".format(k)) + + return QuestionAnsweringPipeline.create_sample(**item) + raise ValueError("{} argument needs to be of type (SquadExample, dict)".format(item)) + def __call__(self, *args, **kwargs): - # Position args, handling is sensibly the same as X and data, so forwarding to avoid duplicating + # Detect where the actual inputs are if args is not None and len(args) > 0: if len(args) == 1: - kwargs["X"] = args[0] + inputs = args[0] + elif len(args) == 2 and {type(el) for el in args} == {str}: + inputs = [{"question": args[0], "context": args[1]}] else: - kwargs["X"] = list(args) - + inputs = list(args) # Generic compatibility with sklearn and Keras # Batched data - if "X" in kwargs or "data" in kwargs: - inputs = kwargs["X"] if "X" in kwargs else kwargs["data"] - - if isinstance(inputs, dict): - inputs = [inputs] - else: - # Copy to avoid overriding arguments - inputs = [i for i in inputs] - - for i, item in enumerate(inputs): - if isinstance(item, dict): - if any(k not in item for k in ["question", "context"]): - raise KeyError("You need to provide a dictionary with keys {question:..., context:...}") - - inputs[i] = QuestionAnsweringPipeline.create_sample(**item) - - elif not isinstance(item, SquadExample): - raise ValueError( - "{} argument needs to be of type (list[SquadExample | dict], SquadExample, dict)".format( - "X" if "X" in kwargs else "data" - ) - ) - - # Tabular input + elif "X" in kwargs: + inputs = kwargs["X"] + elif "data" in kwargs: + inputs = kwargs["data"] elif "question" in kwargs and "context" in kwargs: - if isinstance(kwargs["question"], str): - kwargs["question"] = [kwargs["question"]] - - if isinstance(kwargs["context"], str): - kwargs["context"] = [kwargs["context"]] - - inputs = [ - QuestionAnsweringPipeline.create_sample(q, c) for q, c in zip(kwargs["question"], kwargs["context"]) - ] + inputs = [{"question": kwargs["question"], "context": kwargs["context"]}] else: raise ValueError("Unknown arguments {}".format(kwargs)) - if not isinstance(inputs, list): + # Normalize inputs + if isinstance(inputs, dict): inputs = [inputs] + elif isinstance(inputs, Iterable): + # Copy to avoid overriding arguments + inputs = [i for i in inputs] + else: + raise ValueError("Invalid arguments {}".format(inputs)) + + for i, item in enumerate(inputs): + inputs[i] = self.normalize(item) return inputs diff --git a/tests/test_pipelines_question_answering.py b/tests/test_pipelines_question_answering.py index 3f3f6dc83a72..54b306c09d88 100644 --- a/tests/test_pipelines_question_answering.py +++ b/tests/test_pipelines_question_answering.py @@ -1,6 +1,7 @@ import unittest -from transformers.pipelines import Pipeline +from transformers.data.processors.squad import SquadExample +from transformers.pipelines import Pipeline, QuestionAnsweringArgumentHandler from .test_pipelines_common import CustomInputPipelineCommonMixin @@ -43,5 +44,116 @@ def _test_pipeline(self, nlp: Pipeline): for key in output_keys: self.assertIn(key, result) for bad_input in invalid_inputs: - self.assertRaises(Exception, nlp, bad_input) - self.assertRaises(Exception, nlp, invalid_inputs) + self.assertRaises(ValueError, nlp, bad_input) + self.assertRaises(ValueError, nlp, invalid_inputs) + + def test_argument_handler(self): + qa = QuestionAnsweringArgumentHandler() + + Q = "Where was HuggingFace founded ?" + C = "HuggingFace was founded in Paris" + + normalized = qa(Q, C) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa(question=Q, context=C) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa(question=Q, context=C) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa({"question": Q, "context": C}) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa([{"question": Q, "context": C}]) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa([{"question": Q, "context": C}, {"question": Q, "context": C}]) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 2) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa(X={"question": Q, "context": C}) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa(X=[{"question": Q, "context": C}]) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + normalized = qa(data={"question": Q, "context": C}) + self.assertEqual(type(normalized), list) + self.assertEqual(len(normalized), 1) + self.assertEqual({type(el) for el in normalized}, {SquadExample}) + + def test_argument_handler_error_handling(self): + qa = QuestionAnsweringArgumentHandler() + + Q = "Where was HuggingFace founded ?" + C = "HuggingFace was founded in Paris" + + with self.assertRaises(KeyError): + qa({"context": C}) + with self.assertRaises(KeyError): + qa({"question": Q}) + with self.assertRaises(KeyError): + qa([{"context": C}]) + with self.assertRaises(ValueError): + qa(None, C) + with self.assertRaises(ValueError): + qa("", C) + with self.assertRaises(ValueError): + qa(Q, None) + with self.assertRaises(ValueError): + qa(Q, "") + + with self.assertRaises(ValueError): + qa(question=None, context=C) + with self.assertRaises(ValueError): + qa(question="", context=C) + with self.assertRaises(ValueError): + qa(question=Q, context=None) + with self.assertRaises(ValueError): + qa(question=Q, context="") + + with self.assertRaises(ValueError): + qa({"question": None, "context": C}) + with self.assertRaises(ValueError): + qa({"question": "", "context": C}) + with self.assertRaises(ValueError): + qa({"question": Q, "context": None}) + with self.assertRaises(ValueError): + qa({"question": Q, "context": ""}) + + with self.assertRaises(ValueError): + qa([{"question": Q, "context": C}, {"question": None, "context": C}]) + with self.assertRaises(ValueError): + qa([{"question": Q, "context": C}, {"question": "", "context": C}]) + + with self.assertRaises(ValueError): + qa([{"question": Q, "context": C}, {"question": Q, "context": None}]) + with self.assertRaises(ValueError): + qa([{"question": Q, "context": C}, {"question": Q, "context": ""}]) + + def test_argument_handler_error_handling_odd(self): + qa = QuestionAnsweringArgumentHandler() + with self.assertRaises(ValueError): + qa(None) + + with self.assertRaises(ValueError): + qa(Y=None) + + with self.assertRaises(ValueError): + qa(1)