Skip to content

fix LayoutLMv3TokenizerFast subword label after 'Ġ' token#21695

Merged
sgugger merged 1 commit into
huggingface:mainfrom
thibaultdouzon:layoutlmv3-bpe-empty-tokens-labels
Apr 3, 2023
Merged

fix LayoutLMv3TokenizerFast subword label after 'Ġ' token#21695
sgugger merged 1 commit into
huggingface:mainfrom
thibaultdouzon:layoutlmv3-bpe-empty-tokens-labels

Conversation

@thibaultdouzon

Copy link
Copy Markdown
Contributor

LayoutLMv3TokenizerFast produces empty 'Ġ' token with offset_mapping = (0, 0).
Next token is wrongly assumed to also be beginning of word and isn't correctly assigned pad_token_label.
This may lead to misalignment of words and token representations.
Other BPE tokenizers might be affected

Add check for previous token if it had an empty offset_mapping (not including special tokens)
Remove copy check from LayoutLMv2TokenizerFast for _batch_encode_plus because it is not affected (uses WordPiece instead of BPE)
Modify test with text that produce 'Ġ' token.

Fixes issue: #19978

@NielsRogge
@ArthurZucker

LayoutLMv3TokenizerFast produces empty 'Ġ' token with `offset_mapping = (0, 0)`.
Next token is wrongly assumed to also be beginning of word and isn't
correctly assigned `pad_token_label`.
Modify test with text that produce 'Ġ' token.
Remove copy check from LayoutLMv2TokenizerFast for `_batch_encode_plus`.

solves issue: huggingface#19978
@thibaultdouzon thibaultdouzon marked this pull request as draft February 19, 2023 20:07
@thibaultdouzon thibaultdouzon marked this pull request as ready for review February 19, 2023 20:09
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@huggingface huggingface deleted a comment from github-actions Bot Mar 22, 2023
@sgugger

sgugger commented Mar 27, 2023

Copy link
Copy Markdown
Collaborator

Also cc @amyeroberts

@ArthurZucker ArthurZucker left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @thibaultdouzon thanks all for working on this! I am having a look to check if there might be a problem with the rust BPE (may just be an edge case!). In the mean time, I think that this is great.
Edit: here are the results of my investigation:

  1. The BPE is behaving as expected:
    Let the input words be ['phone', '0000000000000000', 'phone]. The pretokenization will result in [('Ġpencil', (0, 6)), ('Ġ0000000000000000', (6, 23)), ('Ġphone', (23, 29))] because the tokenizer has add_prefix_space set to True. Then, each of these sequences are given to the actual BPE model which produces ( tokenizer.model.tokenize('Ġ0000000000000000')) Ġpencil, 'Ġ', '0000000000000000', Ġphone. Why? Because Ġpencilis part of the vocabulary (vocab.json) from the merge ofĠpenand cil(in the merge.txt). However,Ġ0000000000000000is not part of the vocabulary, but0000000000000000` is (idx is 49393).
  2. The post-processor is not supposed to remove the added space. It is only when decoding that we remove what was added by BPE. Note that the space is only add for beginning tokens. However there is a problem with rust here, when decoding the extra space is still here, and the output for '0000000000000000' will be ' 0000000000000000'

@thibaultdouzon

thibaultdouzon commented Mar 28, 2023

Copy link
Copy Markdown
Contributor Author

Hi @ArthurZucker, thanks for your investigations.

This PR fixes the problem for LayoutLMv3 but I expect the problem to exist on other models using Fast BPE tokenization, I will take a look when I can to list all impacted models that need a fix.

@ArthurZucker ArthurZucker requested a review from Narsil March 29, 2023 11:19
Comment thread src/transformers/models/layoutlmv3/tokenization_layoutlmv3_fast.py
labels_example.append(word_labels[original_index][word_id])
else:
labels_example.append(self.pad_token_label)
if offset == (0, 0):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if offset == (0, 0):
if self.decode(id) == "":

I'm not sure offset == (0,0) is the right way to check this, maybe this is a safer option

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm this works when testing on a new model (UDOP)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I believe it is equivalent, although slower than simply comparing a tuple to (0, 0) it is probably safer and resilient to future changes.
It works because special tokens are checked line 648 with word_id is not None. Thus an offset of (0, 0) cannot be a special token at this position in code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just checked for LayoutLMv3TokenizerFast, tok.decode(1437) == " ", where 1437 is the id of the "Ġ" token.

@NielsRogge NielsRogge Apr 3, 2023

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok interesting. When working on UdopTokenizerFast (UDOP is a new model for which I'll open a PR soon), I had to use self.decode(id) == "", cause it didn't work with offset == (0, 0). UDOP has the same vocabulary as T5, which is based on SentencePiece.

When testing out the following with offset == (0,0) from this branch:

from transformers import UdopTokenizerFast

words = ['a', 'weirdly', 'test', 'hello']
boxes = [[1,2,3,4] for _ in range(len(words))]
labels = [1, 2, 3, 4]

tokenizer = UdopTokenizerFast.from_pretrained("nielsr/udop-large")

encoding = tokenizer(words, boxes=boxes, word_labels=labels)

for id, label in zip(encoding.input_ids, encoding.labels):
    print(tokenizer.decode([id]), label)

it gives me:

 1
a -100
weird 2
ly -100
test 3
hello 4
</s> -100

(I used "weirdly" just to make sure we get multiple tokens). Interestingly it splits the word "a" into 2 tokens: an empty token and "a". I also printed (id, offset, word_id):

 (0, 1) 0
a (0, 1) 0
weird (0, 5) 1
ly (5, 7) 1
test (0, 4) 2
hello (0, 5) 3

and in this case the empty token has offset (0, 1), which explains why offset == (0,0) didn't work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not related to the current issue but your tokenizer produces weird (pun intended) offsets mapping.

 (0, 1)
a (0, 1)
weird (0, 5)
ly (5, 7)

We should be able to derive word length from the offset_mapping, ie "weirdly" is of length (5-0) + (7-5) = 7. But this does not hold anymore with this empty token not being assigned (0, 0) offset.

@NielsRogge

NielsRogge commented Apr 3, 2023

Copy link
Copy Markdown
Collaborator

Thanks a lot for this fix, would you be able to take into account my comment such that we can merge it? 🙏

Thanks!

Btw the same fix could then be applied to LayoutLMv2 and LayoutXLM

@thibaultdouzon

Copy link
Copy Markdown
Contributor Author

LayoutLMv2 uses WordPiece and not BPE. From what I saw its vocabulary does not contain empty token and thus cannot produce (0, 0) offset_mapping when encoding.

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@sgugger sgugger merged commit 4e441e5 into huggingface:main Apr 3, 2023
@thibaultdouzon thibaultdouzon deleted the layoutlmv3-bpe-empty-tokens-labels branch April 3, 2023 15:39
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
…e#21695)

LayoutLMv3TokenizerFast produces empty 'Ġ' token with `offset_mapping = (0, 0)`.
Next token is wrongly assumed to also be beginning of word and isn't
correctly assigned `pad_token_label`.
Modify test with text that produce 'Ġ' token.
Remove copy check from LayoutLMv2TokenizerFast for `_batch_encode_plus`.

solves issue: huggingface#19978
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…e#21695)

LayoutLMv3TokenizerFast produces empty 'Ġ' token with `offset_mapping = (0, 0)`.
Next token is wrongly assumed to also be beginning of word and isn't
correctly assigned `pad_token_label`.
Modify test with text that produce 'Ġ' token.
Remove copy check from LayoutLMv2TokenizerFast for `_batch_encode_plus`.

solves issue: huggingface#19978
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants