Skip to content

Fix inverted conditional in TF common test!#22540

Merged
Rocketknight1 merged 8 commits into
mainfrom
fix_pt_tf_equivalence_test
Apr 4, 2023
Merged

Fix inverted conditional in TF common test!#22540
Rocketknight1 merged 8 commits into
mainfrom
fix_pt_tf_equivalence_test

Conversation

@Rocketknight1

Copy link
Copy Markdown
Member

Noticed a rather alarming conditional being backwards in the test_pt_tf_model_equivalence common test. This probably resulted in a lot of tests being skipped!

@Rocketknight1 Rocketknight1 requested review from gante and ydshieh April 3, 2023 16:02
@HuggingFaceDocBuilderDev

HuggingFaceDocBuilderDev commented Apr 3, 2023

Copy link
Copy Markdown

The documentation is not available anymore as the PR was closed or merged.

@Rocketknight1

Copy link
Copy Markdown
Member Author

As expected this has raised a few bugs in the cross-test that were silent before - I'll see what I can do in this PR

@gante gante left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change makes sense!

Re broken tests (which probably need to be fixed/skipped before merging) -- it means that the loss calculation has issues, correct?

@Rocketknight1

Copy link
Copy Markdown
Member Author

Most likely - I'll investigate them all soon!

@Rocketknight1 Rocketknight1 force-pushed the fix_pt_tf_equivalence_test branch from df36a70 to 12cbb74 Compare April 4, 2023 16:09
@Rocketknight1

Rocketknight1 commented Apr 4, 2023

Copy link
Copy Markdown
Member Author

Quick summary of the fixes needed:

ESM: TFEsmForTokenClassification copied the computation from TFBertForTokenClassification, but this has some slightly odd BERT-specific behaviour and doesn't mask -100 in the same way as other models. Replaced it with the loss block from TFRobertaForTokenClassification and all tests pass.

GPT2: For model classes that take rank-3 inputs (e.g. MultipleChoice or DoubleHeads), when output_hidden_states=True , inputs have their second two dims flattened internally in the main model stem. This means that the output hidden_states are rank 3 (bsz, seq_len * num_choices, hidden_dim) and not rank 4 (bsz, num_choices, seq_len, hidden_dim). However, the PT model un-flattens the output for the final hidden_states, which means the last hidden state is rank-4, unlike the others which remain rank-3. In the old TF model, all hidden states are rank-3. I modified the TF code to un-flatten the last hidden state in the same way.

HUBERT: Loss computation especially for CTC overflows a lot with the default labels, which creates lots of inf values and makes it very hard to compare TF and PT losses. I skipped PT-TF equivalence testing for the losses, but keep it for all non-loss outputs.

Wav2Vec2: Same as HUBERT

XGLM: The PT XGLM model does a weird thing where it shifts labels by 1 and then adds pad_token_id as the final label to all samples. I'm not sure this is correct, but I modified the TF code to do the same. It's possible the TF code is the right one here though, in which case we should revert it and change the PT code instead.

@Rocketknight1

Copy link
Copy Markdown
Member Author

@gante I fixed all the bugs that this surfaced, explained above ^

cc @sgugger for final review too

@Rocketknight1 Rocketknight1 requested a review from sgugger April 4, 2023 17:33

@sgugger sgugger left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing the condition in the base test and all the subsequent failures.

@Rocketknight1 Rocketknight1 merged commit edb704b into main Apr 4, 2023
@Rocketknight1 Rocketknight1 deleted the fix_pt_tf_equivalence_test branch April 4, 2023 20:59
raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
* Fix inverted conditional in TF common test!

* Make the same change in the PT tests file

* Make sure hidden states for GPT2 have the same output shape in PT/TF

* Minor fix to PT implementation of token classification loss

* Skip loss equivalence test for TFHubert because it keeps overflowing to inf

* Compute LM loss for TF the (weird) way it's computed in PT

* Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert

* Fix - don't try to access the hidden states property when output is a tuple
@ydshieh

ydshieh commented Apr 5, 2023

Copy link
Copy Markdown
Collaborator

Thank you for the fix @Rocketknight1 ❤️ . And I apologize for the mistake I introduced ...

novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* Fix inverted conditional in TF common test!

* Make the same change in the PT tests file

* Make sure hidden states for GPT2 have the same output shape in PT/TF

* Minor fix to PT implementation of token classification loss

* Skip loss equivalence test for TFHubert because it keeps overflowing to inf

* Compute LM loss for TF the (weird) way it's computed in PT

* Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert

* Fix - don't try to access the hidden states property when output is a tuple
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants