Skip to content
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
846 commits
Select commit Hold shift + click to select a range
8e28ef9
Fix doc examples: cannot import name (#14698)
ydshieh Dec 13, 2021
577febd
Fix: change tooslow to slow (#14734)
ydshieh Dec 13, 2021
87b6ac9
Small fixes for the doc (#14751)
sgugger Dec 13, 2021
0f8502c
Update transformers metadata (#14724)
sgugger Dec 13, 2021
a942f6e
Fix name
sgugger Dec 13, 2021
bf6901e
Mention no images added to repository (#14738)
LysandreJik Dec 13, 2021
7166ac3
Avoid using tf.tile in embeddings for TF models (#14735)
ydshieh Dec 13, 2021
ea5379b
Change how to load config of XLNetLMHeadModel (#14746)
josutk Dec 13, 2021
9c53194
Improve perceiver (#14750)
NielsRogge Dec 13, 2021
9f91c60
Convert Trainer doc page to MarkDown (#14753)
sgugger Dec 13, 2021
5ae45e9
Update Table of Contents (#14755)
sgugger Dec 13, 2021
7c52751
Fixing tests for Perceiver (#14739)
Narsil Dec 14, 2021
95233cf
Make data shuffling in `run_clm_flax.py` respect global seed (#13410)
bminixhofer Dec 14, 2021
2cdcb95
Adding support for multiple mask tokens. (#14716)
Narsil Dec 14, 2021
1c39882
Fix broken links to distillation on index page of documentation (#14722)
amitness Dec 15, 2021
dd6a480
[doc] performance: groups of operations by compute-intensity (#14757)
stas00 Dec 15, 2021
4a07840
Fix the doc_build_test job (#14774)
sgugger Dec 15, 2021
0d557f0
Fix preprocess_function in run_summarization_flax.py (#14769)
ydshieh Dec 15, 2021
56445bd
Update t5.rst (#14776)
Dec 15, 2021
d06ca88
TF model cards (#14720)
Rocketknight1 Dec 15, 2021
17c0422
Update Perceiver code examples (#14783)
NielsRogge Dec 15, 2021
861e7a6
Improve Perceiver docs (#14786)
NielsRogge Dec 15, 2021
80b91bb
Release: v4.14.0
LysandreJik Dec 15, 2021
995b2c4
Docs for v4.14.0
LysandreJik Dec 15, 2021
648c3cd
Move import (#14787)
sgugger Dec 15, 2021
091616d
PoC for conserving old links (#14754)
sgugger Dec 15, 2021
7568bfb
Removes images to put them in a dataset (#14781)
LysandreJik Dec 16, 2021
ce7480a
Post sphinx-clean up and contributing guide updates (#14790)
sgugger Dec 16, 2021
6d1c080
Fix the build documentation job (#14788)
sgugger Dec 16, 2021
6399167
Update CONTRIBUTING.md (#14799)
kamalkraj Dec 16, 2021
482622f
Update CONTRIBUTING.md (#14800)
kamalkraj Dec 16, 2021
e33c504
Train step fix (#14796)
Rocketknight1 Dec 16, 2021
9aa9e8f
Add Speaker Diarization and Verification heads (#14723)
anton-l Dec 16, 2021
48870ec
[Generate] Make generate multi-modal (#14784)
patrickvonplaten Dec 16, 2021
83c7acb
Add WavLM (#14354)
patrickvonplaten Dec 16, 2021
cd30477
Remove datasets requirement (#14795)
LysandreJik Dec 16, 2021
c788e32
[WavLM] Correct position bias computation (#14805)
patrickvonplaten Dec 16, 2021
8e3e4d7
Add test (#14810)
NielsRogge Dec 17, 2021
e007aae
[WavLM] Layerdrop is not allowed for first layer (#14811)
patrickvonplaten Dec 17, 2021
2716d63
[Generate] Correct input_ids detection (#14815)
patrickvonplaten Dec 17, 2021
5c57df0
Implement head_mask for Flax BERT and other models copied from BERT (…
stancld Dec 17, 2021
afb1b3e
Convert rst to mdx bert (#14806)
LysandreJik Dec 17, 2021
c627eca
Wav2Vec2 meets phonemes (#14353)
patrickvonplaten Dec 17, 2021
2ef85df
[ImageGPT] Deprecate pixel_values input name to input_ids (#14801)
patrickvonplaten Dec 17, 2021
be5e54b
[Seq2SeqTrainer] Remove model input name hack (#14802)
patrickvonplaten Dec 20, 2021
20c43f2
up (#14829)
patrickvonplaten Dec 20, 2021
36d9e3a
[WavLM] Fix slow tests (#14845)
patrickvonplaten Dec 20, 2021
a348504
Add SD and SV heads for WavLM (#14847)
anton-l Dec 20, 2021
071b0be
Add an argument to set bucket_cap_mb for PyTorch DDP (#14756)
changlan Dec 20, 2021
f67c476
Update CONTRIBUTING.md (#14835)
kamalkraj Dec 20, 2021
cbc7880
Fix dead link to benchmarks.ipynb (#14842)
DerekChia Dec 20, 2021
3a12d24
[Perceiver] Skip multi-gpu tests for now (#14813)
patrickvonplaten Dec 20, 2021
458f428
Add 'with torch.no_grad()' to integration test forward pass (#14821)
h-holm Dec 20, 2021
d05fd23
Add 'with torch.no_grad()' to integration test forward pass (#14820)
h-holm Dec 20, 2021
78436cb
Add a main_input_name attribute to all models (#14803)
sgugger Dec 20, 2021
a388707
[doc] typo (#14849)
stas00 Dec 20, 2021
c60ee09
[logging] implement warning_advice / TRANSFORMERS_NO_ADVISORY_WARNING…
stas00 Dec 21, 2021
575dcd0
Make the onnx submodule init lazy (#14855)
sgugger Dec 21, 2021
507a2ca
Convert docstrings of modeling files (#14850)
sgugger Dec 21, 2021
088a610
[Bart] better error message (#14854)
patrickvonplaten Dec 21, 2021
a250e75
Only create the model card on process 0 (#14857)
sgugger Dec 21, 2021
72c0e5a
[ASR example] Improve example + add more examples (#14848)
patrickvonplaten Dec 21, 2021
02bafb3
Fix the value error typo of AdamW's betas' valid values checking (#14…
dourgey Dec 21, 2021
726cecd
Add custom `stopping_criteria` and `logits_processor` to `generate` …
lvwerra Dec 21, 2021
0dbf139
Replace commit sha by commit url for update jobs (#14852)
sgugger Dec 21, 2021
852884c
[examples/summarization] deal with None in data records (#14816)
stas00 Dec 21, 2021
336fc02
[doc porting] several docs (#14858)
stas00 Dec 21, 2021
19048ed
Mass conversion of documentation from rst to Markdown (#14866)
sgugger Dec 21, 2021
b14a59f
Skip failing test
sgugger Dec 21, 2021
82ab4bf
Fix FLAX_MULTIPLE_CHOICE_SAMPLE typo (#14871)
mishig25 Dec 21, 2021
e8d789f
Fixes in marian doc (#14872)
sgugger Dec 21, 2021
114a940
Fix `FlaxMarianMTModel` return block. (#14873)
sgugger Dec 21, 2021
d6ab747
Fix doc mistakes (#14874)
sgugger Dec 21, 2021
2279d46
Convert model files from rst to mdx (#14865)
LysandreJik Dec 22, 2021
244aaaf
update the arguments `add_prefix_space` and `trim_offsets` in `backen…
SaulLu Dec 22, 2021
0eb6d69
Feature/fix slow test in mluke (#14749)
ryokan0123 Dec 22, 2021
9c90171
Updated deberta attention (#14625)
guillaume-be Dec 22, 2021
9fd3d98
IterableDatasetShard should use per device batch size instead of real…
SysuCharon Dec 22, 2021
4142bac
Fix typo in error message
sgugger Dec 22, 2021
7d10dbf
Fix Perceiver docs (#14879)
NielsRogge Dec 22, 2021
8f3f725
Fix pytorch image classification example (#14883)
mariosasko Dec 22, 2021
6a9f474
Onnx enable tasks for supported models (part 2) (#14700)
michaelbenayoun Dec 22, 2021
9c2a1ac
Properly indent return block (#14887)
sgugger Dec 22, 2021
61c18f9
Release: v4.15.0
patrickvonplaten Dec 22, 2021
3befc00
Docs for v4.16.0dev0
patrickvonplaten Dec 22, 2021
41f987c
Keras metric callback (#14867)
Rocketknight1 Dec 22, 2021
bf81baf
Convert rst files (#14888)
sgugger Dec 22, 2021
353915d
Fix installation instructions for BART ONNX example (#14885)
lewtun Dec 23, 2021
9a738d6
Fix doc examples: ... takes no keyword arguments (#14701)
ydshieh Dec 23, 2021
edcefc9
Fix AttributeError from PreTrainedTokenizerFast.decoder (#14691)
aphedges Dec 23, 2021
a3e6e65
Add 'with torch.no_grad()' to integration test forward pass (#14808)
h-holm Dec 23, 2021
bac4566
Add ONNX support for MarianMT models (#14586)
lewtun Dec 23, 2021
38f74d9
add custom stopping criteria to human eval script (#14897)
lvwerra Dec 23, 2021
3cb0ffc
Set `run_name` in MLflowCallback (#14894)
yangdong02 Dec 23, 2021
caa166d
Add TFCLIPModel (#13967)
ydshieh Dec 23, 2021
f0a2374
[AutoTokenizer] Fix incorrect from pretrained (#14900)
patrickvonplaten Dec 23, 2021
d93bf3f
Update diarization and WavLM tolerances (#14902)
anton-l Dec 23, 2021
eed928c
[doc] post-porting (#14890)
stas00 Dec 23, 2021
226576c
[Generate] Remove attention_mask and integrate model_main_input_name …
patrickvonplaten Dec 23, 2021
2f4d7b9
Fix failing GPU trainer tests (#14903)
sgugger Dec 23, 2021
bf3b4e3
Better logic for getting tokenizer config in AutoTokenizer (#14906)
sgugger Dec 23, 2021
9a6f8a0
[doc] install - add jax (#14912)
stas00 Dec 23, 2021
a3c9be1
[WavLM] fix wavlm docs (#14910)
patrickvonplaten Dec 23, 2021
a3689dd
Fix Perceiver docs (#14917)
Sanster Dec 24, 2021
85227a8
ChunkPipeline (batch_size enabled on `zero-cls` and `qa` pipelines. (…
Narsil Dec 27, 2021
51b5f5a
Add `ElectraForCausalLM` -> Enable Electra encoder-decoder model (#14…
stancld Dec 27, 2021
a4b5219
fix to issue #14833 in data_collator - consider no labels (#14930)
Dec 27, 2021
26af612
Fix duplicate call to save_checkpoint when using deepspeed (#14946)
MihaiBalint Dec 27, 2021
120ed8c
Doc styler v2 (#14950)
sgugger Dec 27, 2021
17073ba
Convert last rst file (#14952)
sgugger Dec 27, 2021
545663b
[doc] consistent True/False/None default format (#14951)
stas00 Dec 27, 2021
85212f0
[doc] :obj: hunt (#14954)
stas00 Dec 27, 2021
eeca7e4
Doc styler examples (#14953)
sgugger Dec 28, 2021
d8f0486
Style
sgugger Dec 28, 2021
86c9852
[doc] :class: hunt (#14955)
stas00 Dec 28, 2021
b1933d5
Add Speech Seq2Seq Training script (#14792)
patrickvonplaten Dec 28, 2021
f24b430
[WavLM] give model for precision (#14958)
patrickvonplaten Dec 28, 2021
52c780d
Update README.md (#14965)
patrickvonplaten Dec 28, 2021
59ba84c
[Tests] Speed up tokenizer tests (#14964)
patrickvonplaten Dec 28, 2021
64fbbe0
[Wav2Vec2] Rename model's feature extractor to feature encoder (#14959)
patrickvonplaten Dec 28, 2021
8ef1f8e
refactor: replace `assert` with `ValueError` (#14970)
jaketae Dec 29, 2021
11fd216
remove absl workaround as it's no longer needed (#14909)
stas00 Dec 29, 2021
0f64f4d
Fixing a pathological case for slow tokenizers (#14981)
Narsil Dec 30, 2021
85a5772
[AutoProcessor] Correct AutoProcessor and automatically add processor…
patrickvonplaten Dec 30, 2021
c2873d5
[Generate] correct encoder_outputs are passed without attention_mask …
patrickvonplaten Dec 30, 2021
debac54
Adding `num_return_sequences` support for text2text generation. (#14988)
Narsil Dec 30, 2021
118cc26
Enabling `tokenizers` upgrade. (#14941)
Narsil Dec 30, 2021
e8bcddf
Allow training to resume even if RNG states are not properly loaded (…
sgugger Dec 30, 2021
b53c3c8
Map model_type and doc pages names (#14944)
sgugger Jan 3, 2022
3e75cdd
Fixing t2t pipelines lists outputs. (#15008)
Narsil Jan 3, 2022
9455666
Improve truncation_side (#14947)
Narsil Jan 3, 2022
297abeb
Large audio chunking for the existing ASR pipeline (#14896)
anton-l Jan 3, 2022
30bb7cc
fix missing import (#15016)
ydshieh Jan 3, 2022
e7bbc82
[Tests] Correct Wav2Vec2 & WavLM tests (#15015)
patrickvonplaten Jan 3, 2022
9323670
Update parallelism.mdx (#15013)
hyunwoongko Jan 3, 2022
4296514
Fix Code block (#14983)
flozi00 Jan 4, 2022
22aa017
Fix a little typo (#15002)
milyiyo Jan 4, 2022
5976de7
Add Flax RoFormer (#15005)
stancld Jan 4, 2022
6795780
Hotfix `chunk_length_s` instead of `_ms`. (#15029)
Narsil Jan 4, 2022
635a3bd
[doc] Update parallelism.mdx (#15018)
hyunwoongko Jan 4, 2022
40081ba
[megatron convert] PYTHONPATH requirements (#14956)
stas00 Jan 5, 2022
38f089f
Fix doc example: mask_time_indices (numpy) has no attribute 'to' (#15…
ydshieh Jan 5, 2022
0c17cea
Adding QoL for `batch_size` arg (like others enabled everywhere). (#1…
Narsil Jan 5, 2022
e5902d2
[CLIP] Fix PT test (#15041)
patrickvonplaten Jan 5, 2022
0c98bdf
[SpeechEncoderDecoder] Fix from pretrained (#15043)
patrickvonplaten Jan 5, 2022
59db73c
[CLIP] Fix TF test (#15042)
patil-suraj Jan 5, 2022
dd859ae
Add Flax image captioning example (#14864)
ydshieh Jan 6, 2022
f4339ee
Enabling `TF` on `image-classification` pipeline. (#15030)
Narsil Jan 6, 2022
1e4c3f8
wrapped forward passes in torch.no_grad() (#15037)
mattchurgin Jan 6, 2022
f172978
Add detectron2 to Github actions (#15053)
NielsRogge Jan 6, 2022
6a4d553
Remove old asserts. (#15012)
Narsil Jan 6, 2022
39b4223
Add 'with torch.no_grad()' to BertGeneration integration test forward…
itsTurner Jan 6, 2022
3228914
Update run_speech_recognition_seq2seq.py (#14967)
flozi00 Jan 6, 2022
36d0eb4
[VisionTextDualEncoder] Fix doc example
ydshieh Jan 6, 2022
cde0a3e
Resubmit changes after rebase to master (#14982)
kct22aws Jan 7, 2022
8a31277
[Fix doc examples] Add missing from_pretrained (#15044)
ydshieh Jan 7, 2022
4a9de35
[VisionTextDualEncoder] Add token_type_ids param (#15073)
ydshieh Jan 7, 2022
3e0ca20
Fix convert for newer megatron-lm bert model (#14082)
yoquankara Jan 8, 2022
1747d36
[Wav2Vec2 Speech Event] Add speech event v2 (#15083)
patrickvonplaten Jan 10, 2022
ddd5ca7
fix model table cell text alignment (#14999)
ydshieh Jan 10, 2022
b280c18
Update check_repo.py (#15014)
kamalkraj Jan 10, 2022
e7761c9
Make OpenAIGPTTokenizer work with SpaCy 2.x and 3.x (#15019)
cody-moveworks Jan 10, 2022
763b933
Change assignee for tokenizers (#15088)
LysandreJik Jan 10, 2022
2608566
support the trocr small models (#14893)
liminghao1630 Jan 10, 2022
7fa9cd2
fix doc example - AttributeError: type object 'RagModel' has no attri…
ydshieh Jan 10, 2022
bbdebd7
Fix style
sgugger Jan 10, 2022
b3f48cb
Model summary horizontal banners (#15058)
mishig25 Jan 10, 2022
03d4b51
Use tqdm.auto in Pipeline docs (#14920)
bryant1410 Jan 10, 2022
21ff6a0
[doc] normalize HF Transformers string (#15023)
stas00 Jan 10, 2022
9c405e9
Happy New Year! (#15094)
sgugger Jan 10, 2022
eed0e28
[DOC] fix doc examples for bart-like models (#15093)
patil-suraj Jan 10, 2022
eae526d
[performance doc] Power and Cooling (#14935)
stas00 Jan 10, 2022
1a030d1
Add TFVisionEncoderDecoderModel (#14148)
ydshieh Jan 10, 2022
12dcb3f
Add test to check reported training loss (#15096)
sgugger Jan 11, 2022
ee2d797
Take gradient accumulation into account when defining samplers (#15095)
sgugger Jan 11, 2022
dafa760
fix doc example - TypeError: forward() got an unexpected keyword argu…
ydshieh Jan 11, 2022
7e507c7
Fix cookiecutter (#15100)
NielsRogge Jan 11, 2022
1cb5545
[Wav2Vec2ProcessorWithLM] improve decoder downlaod (#15040)
patrickvonplaten Jan 11, 2022
c8255a6
Adds IBERT to models exportable with ONNX (#14868)
MaximovaIrina Jan 11, 2022
73ac1bf
change metric_key_prefix in seq2seq_trainer.py (#15099)
JejuWayfarer Jan 11, 2022
fa58538
Print out durations of all scheduled tests (#15102)
LysandreJik Jan 11, 2022
03d769f
Add Nystromformer (#14659)
novice03 Jan 11, 2022
da5865e
Fix failing test (#15104)
LysandreJik Jan 11, 2022
5f6ad60
add spaces badges
AK391 Jan 4, 2022
71d6dc0
Transformer-XL badge
AK391 Jan 6, 2022
7965998
Reformer Spaces badge
AK391 Jan 6, 2022
aaf11f9
XLNet spaces badge
AK391 Jan 6, 2022
847fc2a
BERT spaces badge
AK391 Jan 6, 2022
448dc85
ALBERT spaces badge
AK391 Jan 6, 2022
29259ad
Roberta spaces badge
AK391 Jan 6, 2022
457ed1c
Distilbert spaces badge
AK391 Jan 7, 2022
18ceeb6
ConvBERT spaces badge
AK391 Jan 10, 2022
f17631a
XLM Spaces badge
AK391 Jan 10, 2022
9da68a0
XLM-Roberta Spaces badge
AK391 Jan 10, 2022
2fd3803
FlauBERT spaces badge
AK391 Jan 10, 2022
a691b79
ELECTRA Spaces badge
AK391 Jan 10, 2022
4c2e966
Funnel Transformer spaces badge
AK391 Jan 10, 2022
673b626
Longformer Spaces badge
AK391 Jan 10, 2022
0b52bbf
BART Spaces badge
AK391 Jan 10, 2022
022cb07
Pegasus Spaces badge
AK391 Jan 10, 2022
670456f
MarianMT Spaces badge
AK391 Jan 10, 2022
15fc0bd
T5 Spaces badge
AK391 Jan 10, 2022
1c51a90
MT5 Spaces badge
AK391 Jan 10, 2022
bbf6d9d
MBART spaces badge
AK391 Jan 10, 2022
430859f
ProphetNet spaces badge
AK391 Jan 10, 2022
228cee9
DPR Spaces badge
AK391 Jan 10, 2022
33e9cdd
XLM-ProphetNet Spaces badge
AK391 Jan 11, 2022
3a585ef
Doc styler tip (#15105)
sgugger Jan 11, 2022
822a090
Update ONNX docs (#14904)
lewtun Jan 11, 2022
35d57e8
Fix saving FlaubertTokenizer configs (#14991)
vmaryasin Jan 11, 2022
c6c0d64
Update TF test_step to match train_step (#15111)
Rocketknight1 Jan 11, 2022
4193e90
Fix typo in doc template
sgugger Jan 11, 2022
374e00f
Pipeline ASR with LM. (#15071)
Narsil Jan 12, 2022
ce27774
use block_size instead of max_seq_length in tf run_clm example (#15036)
riklopfer Jan 12, 2022
ffb9284
fix: switch from slow to generic tokenizer class (#15122)
lvwerra Jan 12, 2022
f1dc7e0
Fix #14357 (#15001)
ydshieh Jan 12, 2022
6ab78a5
Fix link to deepspeed config
sgugger Jan 12, 2022
da7e8e5
Add ONNX configuration classes to docs (#15121)
lewtun Jan 12, 2022
c58eb18
Add `with torch.no_grad()` to DistilBERT integration test forward pas…
jaketae Jan 12, 2022
b2266db
mBART support for run_summarization.py (#15125)
banda-larga Jan 12, 2022
8b3353d
doc-builder -> doc-build (#15134)
LysandreJik Jan 13, 2022
b3184b0
Update src/transformers/onnx/convert.py
Albertobegue Jan 27, 2022
6f23e41
Merge branch 'onnx-conversion-with-tensorflow' of https://github.com/…
Jan 27, 2022
a65c8bd
Integrate tensorflow onnx conversion with the changes in master
Jan 27, 2022
04f25a0
Change documentation to mention conversion with Tensorflow
Jan 27, 2022
9bab135
Define global variables even if tf or pt not available
Albertobegue Jan 28, 2022
0cb896e
Merge branch 'master' into onnx-conversion-with-tensorflow
lewtun Feb 7, 2022
92f86a8
Fix setup.oy
lewtun Feb 7, 2022
87859df
Fix check_table.py
lewtun Feb 7, 2022
0cd8783
Restore files to master
lewtun Feb 7, 2022
893499b
Reinstate torch-nightly fix
lewtun Feb 7, 2022
4bbb806
Remove keras2onnx dependency as project is archived
lewtun Feb 7, 2022
2821592
Refactor to support TensorFlow in FeaturesManager
lewtun Feb 7, 2022
c0a0456
Add seq2seq test for TF export
lewtun Feb 8, 2022
1877023
Fix TF test decorator
lewtun Feb 8, 2022
2315244
Refactor export into export_pytorch and export_tensorflow
Albertobegue Feb 8, 2022
019083c
Merge branch 'onnx-conversion-with-tensorflow' of https://github.com/…
Albertobegue Feb 8, 2022
394a35f
Update src/transformers/onnx/convert.py
Albertobegue Feb 8, 2022
d4ffce7
Add documentation
Albertobegue Feb 8, 2022
79d4945
Merge branch 'onnx-conversion-with-tensorflow' of https://github.com/…
Albertobegue Feb 8, 2022
b1e50f6
Update src/transformers/onnx/convert.py
Albertobegue Feb 8, 2022
79367d3
Update src/transformers/onnx/convert.py
Albertobegue Feb 8, 2022
69b5a23
Update src/transformers/onnx/convert.py
Albertobegue Feb 8, 2022
94107d8
Update src/transformers/onnx/convert.py
Albertobegue Feb 8, 2022
071c21e
Fix style
lewtun Feb 8, 2022
d201dce
Check model's type instead of framework installation to choose betwee…
Albertobegue Feb 8, 2022
b605bb6
Merge branch 'onnx-conversion-with-tensorflow' of https://github.com/…
Albertobegue Feb 8, 2022
b4816f9
Fix test checking pytorch version
Albertobegue Feb 8, 2022
42d0270
Update src/transformers/onnx/convert.py
Albertobegue Feb 10, 2022
fe6c0d0
Use alias of onnx.export
Albertobegue Feb 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions docs/source/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ Ready-made configurations include the following architectures:
- XLM-RoBERTa
- XLM-RoBERTa-XL

The ONNX conversion is supported for the PyTorch versions of the models. If you
would like to be able to convert a TensorFlow model, please let us know by
opening an issue.

In the next two sections, we'll show you how to:

* Export a supported model using the `transformers.onnx` package.
Expand Down Expand Up @@ -149,6 +145,8 @@ DistilBERT we have:
["last_hidden_state"]
```

The approach is similar for TensorFlow models.

### Selecting features for different model topologies

Each ready-made configuration comes with a set of _features_ that enable you to
Expand Down
259 changes: 184 additions & 75 deletions src/transformers/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from packaging.version import Version, parse

from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from transformers.file_utils import is_torch_onnx_dict_inputs_support_available
from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available
from transformers.onnx.config import OnnxConfig
from transformers.utils import logging

Expand Down Expand Up @@ -62,90 +62,190 @@ def check_onnxruntime_requirements(minimum_version: Version):
)


def export(
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
def export_pytorch(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we decided to refactor the single export() function into dedicated functions for PyTorch and TensorFlow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as the method export still exists with the same API, then it's good to me! (And it seems to be the case)

tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
Export a PyTorch model to an ONNX Intermediate Representation (IR)

Args:
tokenizer:
model:
config:
opset:
output:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if is_torch_available():
from transformers.file_utils import torch_version

if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")

if issubclass(type(model), PreTrainedModel):
import torch
from torch.onnx import export

logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)

# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

if not inputs_match:
raise ValueError("Model and config inputs doesn't match")

config.patch_ops()

# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)

config.restore_ops()

return matched_inputs, onnx_outputs


def export_tensorflow(
tokenizer: PreTrainedTokenizer,
model: TFPreTrainedModel,
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Export a TensorFlow model to an ONNX Intermediate Representation (IR)

Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`TFPreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if not is_torch_available():
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")

import torch
from torch.onnx import export

from ..file_utils import torch_version

if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")

logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)

# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

if not inputs_match:
raise ValueError("Model and config inputs doesn't match")

config.patch_ops()

# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"):
# export can work with named args but the dict containing named args
# has to be the last element of the args tuple.
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
opset_version=opset,
)
import tensorflow as tf

import onnx
import tf2onnx

model.config.return_dict = True

config.restore_ops()
# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)

# Ensure inputs match
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
onnx.save(onnx_model, output.as_posix())
config.restore_ops()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ops patching might need some changes to work for both backends, but I guess that can be figured out later.


return matched_inputs, onnx_outputs


def export(
tokenizer: PreTrainedTokenizer,
model: Union[PreTrainedModel, TFPreTrainedModel],
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)

Args:
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
The model to export.
config ([`~onnx.config.OnnxConfig`]):
The ONNX configuration associated with the exported model.
opset (`int`):
The version of the ONNX operator set to use.
output (`Path`):
Directory to store the exported ONNX model.

Returns:
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
the ONNX configuration.
"""
if not (is_torch_available() or is_tf_available()):
raise ImportError(
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
"Please install torch or tensorflow first."
)

if is_torch_available():
return export_pytorch(tokenizer, model, config, opset, output)
elif is_tf_available():
return export_tensorflow(tokenizer, model, config, opset, output)


def validate_model_outputs(
config: OnnxConfig,
tokenizer: PreTrainedTokenizer,
Expand All @@ -160,7 +260,10 @@ def validate_model_outputs(

# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
if issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
else:
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)

# Create ONNX Runtime session
options = SessionOptions()
Expand Down Expand Up @@ -210,7 +313,10 @@ def validate_model_outputs(

# Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].detach().numpy()
if issubclass(type(reference_model), PreTrainedModel):
ref_value = ref_outputs_dict[name].detach().numpy()
else:
ref_value = ref_outputs_dict[name].numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":')

# Shape
Expand Down Expand Up @@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match(

:param model_inputs: :param config_inputs: :return:
"""
forward_parameters = signature(model.forward).parameters
if issubclass(type(model), PreTrainedModel):
forward_parameters = signature(model.forward).parameters
else:
forward_parameters = signature(model.call).parameters
model_inputs_set = set(model_inputs)

# We are fine if config_inputs has more keys than model_inputs
Expand Down
Loading