Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
"end": end_ind,
}

if self.grouped_entities or self.subword_label_re_alignment:
if self.grouped_entities and (self.subword_label_re_alignment or self.ignore_subwords):
entity["is_subword"] = is_subword

entities += [entity]
Expand All @@ -220,10 +220,10 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
label_idx = entity["score"].argmax()
label = self.model.config.id2label[label_idx]
entity["entity"] = label
entity["score"] = entity["score"][label_idx]
entity["score"] = entity["score"][label_idx].item()

# I think we should check self.subword_label_re_alignment here too
# because we can't use self.grouped_entities if self.subword_label_re_alignment is false
# filter out ignored labels
entities = [entity for entity in entities if entity["entity"] not in self.ignore_labels]
if self.grouped_entities:
answers += [self.group_entities(entities)]
# Append ungrouped entities
Expand Down Expand Up @@ -258,7 +258,7 @@ def sub_words_label(sub_words: List[dict]) -> dict:

for idx, sub in enumerate(sub_words):
sub["entity"] = label
sub["score"] = score[idx][max_label_idx]
sub["score"] = score[idx][max_label_idx].item()

return sub_words

Expand All @@ -278,7 +278,7 @@ def sub_words_label(sub_words: List[dict]) -> dict:
label_idx = begin_sub["score"].argmax()
label = self.model.config.id2label[label_idx]
begin_sub["entity"] = label
begin_sub["score"] = begin_sub["score"][label_idx]
begin_sub["score"] = begin_sub["score"][label_idx].item()

entities_with_label.append(begin_sub)
word_group_disagg = []
Expand Down Expand Up @@ -314,7 +314,6 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
Args:
entities (:obj:`dict`): The entities predicted by the pipeline.
"""

entity_groups = []
entity_group_disagg = []

Expand All @@ -324,7 +323,7 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
for entity in entities:

is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and entity["is_subword"]
is_subword = (self.ignore_subwords or self.subword_label_re_alignment) and entity["is_subword"]
if not entity_group_disagg:
entity_group_disagg += [entity]
if is_last_idx:
Expand Down