Skip to content

Conversation

@ArthurZucker
Copy link
Collaborator

What does this PR do?

Adds support for OPT in Flax and TF.
Also clean Pytorch code a bit.

Who can review?

@LysandreJik, @patrickvonplaten, @patil-suraj, @sgugger

Sorry for the two pull requests in a row, pulled from main instead of rebasing and had the entire commit history. Created a new branch to clean a bit.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding those! Without the TFOPTForCausalLM, I don't see the point of adding the TF version of OPT since it can't really be used, so would either not add TF yet or make sure this model is added before merging the PR.

# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
# TODO Check if that needs reimplemetation similar to OPTLearnedPositionalEmbedding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging this here, not sure if it has been checked or not. The comment should be removed before merging.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here, please only resolve comments once they have been corrected

)
class TFOPTPretrainedModel(TFPreTrainedModel):
"""
TFOPT Pretrained Model that inheritates from transformers.TFPreTrainedModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That docstring is not very informative 😆

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker please only resolve a comment if it has been treated ;-)

@ArthurZucker
Copy link
Collaborator Author

Thanks for adding those! Without the TFOPTForCausalLM, I don't see the point of adding the TF version of OPT since it can't really be used, so would either not add TF yet or make sure this model is added before merging the PR.

Yes I am not done yet! Sorry if I pinged you a bit early

@ArthurZucker ArthurZucker marked this pull request as draft May 13, 2022 15:21
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]

self.activation_dropout = config.activation_dropout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revert all changes in this file? This should go into a different PR ;-)


output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

self.assertIsNotNone(output_string, EXPECTED_OUTPUTS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this test pass?

def prepare_opt_inputs_dict(
config,
input_ids,
decoder_input_ids=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ArthurZucker,

I would also recommend reverting this as it's not related to OPT TF or Flax. Happy to correct it in a future PR :-)

xla_generate = tf.function(model.generate, jit_compile=True)
output_sequences = xla_generate(self.prompts).sequences
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
self.assertIsNotNone(output_string, EXPECTED_OUTPUTS)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice do the tests pass? Tests are looking good!

This module learns positional embeddings up to a fixed maximum size.
"""

def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = 1, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, let's remove the padding_idx (position ids don't need to have a padding idx).

To explain, consider the following:
You want to generate in batches:

[, , Hello]
[Hey, my, name]

=> now the attention mask looks as follows:
[0, 0, 1]
[1, 1, 1]
=> this means that the position ids should look as follows:
[0, 0, 0]
[0, 1, 2]

-> there is no need to give the embeddings a special token for the padding tokens. It doesn't really make sense (they are not word embeddings)

@patrickvonplaten
Copy link
Contributor

@ArthurZucker could we add a test similar to this one: #17359 to both Flax and TF?

@Rocketknight1 @gante could you check the TF version here as well?

def __init__(self, config: OPTConfig, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.decoder = TFOPTMainLayer(config, name="decoder")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why decoder ? think it should be called "model"

Copy link
Collaborator Author

@ArthurZucker ArthurZucker May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the previous one, but here it does not look natural for me to put model as the class is TFOPTModel. If we say that it has a .model attribute we would have to access it using model.model which is does not look good to me. WDYT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The important part here is that the automatic conversion between PyTorch and Tensorflow works correctly and that it's aligned with PyTorch.

I don't really see why we would do "decoder" is better than "model" here.
E.g. you would do:

opt = TFOPTModel.from_pretrained(....)
opt.model = <to/access/main/layer>

no?

We should not change weight names as it would break the automatic conversion which is very important IMO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think we were not 100% on the same page here. I think we should make sure to align TF and PT here as much as possible and the way to go is not (as I said before) to just rename it, but I think we need to create a TFOPTDecoder class in Tensorflow and then use this one in TFOPTMainLayer which then should be used with self.model = TFOPTMainLayer(...) in both TFOPTModel and TFOPTForCausalLM. Think the best reference here is modeling_tf_bart.py IMO.

Again the most important is to make sure the automatic conversion works fine which I think it should with the way described above.

Maybe @Rocketknight1 @gante could you verify this real quick?

# shift labels to the left and cut last logit token
shifted_logits = logits[:, :-1]
labels = labels[:, 1:]
loss = self.hf_compute_loss(labels, shifted_logits)
Copy link
Contributor

@patrickvonplaten patrickvonplaten May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess this is the standard but not a huge fan of this design generally -> why do we inherit the loss from a difference class?

cc @gante @Rocketknight1

Copy link
Member

@Rocketknight1 Rocketknight1 May 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, honestly! That design decision happened before I started here - my guess is the intention was to separate the losses out so they could be used as Keras losses, but that was never completed, and so that never really became available to users. Using loss mixins does cut down on a lot of code duplication, at the cost of more abstraction, so it's not the worst thing ever imo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean it does use abstraction by inheriting from the loss class 😅 . Also not 100% in line with our policy of single-file, but not a big deal either IMO (if it's more Keras-idiomatic then totally fine for me).

commit 5419205
Author: Patrick von Platen <[email protected]>
Date:   Thu May 19 23:46:26 2022 +0200

    [Test OPT] Add batch generation test opt (huggingface#17359)

    * up

    * up

commit 48c2269
Author: ddobokki <[email protected]>
Date:   Fri May 20 05:42:44 2022 +0900

    Fix bug in Wav2Vec2 pretrain example (huggingface#17326)

commit 5d6feec
Author: Nathan Dahlberg <[email protected]>
Date:   Thu May 19 16:21:19 2022 -0400

    fix for 17292 (huggingface#17293)

commit 518bd02
Author: Patrick von Platen <[email protected]>
Date:   Thu May 19 22:17:02 2022 +0200

    [Generation] Fix Transition probs (huggingface#17311)

    * [Draft] fix transition probs

    * up

    * up

    * up

    * make it work

    * fix

    * finish

    * update

commit e8714c0
Author: Patrick von Platen <[email protected]>
Date:   Thu May 19 22:15:36 2022 +0200

    [OPT] Run test in lower precision on GPU (huggingface#17353)

    * [OPT] Run test only in half precision

    * up

    * up

    * up

    * up

    * finish

    * fix on GPU

    * Update tests/models/opt/test_modeling_opt.py

commit 2b28229
Author: Nicolas Patry <[email protected]>
Date:   Thu May 19 20:28:12 2022 +0200

    Adding `batch_size` test to QA pipeline. (huggingface#17330)

commit a4386d7
Author: Nicolas Patry <[email protected]>
Date:   Thu May 19 10:29:16 2022 +0200

    [BC] Fixing usage of text pairs (huggingface#17324)

    * [BC] Fixing usage of text pairs

    The BC is actually preventing users from misusing the pipeline since
    users could have been willing to send text pairs and the pipeline would
    instead understand the thing as a batch returning bogus results.

    The correct usage of text pairs is preserved in this PR even when that
    makes the code clunky.

    Adds support for {"text":..,, "text_pair": ...} inputs for both dataset
    iteration and more explicit usage to pairs.

    * Updating the doc.

    * Update src/transformers/pipelines/text_classification.py

    Co-authored-by: Sylvain Gugger <[email protected]>

    * Update src/transformers/pipelines/text_classification.py

    Co-authored-by: Sylvain Gugger <[email protected]>

    * Update tests/pipelines/test_pipelines_text_classification.py

    Co-authored-by: Lysandre Debut <[email protected]>

    * quality.

    Co-authored-by: Sylvain Gugger <[email protected]>
    Co-authored-by: Lysandre Debut <[email protected]>

commit 3601aa8
Author: Stas Bekman <[email protected]>
Date:   Wed May 18 16:00:47 2022 -0700

    [tests] fix copy-n-paste error (huggingface#17312)

    * [tests] fix copy-n-paste error

    * fix

commit 1b20c97
Author: Yih-Dar <[email protected]>
Date:   Wed May 18 21:49:08 2022 +0200

    Fix ci_url might be None (huggingface#17332)

    * fix

    * Update utils/notification_service.py

    Co-authored-by: Lysandre Debut <[email protected]>

    Co-authored-by: ydshieh <[email protected]>
    Co-authored-by: Lysandre Debut <[email protected]>

commit 6aad387
Author: Yih-Dar <[email protected]>
Date:   Wed May 18 21:26:44 2022 +0200

    fix (huggingface#17337)

    Co-authored-by: ydshieh <[email protected]>

commit 1762ded
Author: Zachary Mueller <[email protected]>
Date:   Wed May 18 14:17:40 2022 -0400

    Fix metric calculation in examples and setup tests to run on multi-gpu for no_trainer scripts (huggingface#17331)

    * Fix length in no_trainer examples

    * Add setup and teardown

    * Use new accelerator config generator to automatically make tests able to run based on environment

commit 6e195eb
Author: Jader Martins <[email protected]>
Date:   Wed May 18 14:18:43 2022 -0300

    docs for typical decoding (huggingface#17186)

    Co-authored-by: Jader Martins <[email protected]>

commit 060fe61
Author: Yih-Dar <[email protected]>
Date:   Wed May 18 19:07:48 2022 +0200

    Not send successful report (huggingface#17329)

    * send report only if there is any failure

    Co-authored-by: ydshieh <[email protected]>

commit b3b9f99
Author: Yih-Dar <[email protected]>
Date:   Wed May 18 17:57:23 2022 +0200

    Fix test_t5_decoder_model_past_large_inputs (huggingface#17320)

    Co-authored-by: ydshieh <[email protected]>

commit 6da76b9
Author: Jingya HUANG <[email protected]>
Date:   Wed May 18 17:52:13 2022 +0200

    Add onnx export cuda support (huggingface#17183)

    Co-authored-by: Lysandre Debut <[email protected]>

    Co-authored-by: lewtun <[email protected]>

commit adc0ff2
Author: NielsRogge <[email protected]>
Date:   Wed May 18 17:47:18 2022 +0200

    Add CvT (huggingface#17299)

    * Adding cvt files

    * Adding cvt files

    * changes in init file

    * Adding cvt files

    * changes in init file

    * Style fixes

    * Address comments from code review

    * Apply suggestions from code review

    Co-authored-by: Sylvain Gugger <[email protected]>

    * Format lists in docstring

    * Fix copies

    * Apply suggestion from code review

    Co-authored-by: AnugunjNaman <[email protected]>
    Co-authored-by: Ayushman Singh <[email protected]>
    Co-authored-by: Niels Rogge <[email protected]>
    Co-authored-by: Sylvain Gugger <[email protected]>

commit 4710702
Author: Sylvain Gugger <[email protected]>
Date:   Wed May 18 10:46:40 2022 -0400

    Fix style

commit 5fdb54e
Author: mraunak <[email protected]>
Date:   Wed May 18 10:39:02 2022 -0400

    Add Information Gain Filtration algorithm (huggingface#16953)

    * Add information gain filtration algorithm

    * Complying with black requirements

    * Added author

    * Fixed import order

    * flake8 corrections

    Co-authored-by: Javier Turek <[email protected]>

commit 91ede48
Author: Kamal Raj <[email protected]>
Date:   Wed May 18 19:59:53 2022 +0530

    Fix typo (huggingface#17328)

commit fe28eb9
Author: Yih-Dar <[email protected]>
Date:   Wed May 18 16:06:41 2022 +0200

    remove (huggingface#17325)

    Co-authored-by: ydshieh <[email protected]>

commit 2cb2ea3
Author: Nicolas Patry <[email protected]>
Date:   Wed May 18 16:06:24 2022 +0200

    Accepting real pytorch device as arguments. (huggingface#17318)

    * Accepting real pytorch device as arguments.

    * is_torch_available.

commit 1c9d1f4
Author: Nicolas Patry <[email protected]>
Date:   Wed May 18 15:46:12 2022 +0200

    Updating the docs for `max_seq_len` in QA pipeline (huggingface#17316)

commit 60ad734
Author: Patrick von Platen <[email protected]>
Date:   Wed May 18 15:08:56 2022 +0200

    [T5] Fix init in TF and Flax for pretraining (huggingface#17294)

    * fix init

    * Apply suggestions from code review

    * fix

    * finish

    * Update src/transformers/modeling_tf_utils.py

    Co-authored-by: Sylvain Gugger <[email protected]>

    Co-authored-by: Sylvain Gugger <[email protected]>

commit 7ba1d4e
Author: Joaq <[email protected]>
Date:   Wed May 18 09:23:47 2022 -0300

    Add type hints for ProphetNet (Pytorch) (huggingface#17223)

    * added type hints to prophetnet

    * reformatted with black

    * fix bc black misformatted some parts

    * fix imports

    * fix imports

    * Update src/transformers/models/prophetnet/configuration_prophetnet.py

    Co-authored-by: Matt <[email protected]>

    * update OPTIONAL type hint and docstring

    Co-authored-by: Matt <[email protected]>

commit d6b8e9c
Author: Carl <[email protected]>
Date:   Wed May 18 01:07:43 2022 +0200

    Add trajectory transformer (huggingface#17141)

    * Add trajectory transformer

    Fix model init

    Fix end of lines for .mdx files

    Add trajectory transformer model to toctree

    Add forward input docs

    Fix docs, remove prints, simplify prediction test

    Apply suggestions from code review

    Co-authored-by: Sylvain Gugger <[email protected]>
    Apply suggestions from code review

    Co-authored-by: Lysandre Debut <[email protected]>
    Co-authored-by: Sylvain Gugger <[email protected]>
    Update docs, more descriptive comments

    Apply suggestions from code review

    Co-authored-by: Sylvain Gugger <[email protected]>
    Update readme

    Small comment update and add conversion script

    Rebase and reformat

    Fix copies

    Fix rebase, remove duplicates

    Fix rebase, remove duplicates

    * Remove tapex

    * Remove tapex

    * Remove tapex

commit c352640
Author: Patrick von Platen <[email protected]>
Date:   Wed May 18 00:34:31 2022 +0200

    fix (huggingface#17310)

commit d9050dc
Author: Cesare Campagnano <[email protected]>
Date:   Tue May 17 23:44:37 2022 +0200

    [LED] fix global_attention_mask not being passed for generation and docs clarification about grad checkpointing (huggingface#17112)

    * [LED] fixed global_attention_mask not passed for generation + docs clarification for gradient checkpointing

    * LED docs clarification

    Co-authored-by: Patrick von Platen <[email protected]>

    * [LED] gradient_checkpointing=True should be passed to TrainingArguments

    Co-authored-by: Patrick von Platen <[email protected]>

    * [LED] docs: remove wrong word

    Co-authored-by: Patrick von Platen <[email protected]>

    * [LED] docs fix typo

    Co-authored-by: Patrick von Platen <[email protected]>

    Co-authored-by: Patrick von Platen <[email protected]>

commit bad3583
Author: Jean Vancoppenolle <[email protected]>
Date:   Tue May 17 23:42:14 2022 +0200

    Add support for pretraining recurring span selection to Splinter (huggingface#17247)

    * Add SplinterForSpanSelection for pre-training recurring span selection.

    * Formatting.

    * Rename SplinterForSpanSelection to SplinterForPreTraining.

    * Ensure repo consistency

    * Fixup changes

    * Address SplinterForPreTraining PR comments

    * Incorporate feedback and derive multiple question tokens per example.

    * Update src/transformers/models/splinter/modeling_splinter.py

    Co-authored-by: Patrick von Platen <[email protected]>

    * Update src/transformers/models/splinter/modeling_splinter.py

    Co-authored-by: Patrick von Platen <[email protected]>

    Co-authored-by: Jean Vancoppenole <[email protected]>
    Co-authored-by: Tobias Günther <[email protected]>
    Co-authored-by: Tobias Günther <[email protected]>
    Co-authored-by: Patrick von Platen <[email protected]>

commit 0511305
Author: Yih-Dar <[email protected]>
Date:   Tue May 17 18:56:58 2022 +0200

    Add PR author in CI report + merged by info (huggingface#17298)

    * Add author info to CI report

    * Add merged by info

    * update

    Co-authored-by: ydshieh <[email protected]>

commit 032d63b
Author: Sylvain Gugger <[email protected]>
Date:   Tue May 17 12:56:24 2022 -0400

    Fix dummy creation script (huggingface#17304)

commit 986dd5c
Author: Sylvain Gugger <[email protected]>
Date:   Tue May 17 12:50:14 2022 -0400

    Fix style

commit 38ddab1
Author: Karim Foda <[email protected]>
Date:   Tue May 17 09:32:12 2022 -0700

    Doctest longformer (huggingface#16441)

    * Add initial doctring changes

    * make fixup

    * Add TF doc changes

    * fix seq classifier output

    * fix quality errors

    * t

    * swithc head to random init

    * Fix expected outputs

    * Update src/transformers/models/longformer/modeling_longformer.py

    Co-authored-by: Yih-Dar <[email protected]>

    Co-authored-by: Yih-Dar <[email protected]>

commit 10704e1
Author: Patrick von Platen <[email protected]>
Date:   Tue May 17 18:20:36 2022 +0200

    [Test] Fix W2V-Conformer integration test (huggingface#17303)

    * [Test] Fix W2V-Conformer integration test

    * correct w2v2

    * up

commit 28a0811
Author: regisss <[email protected]>
Date:   Tue May 17 17:58:14 2022 +0200

    Improve mismatched sizes management when loading a pretrained model (huggingface#17257)

    - Add --ignore_mismatched_sizes argument to classification examples

    - Expand the error message when loading a model whose head dimensions are different from expected dimensions

commit 1f13ba8
Author: Patrick von Platen <[email protected]>
Date:   Tue May 17 15:48:23 2022 +0200

    correct opt (huggingface#17301)

commit 349f1c8
Author: Matt <[email protected]>
Date:   Tue May 17 14:36:23 2022 +0100

    Rewrite TensorFlow train_step and test_step (huggingface#17057)

    * Initial commit

    * Better label renaming

    * Remove breakpoint before pushing (this is your job)

    * Test a lot more in the Keras fit() test

    * make fixup

    * Clarify the case where we flatten y dicts into tensors

    * Clarify the case where we flatten y dicts into tensors

    * Extract label name remapping to a method

commit 651e48e
Author: Matt <[email protected]>
Date:   Tue May 17 14:14:17 2022 +0100

    Fix tests of mixed precision now that experimental is deprecated (huggingface#17300)

    * Fix tests of mixed precision now that experimental is deprecated

    * Fix mixed precision in training_args_tf.py too

commit 6d21142
Author: SaulLu <[email protected]>
Date:   Tue May 17 14:33:13 2022 +0200

    fix retribert's `test_torch_encode_plus_sent_to_model` (huggingface#17231)
1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan.
1. **[CPM](https://huggingface.co/docs/transformers/model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun.
1. **[CTRL](https://huggingface.co/docs/transformers/model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
1. **[CvT](https://huggingface.co/docs/transformers/main/model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be here - maybe a problem with git pull or git merge?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes most probably, I merged with main, which gave me issues with the torch tests. Could not rebase here

""",
OPT_START_DOCSTRING,
)
class TFOPTForCausalLM(TFOPTPreTrainedModel, TFCausalLanguageModelingLoss):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check that the loss is correct here - see: #17237

@patrickvonplaten
Copy link
Contributor

@ArthurZucker,

Do you think we could fix the PR (I think the PR history is a bit messed up). Also totally fine to close this PR and just open a new PR (move all the relevant files to a new PR) if the git correction is too difficult

@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented May 23, 2022

@ArthurZucker,

Do you think we could fix the PR (I think the PR history is a bit messed up). Also totally fine to close this PR and just open a new PR (move all the relevant files to a new PR) if the git correction is too difficult

Hey, I think we can close it.
Will create a new clean branch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants