-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Opt in flax and tf #17388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Opt in flax and tf #17388
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Should we close the other PR? Let me know once it's ready for a review :-) |
|
Cool, very nice job @ArthurZucker ! Could you as a final safety guard also add TFOPT and FlaxOPT to the documentation test suite? See: https://github.com/huggingface/transformers/tree/main/docs#docstring-testing |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice from my side!
|
|
||
| EXPECTED_OUTPUTS = [ | ||
| "Today is a beautiful day and I want to thank", | ||
| "Today is a beautiful day and I want everyone", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great thanks!
|
Can I merge @LysandreJik @sgugger ? (failing test are not related to OPT) |
| src/transformers/models/mobilebert/modeling_mobilebert.py | ||
| src/transformers/models/mobilebert/modeling_tf_mobilebert.py | ||
| src/transformers/models/opt/modeling_opt.py | ||
| src/transformers/models/opt/modeling_tf_opt.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
@patil-suraj could you quickly check Flax and maybe @gante go over TF OPT? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a lot of work put into this, and it is close to completion 💪 I've added a few questions, suggestions, and corrections on the TF side, but they should be straightforward. Great work!
P.S.: double-checking: have you run the slow tests locally? If you did, it implies that TF XLA generation is working for OPT 🎉
| for idx, decoder_layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states += (hidden_states,) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LayerDrop, present in the PT version (L640), is missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it will be removed in PT as well (should probably need another PR WDYT @patrickvonplaten
| ) | ||
| self.assertTrue(np.allclose(output[:, :3, :3], expected_slice, atol=4e-3)) | ||
|
|
||
| xla_generate = tf.function(model, jit_compile=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😱 does this run? If it does, ignore my comment above about the _update_model_kwargs_for_xla_generation function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not try it yet (xla not compatible with M1 chip, will try on brutasse soon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it doesn't, remove the XLA compilation in the tests. It's still very brittle :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it and it seems to be working fine ! 🥳
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go, from the TF end 👍
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more nits, but LGTM otherwise. Thanks a lot!
| cached_value.value = value | ||
| num_updated_cache_vectors = query.shape[1] | ||
| cache_index.value = cache_index.value + num_updated_cache_vectors | ||
| # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we respect the 119 char limit here and split that comment on several lines? ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The long comment comes from BART, it has to be kept like that otherwise the repo-consistency fails.
Should leave like that
Co-authored-by: Sylvain Gugger <[email protected]>
|
Thanks all for the reviews 😄 🥳 |
* initial commit * add init file * update globakl init * update index and dummy objects * style * update modelling auto * fix initi typo in src/transformers * fix typo in modeling tf auto, opt was in wrong mapping name * fixed a slow test : saved_model * style * fix positionnal embedding if no position id is provided * update tf test * update test flax requirements * fixed serialization * update * update tf name to allow smooth convertion * update flax tests * style * fix test typo * fix tf typo test * add xla for generate support in causal LM * fixed bug * cleaned tf tests * style * removed from PT for slow tests * fix typp * opt test as slow * trying to fix GPT2 undefined * correct documentation and add to test doc * update tf doc * fix doc * fake commit * Apply suggestions from code review Co-authored-by: Joao Gante <[email protected]> * update test based on review * merged main layer for functionning test * fixup + quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * update long comment * make fix copies Co-authored-by: Arthur <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
* initial commit * add init file * update globakl init * update index and dummy objects * style * update modelling auto * fix initi typo in src/transformers * fix typo in modeling tf auto, opt was in wrong mapping name * fixed a slow test : saved_model * style * fix positionnal embedding if no position id is provided * update tf test * update test flax requirements * fixed serialization * update * update tf name to allow smooth convertion * update flax tests * style * fix test typo * fix tf typo test * add xla for generate support in causal LM * fixed bug * cleaned tf tests * style * removed from PT for slow tests * fix typp * opt test as slow * trying to fix GPT2 undefined * correct documentation and add to test doc * update tf doc * fix doc * fake commit * Apply suggestions from code review Co-authored-by: Joao Gante <[email protected]> * update test based on review * merged main layer for functionning test * fixup + quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * update long comment * make fix copies Co-authored-by: Arthur <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
* initial commit * add init file * update globakl init * update index and dummy objects * style * update modelling auto * fix initi typo in src/transformers * fix typo in modeling tf auto, opt was in wrong mapping name * fixed a slow test : saved_model * style * fix positionnal embedding if no position id is provided * update tf test * update test flax requirements * fixed serialization * update * update tf name to allow smooth convertion * update flax tests * style * fix test typo * fix tf typo test * add xla for generate support in causal LM * fixed bug * cleaned tf tests * style * removed from PT for slow tests * fix typp * opt test as slow * trying to fix GPT2 undefined * correct documentation and add to test doc * update tf doc * fix doc * fake commit * Apply suggestions from code review Co-authored-by: Joao Gante <[email protected]> * update test based on review * merged main layer for functionning test * fixup + quality * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * update long comment * make fix copies Co-authored-by: Arthur <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
What does this PR do?
Adds support for OPT in both Flax and TF
Who can review?
@patrickvonplaten, @LysandreJik @younesbelkada @patil-suraj @sgugger