diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 27ad1d462505..92b62be8ea47 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1420,9 +1420,14 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): - **word** (:obj:`str`) -- The token/word classified. - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`. - - **entity** (:obj:`str`) -- The entity predicted for that token/word. + - **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when + `grouped_entities` is set to True. - **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the corresponding token in the sentence. + - **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence. + Only exists if the offsets are available within the tokenizer + - **end** (:obj:`int`, `optional`) -- The index of the end of the corresponding entity in the sentence. + Only exists if the offsets are available within the tokenizer """ inputs, offset_mappings = self._args_parser(inputs, **kwargs) @@ -1486,11 +1491,16 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): else: word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) + start_ind = None + end_ind = None + entity = { "word": word, "score": score[idx][label_idx].item(), "entity": self.model.config.id2label[label_idx], "index": idx, + "start": start_ind, + "end": end_ind, } if self.grouped_entities and self.ignore_subwords: @@ -1524,6 +1534,8 @@ def group_sub_entities(self, entities: List[dict]) -> dict: "entity_group": entity, "score": np.mean(scores), "word": self.tokenizer.convert_tokens_to_string(tokens), + "start": entities[0]["start"], + "end": entities[-1]["end"], } return entity_group diff --git a/tests/test_pipelines_ner.py b/tests/test_pipelines_ner.py index 58da4aded63e..b4b10c48ff47 100644 --- a/tests/test_pipelines_ner.py +++ b/tests/test_pipelines_ner.py @@ -2,7 +2,7 @@ from transformers import AutoTokenizer, pipeline from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler -from transformers.testing_utils import require_tf, require_torch +from transformers.testing_utils import require_tf, require_torch, slow from .test_pipelines_common import CustomInputPipelineCommonMixin @@ -18,55 +18,207 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): large_models = [] # Models tested with the @slow decorator def _test_pipeline(self, nlp: Pipeline): - output_keys = {"entity", "word", "score"} + output_keys = {"entity", "word", "score", "start", "end"} if nlp.grouped_entities: - output_keys = {"entity_group", "word", "score"} + output_keys = {"entity_group", "word", "score", "start", "end"} ungrouped_ner_inputs = [ [ - {"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"}, - {"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"}, - {"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"}, - {"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"}, - {"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"}, - {"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"}, - {"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"}, - {"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"}, - {"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"}, - {"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"}, - {"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"}, - {"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"}, - {"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"}, + { + "entity": "B-PER", + "index": 1, + "score": 0.9994944930076599, + "is_subword": False, + "word": "Cons", + "start": 0, + "end": 4, + }, + { + "entity": "B-PER", + "index": 2, + "score": 0.8025449514389038, + "is_subword": True, + "word": "##uelo", + "start": 4, + "end": 8, + }, + { + "entity": "I-PER", + "index": 3, + "score": 0.9993102550506592, + "is_subword": False, + "word": "Ara", + "start": 9, + "end": 11, + }, + { + "entity": "I-PER", + "index": 4, + "score": 0.9993743896484375, + "is_subword": True, + "word": "##új", + "start": 11, + "end": 13, + }, + { + "entity": "I-PER", + "index": 5, + "score": 0.9992871880531311, + "is_subword": True, + "word": "##o", + "start": 13, + "end": 14, + }, + { + "entity": "I-PER", + "index": 6, + "score": 0.9993029236793518, + "is_subword": False, + "word": "No", + "start": 15, + "end": 17, + }, + { + "entity": "I-PER", + "index": 7, + "score": 0.9981776475906372, + "is_subword": True, + "word": "##guera", + "start": 17, + "end": 22, + }, + { + "entity": "B-PER", + "index": 15, + "score": 0.9998136162757874, + "is_subword": False, + "word": "Andrés", + "start": 23, + "end": 28, + }, + { + "entity": "I-PER", + "index": 16, + "score": 0.999740719795227, + "is_subword": False, + "word": "Pas", + "start": 29, + "end": 32, + }, + { + "entity": "I-PER", + "index": 17, + "score": 0.9997414350509644, + "is_subword": True, + "word": "##tran", + "start": 32, + "end": 36, + }, + { + "entity": "I-PER", + "index": 18, + "score": 0.9996136426925659, + "is_subword": True, + "word": "##a", + "start": 36, + "end": 37, + }, + { + "entity": "B-ORG", + "index": 28, + "score": 0.9989739060401917, + "is_subword": False, + "word": "Far", + "start": 39, + "end": 42, + }, + { + "entity": "I-ORG", + "index": 29, + "score": 0.7188422083854675, + "is_subword": True, + "word": "##c", + "start": 42, + "end": 43, + }, ], [ - {"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"}, - {"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"}, - {"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"}, + { + "entity": "I-PER", + "index": 1, + "score": 0.9968166351318359, + "is_subword": False, + "word": "En", + "start": 0, + "end": 2, + }, + { + "entity": "I-PER", + "index": 2, + "score": 0.9957635998725891, + "is_subword": True, + "word": "##zo", + "start": 2, + "end": 4, + }, + { + "entity": "I-ORG", + "index": 7, + "score": 0.9986497163772583, + "is_subword": False, + "word": "UN", + "start": 11, + "end": 13, + }, ], ] expected_grouped_ner_results = [ [ - {"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"}, - {"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"}, - {"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"}, + { + "entity_group": "PER", + "score": 0.999369223912557, + "word": "Consuelo Araújo Noguera", + "start": 0, + "end": 22, + }, + { + "entity_group": "PER", + "score": 0.9997771680355072, + "word": "Andrés Pastrana", + "start": 23, + "end": 37, + }, + {"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43}, ], [ - {"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"}, - {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, + {"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13}, ], ] expected_grouped_ner_results_w_subword = [ [ - {"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"}, - {"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"}, - {"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"}, - {"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"}, + {"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4}, + { + "entity_group": "PER", + "score": 0.9663328925768534, + "word": "##uelo Araújo Noguera", + "start": 4, + "end": 22, + }, + { + "entity_group": "PER", + "score": 0.9997273534536362, + "word": "Andrés Pastrana", + "start": 23, + "end": 37, + }, + {"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43}, ], [ - {"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"}, - {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"}, + {"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4}, + {"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13}, ], ] @@ -164,6 +316,34 @@ def test_pt_defaults(self): nlp = pipeline(task="ner", model=model_name) self._test_pipeline(nlp) + @slow + @require_torch + def test_simple(self): + nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True) + output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York") + + def simplify(output): + for i in range(len(output)): + output[i]["score"] = round(output[i]["score"], 3) + return output + + output = simplify(output) + + self.assertEqual( + output, + [ + { + "entity_group": "PER", + "score": 0.996, + "word": "Sarah Jessica Parker", + "start": 6, + "end": 26, + }, + {"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38}, + {"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56}, + ], + ) + @require_torch def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self): for model_name in self.small_models: