Skip to content

Conversation

@cceyda
Copy link
Contributor

@cceyda cceyda commented Jul 22, 2020

There are many issues with ner pipeline using grouped_entities=True
#5077
#4816
#5730
#5609
#6514
#5541

  • [Bug Fix] add an option ignore_subwords to ignore subsequent ##wordpieces in predictions. Because some models train on only the first token of a word and not on the subsequent wordpieces (BERT NER default). So it makes sense doing the same thing at inference time.

    • The simplest fix is to just group the subwords with the first wordpiece.
      • [TODO] how to handle ignored scores? just set them to 0 and calculate zero invariant mean ?
      • [TODO] handle different wordpiece_prefix ## ? possible approaches:
        get it from tokenizer? but currently most tokenizers dont have a wordpiece_prefix property?
        have an _is_subword(token)
  • [Feature add] added option to skip_special_tokens. Cause It was harder to remove them after grouping.

  • [Additional Changes] remove B/I prefix on returned grouped_entities

Edit: Ignored subwords' scores are also ignored by setting them to nan and using nanmean
Edit: B entities of different type are separated (as per BIO tag definition)
Edit: skip_special_tokens is now the default behavior
Edit: ignore_subwords is now the default behavior
Edit: more flexibility for custom non-standard tokenizers through tokenizer.is_subword_fn, tokenizer.convert_tokens_to_string
Edit: [fix UNK token related bugs by mapping UNK tokens to the correct original string] Use fast tokenizer or pass offset_mapping

Usage

pipeline('ner', model=model, tokenizer=tokenizer, ignore_labels=[], grouped_entities=True, ignore_subwords=True)

Ceyda Cinarel added 3 commits July 22, 2020 19:35
… same type

	(B-type1 B-type1) != (B-type1 I-type1)
[Bug Fix] add an option `ignore_subwords` to ignore subsequent ##wordpieces in predictions. Because some models train on only the first token of a word and not on the subsequent wordpieces (BERT NER default). So it makes sense doing the same thing at inference time.
	The simplest fix is to just group the subwords with the first wordpiece.
	[TODO] how to handle ignored scores? just set them to 0 and calculate zero invariant mean ?
	[TODO] handle different wordpiece_prefix ## ? possible approaches:
		get it from tokenizer? but currently most tokenizers dont have a wordpiece_prefix property?
		have an _is_subword(token)
[Feature add] added option to `skip_special_tokens`. Cause It was harder to remove them after grouping.
[Additional Changes] remove B/I prefix on returned grouped_entities
[Feature Request/TODO] Return indexes?
[Bug TODO]  can't use fast tokenizer with grouped_entities ('BertTokenizerFast' object has no attribute 'convert_tokens_to_string')
Copy link
Contributor

@enzoampil enzoampil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @cceyda , this generally looks good to me with just some comments around:

  1. I think we can hard code skip_special_tokens=True, similar to the approach in TextGenerationPipeline.
  2. I'm wondering about why we shouldn't group B type entities? Hoping you can elaborate your rationale here
  3. Please make sure to add the test cases to NerPipelineTests in test_pipelines. These should be similar test cases to the ones you referenced in this PR (that used to fail, but now pass with these changes).
  4. Please update the tests to account for these changes. The current failures seem to be due to removal of the prefixes in "entity type".

Thanks!

@HHoofs
Copy link

HHoofs commented Aug 3, 2020

I'm wondering should the B & I part maybe separated from the entity type part? In the sense that you average the entities (disregarding the B/I part) and vice-versa. I now have the feeling that only the first subtoken decides whether the complete word is a B or an I.

@cceyda
Copy link
Contributor Author

cceyda commented Sep 1, 2020

I want to complete this but ran into another issue while working it:

All [UNK] tokens get mapped to [UNK] in the output, instead of the actual input token (because the code is getting from ids->tokens), Also [UNK]s gets lost when using skip_special_tokens (#6863)
While this is a simple token alignment issue and can be solved by using offset_mappings. offset_mappings is only available with fast tokenizers, I'm wondering what would be a more general approach to solving this?

@Monique497
Copy link

Monique497 commented Sep 13, 2020

Dear @cceyda,

In the last couple of days I started to work with Huggingface's transformers and especially NER-classification. I ran into issues that has been previously addressed in other issues you just mentioned at the beginning. Especially that subtokens that were classified with 'O' were not properly merged with the full token.

For example (Dutch):
sentence = "Als we Volkswagens OR-voorzitter Bernd Osterloh moeten geloven, dan moet dat binnen drie jaar het geval zijn."

Gives me as group entities:
[{'entity_group': 'B-per', 'score': 0.9999980926513672, 'word': 'Bern'},
{'entity_group': 'I-per', 'score': 0.9999990463256836, 'word': 'Ost'}]

I expect:
[{'entity_group': 'B-per', 'score': 0.9999980926513672, 'word': 'Bernd'},
{'entity_group': 'I-per', 'score': 0.9999990463256836, 'word': 'Osterloh'}]

However, the considered subtokens are classified as 'O':

{'word': '[CLS]', 'score': 0.9999999403953552, 'entity': 'O', 'index': 0}
{'word': 'Als', 'score': 0.9999999403953552, 'entity': 'O', 'index': 1}
{'word': 'we', 'score': 0.9999999403953552, 'entity': 'O', 'index': 2}
{'word': 'Volkswagen', 'score': 0.9999955296516418, 'entity': 'B-misc', 'index': 3}
{'word': '##s', 'score': 0.9999999403953552, 'entity': 'O', 'index': 4}
{'word': 'O', 'score': 0.9981945157051086, 'entity': 'I-misc', 'index': 5}
{'word': '##R', 'score': 0.9999998807907104, 'entity': 'O', 'index': 6}
{'word': '-', 'score': 0.9999999403953552, 'entity': 'O', 'index': 7}
{'word': 'voorzitter', 'score': 0.9999998807907104, 'entity': 'O', 'index': 8}
{'word': 'Bern', 'score': 0.9999980926513672, 'entity': 'B-per', 'index': 9}
{'word': '##d', 'score': 0.9999998807907104, 'entity': 'O', 'index': 10}
{'word': 'Ost', 'score': 0.9999990463256836, 'entity': 'I-per', 'index': 11}
{'word': '##er', 'score': 0.9999998807907104, 'entity': 'O', 'index': 12}
{'word': '##lo', 'score': 0.9999997615814209, 'entity': 'O', 'index': 13}
{'word': '##h', 'score': 0.9999998807907104, 'entity': 'O', 'index': 14}

{'word': 'moeten', 'score': 0.9999999403953552, 'entity': 'O', 'index': 15}
{'word': 'geloven', 'score': 0.9999998807907104, 'entity': 'O', 'index': 16}
{'word': ',', 'score': 0.9999999403953552, 'entity': 'O', 'index': 17}
{'word': 'dan', 'score': 0.9999999403953552, 'entity': 'O', 'index': 18}
{'word': 'moet', 'score': 0.9999999403953552, 'entity': 'O', 'index': 19}
{'word': 'dat', 'score': 0.9999999403953552, 'entity': 'O', 'index': 20}
{'word': 'binnen', 'score': 0.9999999403953552, 'entity': 'O', 'index': 21}
{'word': 'drie', 'score': 0.9999999403953552, 'entity': 'O', 'index': 22}
{'word': 'jaar', 'score': 0.9999999403953552, 'entity': 'O', 'index': 23}
{'word': 'het', 'score': 0.9999999403953552, 'entity': 'O', 'index': 24}
{'word': 'geval', 'score': 0.9999999403953552, 'entity': 'O', 'index': 25}
{'word': 'zijn', 'score': 0.9999999403953552, 'entity': 'O', 'index': 26}
{'word': '.', 'score': 0.9999999403953552, 'entity': 'O', 'index': 27}
{'word': '[SEP]', 'score': 0.9999999403953552, 'entity': 'O', 'index': 28}

I believe your pull request addresses these issues properly.
However, I saw the merge did not complete since it failed on some tasks.

I was wondering if there is still the intention to solve these issues.

Disclaimer: I am a total newbie to git (just set up an account), so please be mild, haha.
Any help is much appreciated!

Thank you in advance,

Monique

@enzoampil
Copy link
Contributor

@cceyda I actually want this PR to move forward. Are you okay collaborating on your fork (can add me as collaborator)? I can help out with some of the issues failing so we can get this merged 😄

@cceyda
Copy link
Contributor Author

cceyda commented Sep 14, 2020

@enzoampil I have added you as a collaborator.
Also pushed some additional changes addressing the [UNK] token mapping problem I mentioned before.
Still there are some things I'm not very satisfied with:

  1. subword prefix was fixed to '##' before. with the latest change I added a check to see if the tokenizer has an is_subword_fn defined (still dont like handling it this way). I know some tokenizers have subword_prefix but most don't and this was the most flexible solution for now.
  2. offset_mappings is needed to resolve [UNK] tokens, but is only available with fast tokenizers. Fast tokenizers don't have convert_ids_to_tokens so had to implement a hacky solution for those aswell.
  3. skip_special_tokens also dropped [UNK] tokens so I had to change things and rely on special_tokens_mask.

It is not optimal but it worked for my use cases.
Haven't had a chance to look at the failing tests yet :/

@cceyda
Copy link
Contributor Author

cceyda commented Sep 16, 2020

I have changed the ignore_subwords default to True which covers cases like

[
{'word': 'Cons', 'score': 0.9994944930076599, 'entity': 'B-PER', 'index': 1},
{'word': '##uelo', 'score': 0.802545428276062, 'entity': 'B-PER', 'index': 2}
]

And honestly I don't know why subwords shouldn't be ignored for most cases. (Unless there is need for some custom logic that determines a words tag; ie by averaging the wordpieces etc etc. In which case grouped_entities shouldn't be used 🤔 )
IMO Mid-word inconsistencies made by the model while ignore_subwords = False shouldn't effect pipelines output logic.

[todo]

  • torch tests are passing for now but probably should add more cases? (I can't see why the tf tests are failing though, don't have dev env for that)
  • should add the new parameters to the doc strings.

@codecov
Copy link

codecov bot commented Sep 21, 2020

Codecov Report

Merging #5970 into master will increase coverage by 26.30%.
The diff coverage is 71.87%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master    #5970       +/-   ##
===========================================
+ Coverage   52.05%   78.36%   +26.30%     
===========================================
  Files         236      168       -68     
  Lines       43336    32338    -10998     
===========================================
+ Hits        22560    25341     +2781     
+ Misses      20776     6997    -13779     
Impacted Files Coverage Δ
src/transformers/pipelines.py 80.59% <71.87%> (+61.46%) ⬆️
src/transformers/modeling_tf_xlm.py 18.94% <0.00%> (-60.01%) ⬇️
src/transformers/modeling_tf_flaubert.py 24.53% <0.00%> (-56.29%) ⬇️
src/transformers/tokenization_camembert.py 37.03% <0.00%> (-29.23%) ⬇️
src/transformers/modeling_tf_gpt2.py 71.84% <0.00%> (-11.00%) ⬇️
src/transformers/data/__init__.py 100.00% <0.00%> (ø)
src/transformers/modeling_mmbt.py 23.47% <0.00%> (ø)
src/transformers/modeling_mbart.py 100.00% <0.00%> (ø)
src/transformers/modeling_outputs.py 100.00% <0.00%> (ø)
src/transformers/modeling_pegasus.py 100.00% <0.00%> (ø)
... and 217 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7087d9b...47797d1. Read the comment docs.

@Monique497
Copy link

Dear @cceyda,

Last two days I worked on your branch to see how it performs on my own input texts.
However, I came accross the following issue I would like to point out to you:

When I use the following line of code (as you suggest under 'Usage' above):

pipeline('ner', model=model, tokenizer=tokenizer, ignore_labels=[], grouped_entities=True, skip_special_tokens=True, ignore_subwords=True)

I get the error:

TypeError: init() got an unexpected keyword argument 'skip_special_tokens'.

When looking in the file transformer.pipelines and looking specifically for the tokenclassificationpipeline, it seems that it is not yet implemented. Or am I missing something?

Best,

Monique

@cceyda
Copy link
Contributor Author

cceyda commented Sep 28, 2020

@Monique497 sorry for the delay
A couple of things have changed since I first wrote that example:

  • special tokens ([CLS][PAD][SEP]) are always skipped (per comments above) so you don't need that kwarg. This is also valid for grouped_entities=False
from transformers import (
    AutoModelForTokenClassification,
    AutoTokenizer,
    pipeline,
)

model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) # note the fast tokenizer use
# ignore_subwords = True by default
nlp = pipeline("ner",model=model,tokenizer=tokenizer, grouped_entities=True)
inputs="test sentence"
output=nlp(inputs)
  • Another important thing is you have to use a fast tokenizer OR pass offset_mapping as a parameter because the [UNK] token resolution depends on this. (maybe I should rename this to offset_mappings). This is also valid for grouped_entities=False
# you can pass it like this
nlp(inputs,offset_mapping=mappings_you_calculate)
  • If you are using a custom tokenizer that treats subwords differently (ie not starting with '##'), you can pass a function implementing your custom logic through tokenizer.is_subword_fn and tokenizer.convert_tokens_to_string
    I don't know if this is the best way to handle non standard tokenizations, but I use some custom non-standard tokenizers for Korean and this solution gave me enough flexibility.

something like this:

def sub_fn(token):
    if token.starts_with("%%"): return True
tokenizer.is_subword_fn=sub_fn

def convert_tokens_to_string(self, tokens):
    out_string = " ".join(tokens).replace(" %%", "").strip()
    return out_string
tokenizer.convert_tokens_to_string=convert_tokens_to_string

@enzoampil what are your thoughts on this?


for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
if nlp.grouped_entities:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conditioned so that grouped_entities=False tests won't fail because of grouped_entities=True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)
if nlp.grouped_entities:
if nlp.ignore_subwords:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added case for ignore_subwords=True and False

Copy link
Contributor Author

@cceyda cceyda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else can I do for this pr for merge? It has been a while

{"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andrés Pastrana"},
{"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"},
{"entity_group": "PER", "score": 0.999369223912557, "word": "Consuelo Araújo Noguera"},
{"entity_group": "PER", "score": 0.9997771680355072, "word": "Andrés Pastrana"},
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test for hyphenated names (ex. Juantia Gomez-Cortez) would be useful, especially given that the fast and slow tokenizers have different codepaths for reconstructing the original text. I had to implement grouping of named entities myself recently and was tripped up by that corner case.

corresponding token in the sentence.
"""
inputs = self._args_parser(*args, **kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed is_subword detection logic: by comparing length of token(##token) with the original text span mapping (Assuming subwordpieces get prefixed by something).
Incase the user wants some other logic they can first get ungrouped entities add is_subword:bool field to entities and call pipeline.group_entities themselves.

word_ref = sentence[start_ind:end_ind]
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
is_subword = len(word_ref) != len(word)

Copy link
Contributor Author

@cceyda cceyda Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed as suggested! I agree it is much cleaner this way. Umm it looks like fast tokenizers have a convert_tokens_to_string method now? 😕

task identifier: :obj:`"sentiment-analysis"` (for classifying sequences according to positive or negative
sentiments).
If multiple classification labels are available (:obj:`model.config.num_labels >= 2`), the pipeline will run
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this to check offset_mapping if provided. (does a simple batch_size check)

for model_name in self.small_models:
nlp = pipeline(
task="ner", model=model_name, tokenizer=model_name, grouped_entities=True, ignore_subwords=False
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
Copy link
Contributor Author

@cceyda cceyda Oct 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know why the tests are failing, I get normal results when running them outside of the test suite 😕 ? I was being just careless 🤦 . should still add cases for not fast tokenizers

@LysandreJik
Copy link
Member

Thanks for iterating! I'll check this today.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Played around with it, works well, and the implementation seems robust. Looks good, LGTM! Thanks for iterating.

@stefan-it, do you want to take a quick look?

@LysandreJik
Copy link
Member

Merging this as soon as it's green, thank you for iterating on the PR! Sorry this took so long to merge.

@LysandreJik LysandreJik merged commit 29b536a into huggingface:master Nov 3, 2020
@enzoampil
Copy link
Contributor

Thanks @LysandreJik and congrats @cceyda !! 😄

@Botfacke

This comment was marked as spam.

@LysandreJik
Copy link
Member

LysandreJik commented Nov 6, 2020

FYI this broke the NER pipeline:

from transformers import pipeline

nlp = pipeline("ner")

nlp("My name is Alex and I live in New York")

crashes with the following error:

    raise Exception("To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter")
Exception: To decode [UNK] tokens use a fast tokenizer or provide offset_mapping parameter

Trying to see if this can be quickly patched, otherwise we'll revert the PR while we patch this.

@cceyda
Copy link
Contributor Author

cceyda commented Nov 6, 2020

oops! although returning unk tokens with slow tokenizers are not the best, I agree not forcing a fast tokenizer with a default of ignore_subword=True looks better for keeping the compatibility. I saw a bit late the _args_parser line was mis-merged during this pr merge and I see it is fixed/improved on the patch. I wasn't sure on how to test for the offset_mapping argument with the new test structure (which looks to be good at the patch). Sorry for the trouble 😅 @LysandreJik

@LysandreJik
Copy link
Member

No worries, thanks for taking a look at the patch!

@cceyda cceyda mentioned this pull request Apr 6, 2021
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants