-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Adds support for OPT in Flax and TF. #17227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding those! Without the TFOPTForCausalLM, I don't see the point of adding the TF version of OPT since it can't really be used, so would either not add TF yet or make sure this model is added before merging the PR.
| # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 | ||
| # and adjust num_embeddings appropriately. Other models don't have this hack | ||
| self.offset = 2 | ||
| # TODO Check if that needs reimplemetation similar to OPTLearnedPositionalEmbedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flagging this here, not sure if it has been checked or not. The comment should be removed before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here, please only resolve comments once they have been corrected
| ) | ||
| class TFOPTPretrainedModel(TFPreTrainedModel): | ||
| """ | ||
| TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That docstring is not very informative 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker please only resolve a comment if it has been treated ;-)
Co-authored-by: Sylvain Gugger <[email protected]>
Yes I am not done yet! Sorry if I pinged you a bit early |
Co-authored-by: Sylvain Gugger <[email protected]>
…s into add-opt-flax-tf
Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
| self.dropout = config.dropout | ||
| self.activation_fn = ACT2FN[config.activation_function] | ||
|
|
||
| self.activation_dropout = config.activation_dropout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we revert all changes in this file? This should go into a different PR ;-)
|
|
||
| output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) | ||
|
|
||
| self.assertIsNotNone(output_string, EXPECTED_OUTPUTS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this test pass?
| def prepare_opt_inputs_dict( | ||
| config, | ||
| input_ids, | ||
| decoder_input_ids=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ArthurZucker,
I would also recommend reverting this as it's not related to OPT TF or Flax. Happy to correct it in a future PR :-)
| xla_generate = tf.function(model.generate, jit_compile=True) | ||
| output_sequences = xla_generate(self.prompts).sequences | ||
| output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) | ||
| self.assertIsNotNone(output_string, EXPECTED_OUTPUTS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice do the tests pass? Tests are looking good!
| This module learns positional embeddings up to a fixed maximum size. | ||
| """ | ||
|
|
||
| def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = 1, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, let's remove the padding_idx (position ids don't need to have a padding idx).
To explain, consider the following:
You want to generate in batches:
[, , Hello]
[Hey, my, name]
=> now the attention mask looks as follows:
[0, 0, 1]
[1, 1, 1]
=> this means that the position ids should look as follows:
[0, 0, 0]
[0, 1, 2]
-> there is no need to give the embeddings a special token for the padding tokens. It doesn't really make sense (they are not word embeddings)
|
@ArthurZucker could we add a test similar to this one: #17359 to both Flax and TF? @Rocketknight1 @gante could you check the TF version here as well? |
| def __init__(self, config: OPTConfig, **kwargs): | ||
| super().__init__(config, **kwargs) | ||
| self.config = config | ||
| self.decoder = TFOPTMainLayer(config, name="decoder") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why decoder ? think it should be called "model"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the previous one, but here it does not look natural for me to put model as the class is TFOPTModel. If we say that it has a .model attribute we would have to access it using model.model which is does not look good to me. WDYT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The important part here is that the automatic conversion between PyTorch and Tensorflow works correctly and that it's aligned with PyTorch.
I don't really see why we would do "decoder" is better than "model" here.
E.g. you would do:
opt = TFOPTModel.from_pretrained(....)
opt.model = <to/access/main/layer>no?
We should not change weight names as it would break the automatic conversion which is very important IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I think we were not 100% on the same page here. I think we should make sure to align TF and PT here as much as possible and the way to go is not (as I said before) to just rename it, but I think we need to create a TFOPTDecoder class in Tensorflow and then use this one in TFOPTMainLayer which then should be used with self.model = TFOPTMainLayer(...) in both TFOPTModel and TFOPTForCausalLM. Think the best reference here is modeling_tf_bart.py IMO.
Again the most important is to make sure the automatic conversion works fine which I think it should with the way described above.
Maybe @Rocketknight1 @gante could you verify this real quick?
| # shift labels to the left and cut last logit token | ||
| shifted_logits = logits[:, :-1] | ||
| labels = labels[:, 1:] | ||
| loss = self.hf_compute_loss(labels, shifted_logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess this is the standard but not a huge fan of this design generally -> why do we inherit the loss from a difference class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, honestly! That design decision happened before I started here - my guess is the intention was to separate the losses out so they could be used as Keras losses, but that was never completed, and so that never really became available to users. Using loss mixins does cut down on a lot of code duplication, at the cost of more abstraction, so it's not the worst thing ever imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean it does use abstraction by inheriting from the loss class 😅 . Also not 100% in line with our policy of single-file, but not a big deal either IMO (if it's more Keras-idiomatic then totally fine for me).
commit 5419205 Author: Patrick von Platen <[email protected]> Date: Thu May 19 23:46:26 2022 +0200 [Test OPT] Add batch generation test opt (huggingface#17359) * up * up commit 48c2269 Author: ddobokki <[email protected]> Date: Fri May 20 05:42:44 2022 +0900 Fix bug in Wav2Vec2 pretrain example (huggingface#17326) commit 5d6feec Author: Nathan Dahlberg <[email protected]> Date: Thu May 19 16:21:19 2022 -0400 fix for 17292 (huggingface#17293) commit 518bd02 Author: Patrick von Platen <[email protected]> Date: Thu May 19 22:17:02 2022 +0200 [Generation] Fix Transition probs (huggingface#17311) * [Draft] fix transition probs * up * up * up * make it work * fix * finish * update commit e8714c0 Author: Patrick von Platen <[email protected]> Date: Thu May 19 22:15:36 2022 +0200 [OPT] Run test in lower precision on GPU (huggingface#17353) * [OPT] Run test only in half precision * up * up * up * up * finish * fix on GPU * Update tests/models/opt/test_modeling_opt.py commit 2b28229 Author: Nicolas Patry <[email protected]> Date: Thu May 19 20:28:12 2022 +0200 Adding `batch_size` test to QA pipeline. (huggingface#17330) commit a4386d7 Author: Nicolas Patry <[email protected]> Date: Thu May 19 10:29:16 2022 +0200 [BC] Fixing usage of text pairs (huggingface#17324) * [BC] Fixing usage of text pairs The BC is actually preventing users from misusing the pipeline since users could have been willing to send text pairs and the pipeline would instead understand the thing as a batch returning bogus results. The correct usage of text pairs is preserved in this PR even when that makes the code clunky. Adds support for {"text":..,, "text_pair": ...} inputs for both dataset iteration and more explicit usage to pairs. * Updating the doc. * Update src/transformers/pipelines/text_classification.py Co-authored-by: Sylvain Gugger <[email protected]> * Update src/transformers/pipelines/text_classification.py Co-authored-by: Sylvain Gugger <[email protected]> * Update tests/pipelines/test_pipelines_text_classification.py Co-authored-by: Lysandre Debut <[email protected]> * quality. Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Lysandre Debut <[email protected]> commit 3601aa8 Author: Stas Bekman <[email protected]> Date: Wed May 18 16:00:47 2022 -0700 [tests] fix copy-n-paste error (huggingface#17312) * [tests] fix copy-n-paste error * fix commit 1b20c97 Author: Yih-Dar <[email protected]> Date: Wed May 18 21:49:08 2022 +0200 Fix ci_url might be None (huggingface#17332) * fix * Update utils/notification_service.py Co-authored-by: Lysandre Debut <[email protected]> Co-authored-by: ydshieh <[email protected]> Co-authored-by: Lysandre Debut <[email protected]> commit 6aad387 Author: Yih-Dar <[email protected]> Date: Wed May 18 21:26:44 2022 +0200 fix (huggingface#17337) Co-authored-by: ydshieh <[email protected]> commit 1762ded Author: Zachary Mueller <[email protected]> Date: Wed May 18 14:17:40 2022 -0400 Fix metric calculation in examples and setup tests to run on multi-gpu for no_trainer scripts (huggingface#17331) * Fix length in no_trainer examples * Add setup and teardown * Use new accelerator config generator to automatically make tests able to run based on environment commit 6e195eb Author: Jader Martins <[email protected]> Date: Wed May 18 14:18:43 2022 -0300 docs for typical decoding (huggingface#17186) Co-authored-by: Jader Martins <[email protected]> commit 060fe61 Author: Yih-Dar <[email protected]> Date: Wed May 18 19:07:48 2022 +0200 Not send successful report (huggingface#17329) * send report only if there is any failure Co-authored-by: ydshieh <[email protected]> commit b3b9f99 Author: Yih-Dar <[email protected]> Date: Wed May 18 17:57:23 2022 +0200 Fix test_t5_decoder_model_past_large_inputs (huggingface#17320) Co-authored-by: ydshieh <[email protected]> commit 6da76b9 Author: Jingya HUANG <[email protected]> Date: Wed May 18 17:52:13 2022 +0200 Add onnx export cuda support (huggingface#17183) Co-authored-by: Lysandre Debut <[email protected]> Co-authored-by: lewtun <[email protected]> commit adc0ff2 Author: NielsRogge <[email protected]> Date: Wed May 18 17:47:18 2022 +0200 Add CvT (huggingface#17299) * Adding cvt files * Adding cvt files * changes in init file * Adding cvt files * changes in init file * Style fixes * Address comments from code review * Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> * Format lists in docstring * Fix copies * Apply suggestion from code review Co-authored-by: AnugunjNaman <[email protected]> Co-authored-by: Ayushman Singh <[email protected]> Co-authored-by: Niels Rogge <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> commit 4710702 Author: Sylvain Gugger <[email protected]> Date: Wed May 18 10:46:40 2022 -0400 Fix style commit 5fdb54e Author: mraunak <[email protected]> Date: Wed May 18 10:39:02 2022 -0400 Add Information Gain Filtration algorithm (huggingface#16953) * Add information gain filtration algorithm * Complying with black requirements * Added author * Fixed import order * flake8 corrections Co-authored-by: Javier Turek <[email protected]> commit 91ede48 Author: Kamal Raj <[email protected]> Date: Wed May 18 19:59:53 2022 +0530 Fix typo (huggingface#17328) commit fe28eb9 Author: Yih-Dar <[email protected]> Date: Wed May 18 16:06:41 2022 +0200 remove (huggingface#17325) Co-authored-by: ydshieh <[email protected]> commit 2cb2ea3 Author: Nicolas Patry <[email protected]> Date: Wed May 18 16:06:24 2022 +0200 Accepting real pytorch device as arguments. (huggingface#17318) * Accepting real pytorch device as arguments. * is_torch_available. commit 1c9d1f4 Author: Nicolas Patry <[email protected]> Date: Wed May 18 15:46:12 2022 +0200 Updating the docs for `max_seq_len` in QA pipeline (huggingface#17316) commit 60ad734 Author: Patrick von Platen <[email protected]> Date: Wed May 18 15:08:56 2022 +0200 [T5] Fix init in TF and Flax for pretraining (huggingface#17294) * fix init * Apply suggestions from code review * fix * finish * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> commit 7ba1d4e Author: Joaq <[email protected]> Date: Wed May 18 09:23:47 2022 -0300 Add type hints for ProphetNet (Pytorch) (huggingface#17223) * added type hints to prophetnet * reformatted with black * fix bc black misformatted some parts * fix imports * fix imports * Update src/transformers/models/prophetnet/configuration_prophetnet.py Co-authored-by: Matt <[email protected]> * update OPTIONAL type hint and docstring Co-authored-by: Matt <[email protected]> commit d6b8e9c Author: Carl <[email protected]> Date: Wed May 18 01:07:43 2022 +0200 Add trajectory transformer (huggingface#17141) * Add trajectory transformer Fix model init Fix end of lines for .mdx files Add trajectory transformer model to toctree Add forward input docs Fix docs, remove prints, simplify prediction test Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> Apply suggestions from code review Co-authored-by: Lysandre Debut <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Update docs, more descriptive comments Apply suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> Update readme Small comment update and add conversion script Rebase and reformat Fix copies Fix rebase, remove duplicates Fix rebase, remove duplicates * Remove tapex * Remove tapex * Remove tapex commit c352640 Author: Patrick von Platen <[email protected]> Date: Wed May 18 00:34:31 2022 +0200 fix (huggingface#17310) commit d9050dc Author: Cesare Campagnano <[email protected]> Date: Tue May 17 23:44:37 2022 +0200 [LED] fix global_attention_mask not being passed for generation and docs clarification about grad checkpointing (huggingface#17112) * [LED] fixed global_attention_mask not passed for generation + docs clarification for gradient checkpointing * LED docs clarification Co-authored-by: Patrick von Platen <[email protected]> * [LED] gradient_checkpointing=True should be passed to TrainingArguments Co-authored-by: Patrick von Platen <[email protected]> * [LED] docs: remove wrong word Co-authored-by: Patrick von Platen <[email protected]> * [LED] docs fix typo Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> commit bad3583 Author: Jean Vancoppenolle <[email protected]> Date: Tue May 17 23:42:14 2022 +0200 Add support for pretraining recurring span selection to Splinter (huggingface#17247) * Add SplinterForSpanSelection for pre-training recurring span selection. * Formatting. * Rename SplinterForSpanSelection to SplinterForPreTraining. * Ensure repo consistency * Fixup changes * Address SplinterForPreTraining PR comments * Incorporate feedback and derive multiple question tokens per example. * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/transformers/models/splinter/modeling_splinter.py Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Jean Vancoppenole <[email protected]> Co-authored-by: Tobias Günther <[email protected]> Co-authored-by: Tobias Günther <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> commit 0511305 Author: Yih-Dar <[email protected]> Date: Tue May 17 18:56:58 2022 +0200 Add PR author in CI report + merged by info (huggingface#17298) * Add author info to CI report * Add merged by info * update Co-authored-by: ydshieh <[email protected]> commit 032d63b Author: Sylvain Gugger <[email protected]> Date: Tue May 17 12:56:24 2022 -0400 Fix dummy creation script (huggingface#17304) commit 986dd5c Author: Sylvain Gugger <[email protected]> Date: Tue May 17 12:50:14 2022 -0400 Fix style commit 38ddab1 Author: Karim Foda <[email protected]> Date: Tue May 17 09:32:12 2022 -0700 Doctest longformer (huggingface#16441) * Add initial doctring changes * make fixup * Add TF doc changes * fix seq classifier output * fix quality errors * t * swithc head to random init * Fix expected outputs * Update src/transformers/models/longformer/modeling_longformer.py Co-authored-by: Yih-Dar <[email protected]> Co-authored-by: Yih-Dar <[email protected]> commit 10704e1 Author: Patrick von Platen <[email protected]> Date: Tue May 17 18:20:36 2022 +0200 [Test] Fix W2V-Conformer integration test (huggingface#17303) * [Test] Fix W2V-Conformer integration test * correct w2v2 * up commit 28a0811 Author: regisss <[email protected]> Date: Tue May 17 17:58:14 2022 +0200 Improve mismatched sizes management when loading a pretrained model (huggingface#17257) - Add --ignore_mismatched_sizes argument to classification examples - Expand the error message when loading a model whose head dimensions are different from expected dimensions commit 1f13ba8 Author: Patrick von Platen <[email protected]> Date: Tue May 17 15:48:23 2022 +0200 correct opt (huggingface#17301) commit 349f1c8 Author: Matt <[email protected]> Date: Tue May 17 14:36:23 2022 +0100 Rewrite TensorFlow train_step and test_step (huggingface#17057) * Initial commit * Better label renaming * Remove breakpoint before pushing (this is your job) * Test a lot more in the Keras fit() test * make fixup * Clarify the case where we flatten y dicts into tensors * Clarify the case where we flatten y dicts into tensors * Extract label name remapping to a method commit 651e48e Author: Matt <[email protected]> Date: Tue May 17 14:14:17 2022 +0100 Fix tests of mixed precision now that experimental is deprecated (huggingface#17300) * Fix tests of mixed precision now that experimental is deprecated * Fix mixed precision in training_args_tf.py too commit 6d21142 Author: SaulLu <[email protected]> Date: Tue May 17 14:33:13 2022 +0200 fix retribert's `test_torch_encode_plus_sent_to_model` (huggingface#17231)
| 1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. | ||
| 1. **[CPM](https://huggingface.co/docs/transformers/model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. | ||
| 1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. | ||
| 1. **[CvT](https://huggingface.co/docs/transformers/main/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be here - maybe a problem with git pull or git merge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes most probably, I merged with main, which gave me issues with the torch tests. Could not rebase here
| """, | ||
| OPT_START_DOCSTRING, | ||
| ) | ||
| class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please double check that the loss is correct here - see: #17237
|
Do you think we could fix the PR (I think the PR history is a bit messed up). Also totally fine to close this PR and just open a new PR (move all the relevant files to a new PR) if the git correction is too difficult |
Hey, I think we can close it. |
What does this PR do?
Adds support for OPT in Flax and TF.
Also clean Pytorch code a bit.
Who can review?
@LysandreJik, @patrickvonplaten, @patil-suraj, @sgugger
Sorry for the two pull requests in a row, pulled from main instead of rebasing and had the entire commit history. Created a new branch to clean a bit.