diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index b4f4428eb58c..2ce95185bb42 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -208,7 +208,7 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): "end": end_ind, } - if self.grouped_entities or self.subword_label_re_alignment: + if self.grouped_entities and (self.subword_label_re_alignment or self.ignore_subwords): entity["is_subword"] = is_subword entities += [entity] @@ -220,10 +220,10 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): label_idx = entity["score"].argmax() label = self.model.config.id2label[label_idx] entity["entity"] = label - entity["score"] = entity["score"][label_idx] + entity["score"] = entity["score"][label_idx].item() - # I think we should check self.subword_label_re_alignment here too - # because we can't use self.grouped_entities if self.subword_label_re_alignment is false + # filter out ignored labels + entities = [entity for entity in entities if entity["entity"] not in self.ignore_labels] if self.grouped_entities: answers += [self.group_entities(entities)] # Append ungrouped entities @@ -258,7 +258,7 @@ def sub_words_label(sub_words: List[dict]) -> dict: for idx, sub in enumerate(sub_words): sub["entity"] = label - sub["score"] = score[idx][max_label_idx] + sub["score"] = score[idx][max_label_idx].item() return sub_words @@ -278,7 +278,7 @@ def sub_words_label(sub_words: List[dict]) -> dict: label_idx = begin_sub["score"].argmax() label = self.model.config.id2label[label_idx] begin_sub["entity"] = label - begin_sub["score"] = begin_sub["score"][label_idx] + begin_sub["score"] = begin_sub["score"][label_idx].item() entities_with_label.append(begin_sub) word_group_disagg = [] @@ -314,7 +314,6 @@ def group_entities(self, entities: List[dict]) -> List[dict]: Args: entities (:obj:`dict`): The entities predicted by the pipeline. """ - entity_groups = [] entity_group_disagg = [] @@ -324,7 +323,7 @@ def group_entities(self, entities: List[dict]) -> List[dict]: for entity in entities: is_last_idx = entity["index"] == last_idx - is_subword = self.ignore_subwords and entity["is_subword"] + is_subword = (self.ignore_subwords or self.subword_label_re_alignment) and entity["is_subword"] if not entity_group_disagg: entity_group_disagg += [entity] if is_last_idx: