Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,9 +1420,14 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):

- **word** (:obj:`str`) -- The token/word classified.
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
- **entity** (:obj:`str`) -- The entity predicted for that token/word.
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when
`grouped_entities` is set to True.
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the
corresponding token in the sentence.
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer
- **end** (:obj:`int`, `optional`) -- The index of the end of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer
"""

inputs, offset_mappings = self._args_parser(inputs, **kwargs)
Expand Down Expand Up @@ -1486,11 +1491,16 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
else:
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))

start_ind = None
end_ind = None

entity = {
"word": word,
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
"start": start_ind,
"end": end_ind,
}

if self.grouped_entities and self.ignore_subwords:
Expand Down Expand Up @@ -1524,6 +1534,8 @@ def group_sub_entities(self, entities: List[dict]) -> dict:
"entity_group": entity,
"score": np.mean(scores),
"word": self.tokenizer.convert_tokens_to_string(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return entity_group

Expand Down
240 changes: 210 additions & 30 deletions tests/test_pipelines_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import AutoTokenizer, pipeline
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch
from transformers.testing_utils import require_tf, require_torch, slow

from .test_pipelines_common import CustomInputPipelineCommonMixin

Expand All @@ -18,55 +18,207 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
large_models = [] # Models tested with the @slow decorator

def _test_pipeline(self, nlp: Pipeline):
output_keys = {"entity", "word", "score"}
output_keys = {"entity", "word", "score", "start", "end"}
if nlp.grouped_entities:
output_keys = {"entity_group", "word", "score"}
output_keys = {"entity_group", "word", "score", "start", "end"}

ungrouped_ner_inputs = [
[
{"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "is_subword": False, "word": "Cons"},
{"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "is_subword": True, "word": "##uelo"},
{"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "is_subword": False, "word": "Ara"},
{"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "is_subword": True, "word": "##új"},
{"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "is_subword": True, "word": "##o"},
{"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "is_subword": False, "word": "No"},
{"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "is_subword": True, "word": "##guera"},
{"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "is_subword": False, "word": "Andrés"},
{"entity": "I-PER", "index": 16, "score": 0.999740719795227, "is_subword": False, "word": "Pas"},
{"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "is_subword": True, "word": "##tran"},
{"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "is_subword": True, "word": "##a"},
{"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "is_subword": False, "word": "Far"},
{"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "is_subword": True, "word": "##c"},
{
"entity": "B-PER",
"index": 1,
"score": 0.9994944930076599,
"is_subword": False,
"word": "Cons",
"start": 0,
"end": 4,
},
{
"entity": "B-PER",
"index": 2,
"score": 0.8025449514389038,
"is_subword": True,
"word": "##uelo",
"start": 4,
"end": 8,
},
{
"entity": "I-PER",
"index": 3,
"score": 0.9993102550506592,
"is_subword": False,
"word": "Ara",
"start": 9,
"end": 11,
},
{
"entity": "I-PER",
"index": 4,
"score": 0.9993743896484375,
"is_subword": True,
"word": "##új",
"start": 11,
"end": 13,
},
{
"entity": "I-PER",
"index": 5,
"score": 0.9992871880531311,
"is_subword": True,
"word": "##o",
"start": 13,
"end": 14,
},
{
"entity": "I-PER",
"index": 6,
"score": 0.9993029236793518,
"is_subword": False,
"word": "No",
"start": 15,
"end": 17,
},
{
"entity": "I-PER",
"index": 7,
"score": 0.9981776475906372,
"is_subword": True,
"word": "##guera",
"start": 17,
"end": 22,
},
{
"entity": "B-PER",
"index": 15,
"score": 0.9998136162757874,
"is_subword": False,
"word": "Andrés",
"start": 23,
"end": 28,
},
{
"entity": "I-PER",
"index": 16,
"score": 0.999740719795227,
"is_subword": False,
"word": "Pas",
"start": 29,
"end": 32,
},
{
"entity": "I-PER",
"index": 17,
"score": 0.9997414350509644,
"is_subword": True,
"word": "##tran",
"start": 32,
"end": 36,
},
{
"entity": "I-PER",
"index": 18,
"score": 0.9996136426925659,
"is_subword": True,
"word": "##a",
"start": 36,
"end": 37,
},
{
"entity": "B-ORG",
"index": 28,
"score": 0.9989739060401917,
"is_subword": False,
"word": "Far",
"start": 39,
"end": 42,
},
{
"entity": "I-ORG",
"index": 29,
"score": 0.7188422083854675,
"is_subword": True,
"word": "##c",
"start": 42,
"end": 43,
},
],
[
{"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "is_subword": False, "word": "En"},
{"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "is_subword": True, "word": "##zo"},
{"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "is_subword": False, "word": "UN"},
{
"entity": "I-PER",
"index": 1,
"score": 0.9968166351318359,
"is_subword": False,
"word": "En",
"start": 0,
"end": 2,
},
{
"entity": "I-PER",
"index": 2,
"score": 0.9957635998725891,
"is_subword": True,
"word": "##zo",
"start": 2,
"end": 4,
},
{
"entity": "I-ORG",
"index": 7,
"score": 0.9986497163772583,
"is_subword": False,
"word": "UN",
"start": 11,
"end": 13,
},
],
]

expected_grouped_ner_results = [
[
{"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"},
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc"},
{
"entity_group": "PER",
"score": 0.999369223912557,
"word": "Consuelo Araújo Noguera",
"start": 0,
"end": 22,
},
{
"entity_group": "PER",
"score": 0.9997771680355072,
"word": "Andrés Pastrana",
"start": 23,
"end": 37,
},
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43},
],
[
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
],
]

expected_grouped_ner_results_w_subword = [
[
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons"},
{"entity_group": "PER", "score": 0.9663328925768534, "word": "##uelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc"},
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4},
{
"entity_group": "PER",
"score": 0.9663328925768534,
"word": "##uelo Araújo Noguera",
"start": 4,
"end": 22,
},
{
"entity_group": "PER",
"score": 0.9997273534536362,
"word": "Andrés Pastrana",
"start": 23,
"end": 37,
},
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43},
],
[
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo"},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN"},
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13},
],
]

Expand Down Expand Up @@ -164,6 +316,34 @@ def test_pt_defaults(self):
nlp = pipeline(task="ner", model=model_name)
self._test_pipeline(nlp)

@slow
@require_torch
def test_simple(self):
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York")

def simplify(output):
for i in range(len(output)):
output[i]["score"] = round(output[i]["score"], 3)
return output

output = simplify(output)

self.assertEqual(
output,
[
{
"entity_group": "PER",
"score": 0.996,
"word": "Sarah Jessica Parker",
"start": 6,
"end": 26,
},
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
],
)

@require_torch
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
for model_name in self.small_models:
Expand Down