Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,17 @@ def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError


class AggregationStrategy(ExplicitEnum):
"""
Possible values for the ``aggregation_strategy`` argument in :meth:`TokenClassificationPipeline.__init__`. Useful
for tab-completion in an IDE.
"""

FIRST = "first"
MAX = "max"
AVERAGE = "average"


def copy_func(f):
"""Returns a copy of a function f."""
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
Expand Down
116 changes: 89 additions & 27 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available
from ..file_utils import AggregationStrategy, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
ignore_labels=["O"],
task: str = "",
grouped_entities: bool = False,
aggregation_strategy: Union[str, AggregationStrategy] = AggregationStrategy.FIRST,
ignore_subwords: bool = False,
):
super().__init__(
Expand All @@ -107,6 +108,10 @@ def __init__(
self._args_parser = args_parser
self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
if isinstance(aggregation_strategy, str):
self.aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
else:
self.aggregation_strategy = aggregation_strategy
self.ignore_subwords = ignore_subwords

if self.ignore_subwords and not self.tokenizer.is_fast:
Expand Down Expand Up @@ -177,47 +182,46 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
input_ids = tokens["input_ids"].cpu().numpy()[0]

score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1)

entities = []
# Filter to labels not in `self.ignore_labels`
# Filter special_tokens
filtered_labels_idx = [
(idx, label_idx)
for idx, label_idx in enumerate(labels_idx)
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx]
]
filtered_labels_idx = [idx for idx in range(score.shape[0]) if not special_tokens_mask[idx]]

for idx, label_idx in filtered_labels_idx:
for idx in filtered_labels_idx:
entity = {}
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
entity["start"], entity["end"] = (start_ind, end_ind)
word_ref = sentence[start_ind:end_ind]
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
is_subword = len(word_ref) != len(word)
entity["word"] = word
entity["is_subword"] = len(word_ref) != len(word)

if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
is_subword = False
entity["word"] = word_ref
entity["is_subword"] = False
else:
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))

start_ind = None
end_ind = None
entity["word"] = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))

entity = {
"word": word,
"score": score[idx][label_idx].item(),
"entity": self.model.config.id2label[label_idx],
"index": idx,
"start": start_ind,
"end": end_ind,
}
entity["start"] = None
entity["end"] = None

if self.grouped_entities and self.ignore_subwords:
entity["is_subword"] = is_subword
entity["score"] = score[idx]
entity["index"] = idx

entities += [entity]

if self.ignore_subwords:
entities = self.set_subwords_label(entities)
else:
for entity in entities:
label_idx = entity["score"].argmax()
label = self.model.config.id2label[label_idx]
entity["entity"] = label
entity["score"] = entity["score"][label_idx].item()

# filter out ignored labels
entities = [entity for entity in entities if entity["entity"] not in self.ignore_labels]
if self.grouped_entities:
answers += [self.group_entities(entities)]
# Append ungrouped entities
Expand All @@ -228,6 +232,65 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
return answers[0]
return answers

def set_subwords_label(self, entities: List[dict]) -> List[dict]:
strategy = self.aggregation_strategy

def set_labels(sub_words: List[dict]) -> dict:
scores = np.stack([sub["score"] for sub in sub_words])
if strategy == AggregationStrategy.FIRST:
# get label of first sub-word
max_label_idx = scores[0].argmax()
score = scores[0][max_label_idx]
label = self.model.config.id2label[max_label_idx]
elif strategy == AggregationStrategy.MAX:
max_label_idx = np.unravel_index(np.argmax(scores, axis=None), scores.shape)[1]
score = scores.max()
label = self.model.config.id2label[max_label_idx]
elif strategy == AggregationStrategy.AVERAGE:
avg_scores = np.mean(scores, axis=0)
max_label_idx = avg_scores.argmax()
score = avg_scores[max_label_idx]
label = self.model.config.id2label[max_label_idx]
else:
raise ValueError(f"Invalid value {strategy} for option `aggregation_strategy`")

for idx, sub in enumerate(sub_words):
sub["entity"] = label
sub["score"] = score.item()

return sub_words

subword_indices = np.where([entity["is_subword"] for entity in entities])[0]
if subword_indices.size == 0:
adjacent_subwords = []
else:
# find non-consecutive indices to identify separate clusters of subwords
cluster_edges = np.where(np.diff(subword_indices) != 1)[0]
# Sets of adjacent subwords indices, e.g.
# ['Sir', 'Test', '##y', 'M', '##c', '##T', '##est', 'is', 'test', '##iful']
# --> [[2],[4,5,6],[9]]
adjacent_subwords = np.split(subword_indices, cluster_edges + 1) # shift edge by 1

word_indices = []
start = 0
for subwords in adjacent_subwords:
root_word = subwords[0] - 1
word_indices += [[idx] for idx in range(start, root_word)]
word_indices += [[root_word] + list(subwords)]
start = subwords[-1] + 1
word_indices += [[idx] for idx in range(start, len(entities))]

entities_with_label = []
for word_idx in word_indices:
subwords = [entities[idx] for idx in word_idx]
subwords = set_labels(subwords)
if self.ignore_subwords and len(subwords) > 1:
subwords[0]["word"] += "".join([sub["word"].split("##")[1] for sub in subwords[1:]])
subwords = [subwords[0]]
entities_with_label += subwords

return entities_with_label

def group_sub_entities(self, entities: List[dict]) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Expand Down Expand Up @@ -256,7 +319,6 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
Args:
entities (:obj:`dict`): The entities predicted by the pipeline.
"""

entity_groups = []
entity_group_disagg = []

Expand Down
197 changes: 197 additions & 0 deletions tests/fixtures/ner_pipeline_aligned.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"first": [
{
"word": "THe",
"score": 0.929590046,
"index": 1,
"start": 0,
"end": 1,
"is_subword": false,
"entity": "O"
},
{
"word": "quasiamerican",
"score": 0.62320358,
"index": 4,
"start": 4,
"end": 9,
"is_subword": false,
"entity": "B-LOC"
},
{
"word": "Mark",
"score": 0.9833675,
"index": 8,
"start": 18,
"end": 22,
"is_subword": false,
"entity": "B-MISC"
},
{
"word": "Musterman",
"score": 0.998437107,
"index": 9,
"start": 23,
"end": 27,
"is_subword": false,
"entity": "I-MISC"
},
{
"word": "works",
"score": 0.99952888,
"index": 11,
"start": 33,
"end": 38,
"is_subword": false,
"entity": "O"
},
{
"word": "at",
"score": 0.99974287,
"index": 12,
"start": 39,
"end": 41,
"is_subword": false,
"entity": "O"
},
{
"word": "Madeuppity",
"score": 0.7194587,
"index": 13,
"start": 42,
"end": 46,
"is_subword": false,
"entity": "B-ORG"
}
],
"max": [
{
"word": "THe",
"score": 0.96568632,
"index": 1,
"start": 0,
"end": 1,
"is_subword": false,
"entity": "O"
},
{
"word": "quasiamerican",
"score": 0.89459532,
"index": 4,
"start": 4,
"end": 9,
"is_subword": false,
"entity": "I-LOC"
},
{
"word": "Mark",
"score": 0.9833675,
"index": 8,
"start": 18,
"end": 22,
"is_subword": false,
"entity": "B-MISC"
},
{
"word": "Musterman",
"score": 0.998437107,
"index": 9,
"start": 23,
"end": 27,
"is_subword": false,
"entity": "I-MISC"
},
{
"word": "works",
"score": 0.99952888,
"index": 11,
"start": 33,
"end": 38,
"is_subword": false,
"entity": "O"
},
{
"word": "at",
"score": 0.99974287,
"index": 12,
"start": 39,
"end": 41,
"is_subword": false,
"entity": "O"
},
{
"word": "Madeuppity",
"score": 0.7194587,
"index": 13,
"start": 42,
"end": 46,
"is_subword": false,
"entity": "B-ORG"
}
],
"average": [
{
"word": "THe",
"score": 0.821620602,
"index": 1,
"start": 0,
"end": 1,
"is_subword": false,
"entity": "O"
},
{
"word": "quasiamerican",
"score": 0.49670231500000006,
"index": 4,
"start": 4,
"end": 9,
"is_subword": false,
"entity": "I-LOC"
},
{
"word": "Mark",
"score": 0.9833675,
"index": 8,
"start": 18,
"end": 22,
"is_subword": false,
"entity": "B-MISC"
},
{
"word": "Musterman",
"score": 0.9982705734999999,
"index": 9,
"start": 23,
"end": 27,
"is_subword": false,
"entity": "I-MISC"
},
{
"word": "works",
"score": 0.99952888,
"index": 11,
"start": 33,
"end": 38,
"is_subword": false,
"entity": "O"
},
{
"word": "at",
"score": 0.99974287,
"index": 12,
"start": 39,
"end": 41,
"is_subword": false,
"entity": "O"
},
{
"word": "Madeuppity",
"score": 0.6232144095,
"index": 13,
"start": 42,
"end": 46,
"is_subword": false,
"entity": "B-ORG"
}
]
}
Loading