Skip to content
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
277 commits
Select commit Hold shift + click to select a range
2f66a42
[WIP] Add BridgeTowerForContrastiveLearning (#21964)
abhiwand Mar 8, 2023
fe03e51
Fix test for torchneuroncore in Trainer (#22028)
sgugger Mar 8, 2023
781edc5
Add tokenize_kwargs parameter definition in the FeatureExtractionPipe…
anruijian Mar 8, 2023
57b3a97
[examples/speech-recognition] Add SpecAugment to run_speech_recogniti…
bofenghuang Mar 8, 2023
9a10d4b
fixes the gradient checkpointing of whisper (#22019)
soma2000-lang Mar 8, 2023
65f56ec
Avoid `text_config_dict` and `vision_config_dict` being saved for CL…
ydshieh Mar 8, 2023
3f29db2
Mark all `BridgeTower` tests slow for now (#22039)
ydshieh Mar 8, 2023
2a5f185
Bug fix: token classification pipeline while passing offset_mapping (…
cceyda Mar 8, 2023
f89b95c
Update ALIGN docs (#22025)
alaradirik Mar 9, 2023
0d08783
[21737][T5]: Fix gradient checkpoint bug (#22036)
nipunjindal Mar 9, 2023
67bd8ee
Docs Improvement - In ZSH, not using ' ' around pip install fails, fi…
shaun-scale Mar 9, 2023
c2420fd
Can't install tf2 on M1 Chip by default (#22046)
shaun-scale Mar 9, 2023
da98339
Remove set_access_token usage + fail tests if FutureWarning (#22051)
Wauplin Mar 9, 2023
6fd5f2f
Show the number of `huggingface_hub` warnings in CI report (#22054)
ydshieh Mar 9, 2023
d95d2d9
Return analysis for hyperparameter_search with Ray backend (#22040)
anruijian Mar 9, 2023
a7dacfb
pt-to-tf model architecture override (#22055)
Rocketknight1 Mar 9, 2023
c8437ea
rm $ symbol from code block from contributing.md (#22057)
kamalkraj Mar 9, 2023
09e9344
[deepspeed] offload + non-cpuadam optimizer exception (#22043)
stas00 Mar 9, 2023
e39722e
Edit the docstring of `image_processing_donut` to match code (#22033)
vermouthmjl Mar 9, 2023
bf97e22
Skip 3 tests for `WhisperEncoderModelTest` (#22060)
ydshieh Mar 9, 2023
4405fa4
Add setters by type of args to TrainingArguments (#21570)
sgugger Mar 9, 2023
a4e1c7b
Update tiny model creation script (#22058)
ydshieh Mar 9, 2023
777f921
Fix case when using --gradient_accumulation_steps with DDP disabled. …
sangeethabal Mar 9, 2023
1e9ccb6
Add a progress bar for the total download of shards (#22062)
sgugger Mar 9, 2023
b84287e
Fix gradient checkpointing bug in Speech2Text (#22079)
KMFODA Mar 10, 2023
ae930f0
Fix gradient checkpointing bug in switch transformer (#22081)
KMFODA Mar 10, 2023
3d42b72
[GPT2] Propose fix for #21080 (#21853)
ArthurZucker Mar 10, 2023
13289d4
Fix small typo in flan-ul2.mdx (#22068)
kevin51jiang Mar 10, 2023
c78d87a
Generate - Fix broken documentation links (#22078)
gante Mar 10, 2023
0993014
Fix gradient checkpointing bug in Speecht5 (#22080)
KMFODA Mar 10, 2023
9199bf7
Fix hint in src/transformers/modeling_utils.py (#22074)
J-shang Mar 10, 2023
6a42232
handle numpy inputs in whole word mask data collator (#22032)
dwyatte Mar 10, 2023
18fa8e9
GPT-J specific half precision on CPU note (#22086)
MKhalusova Mar 10, 2023
9fd5b5e
Fix imports of TF MobileViT (#22065)
sgugger Mar 10, 2023
ffddc14
Revert "[GPT2] Propose fix for #21080" (#22093)
ydshieh Mar 10, 2023
30ccdfa
[Whisper] Remove embed_tokens from encoder docstring (#21996)
sanchit-gandhi Mar 11, 2023
733b396
Add AutoModelForZeroShotImageClassification (#22087)
alaradirik Mar 13, 2023
010d238
add new model of MGP-STR (#21418)
wdp-007 Mar 13, 2023
c211bd5
Add pr_checks.mdx Italian translation (#17459) (#22116)
alexcalabrese Mar 13, 2023
6048e88
Fix gradient checkpointing bug in xglm (#22127)
KMFODA Mar 13, 2023
27ca366
Fix gradient checkpointing bug in Trajectory Transformer (#22125)
KMFODA Mar 13, 2023
3e32345
Fix gradient checkpointing bug in xlm_roberta_xl (#22128)
KMFODA Mar 13, 2023
9c13edb
Added big_models.mdx italian translation #17600 (#22115)
nickprock Mar 13, 2023
718ec35
[`Blip2`] skip accelerate test (#22124)
younesbelkada Mar 13, 2023
d56076d
Fix gradient checkpointing bug in xmod (#22129)
KMFODA Mar 13, 2023
b22b7b1
Fix gradient checkpointing bug in LongT5 (#22130)
KMFODA Mar 13, 2023
c349eda
Fix gradient checkpointing bug in trocr (#22126)
KMFODA Mar 13, 2023
165bd4a
Zero-shot image classification task guide (#22132)
MKhalusova Mar 13, 2023
798d110
Fix doc link for MGP-STR (#22138)
sgugger Mar 13, 2023
729251a
Adding Type Hints to TF_Pegasus model (#21941)
pmollerus23 Mar 13, 2023
d56ad12
Add a new script to check model testers' config (#22063)
ydshieh Mar 13, 2023
c53441c
Update configuration_align.py (projected_dim=640) (#22139)
bishmdl76 Mar 13, 2023
f3067cc
[`Whiper`] add `get_input_embeddings` to `WhisperForAudioClassificati…
younesbelkada Mar 13, 2023
e6bbfa8
Trainer: let generate pick its inputs (#22108)
gante Mar 13, 2023
382b5ec
Enforce same behavior as PyTorch 2.0 for older versions (#22136)
sgugger Mar 13, 2023
2e0898b
[trainer] fix bug in grad accum with multiple epochs (#22098)
stas00 Mar 13, 2023
a3cb682
[deepspeed docs] Activation Checkpointing (#22099)
stas00 Mar 13, 2023
5a1f4f8
Remove backend check for torch.compile (#22140)
sgugger Mar 13, 2023
7703f02
[Safetensors] Add explicit flag to from pretrained (#22083)
patrickvonplaten Mar 13, 2023
4c876a3
Prepare daily CI for torch 2.0.0 (#22135)
ydshieh Mar 13, 2023
f3d55b5
docs: New terms and updates to glossary (#21982)
MichaelRipa Mar 13, 2023
5c055e4
[🛠️] Fix-whisper-breaking-changes (#21965)
ArthurZucker Mar 14, 2023
cc22fe6
Move `is_pipeline_test_to_skip` to specific model test classes (#21999)
ydshieh Mar 14, 2023
86b33f3
Add ConvNeXT V2 (#21679)
alaradirik Mar 14, 2023
6d0b040
Update 2 doctest expected values for torch 2.0.0 (#22148)
ydshieh Mar 14, 2023
4501ebe
Translation Italian: perf_train_cpu and perf_train_cpu_many (#22151)
nickprock Mar 14, 2023
4f9f190
Fix big model inference for T5 models in float16 (#22095)
sgugger Mar 14, 2023
2e5b600
Create MaskedImageCompletionOutput and fix ViT docs (#22152)
alaradirik Mar 14, 2023
730bf0c
to_pil - don't rescale if int and in range 0-255 (#22158)
amyeroberts Mar 14, 2023
9515690
[trainer] add `--optim adamw_torch_fused` for pt-2.0+ (#22144)
stas00 Mar 14, 2023
a115f40
Revert "Enforce same behavior as PyTorch 2.0 for older versions" (#22…
sgugger Mar 14, 2023
7989275
v4.28.0.dev0
sgugger Mar 14, 2023
741ec6b
Load optimizer state on CPU to avoid CUDA OOM (#22159)
sgugger Mar 14, 2023
350740e
Run all tests by default (#22162)
sgugger Mar 14, 2023
81b794f
Fix: unfinished_sequences with correct device (#22184)
Stxr Mar 15, 2023
3704f82
Revert 22152 MaskedImageCompletionOutput changes (#22187)
amyeroberts Mar 15, 2023
f4c7551
Regression pipeline device (#22190)
sgugger Mar 15, 2023
1222535
Update BridgeTowerForContrastiveLearning (#22145)
abhiwand Mar 15, 2023
860a3cd
t5 remove data dependency (#22097)
prathikr Mar 15, 2023
cd0d499
Fix DeepSpeed CI (#22194)
ydshieh Mar 16, 2023
6ac5880
Fix typo in Align docs (#22199)
alaradirik Mar 16, 2023
3aa4b1d
Update expected values in `MgpstrModelIntegrationTest` (#22195)
ydshieh Mar 16, 2023
bd9f26a
Italian Translation of migration.mdx (#22183)
Baelish03 Mar 16, 2023
5e10dfe
LLaMA Implementation (#21955)
zphang Mar 16, 2023
45d7741
LLaMA Implementation (#21955)
zphang Mar 16, 2023
fdce70e
Update tiny model creation script (#22202)
ydshieh Mar 16, 2023
0bb5eb9
Temporarily fix ONNX model exporting error (#21830)
SatyaJandhyalaAtMS Mar 16, 2023
a0e90ae
[`XGLM`] Add `accelerate` support for XGLM (#22207)
younesbelkada Mar 16, 2023
ff4142a
fixes a typo in WhisperFeatureExtractor docs. (#22208)
susnato Mar 16, 2023
b24b1a5
🔥py38 + torch 2 🔥🔥🔥🚀 (#22204)
ydshieh Mar 16, 2023
d10e67f
Hotfix for natten issue with torch 2.0.0 on CircleCI (#22218)
ydshieh Mar 16, 2023
bbb5691
fix typos in llama.mdx (#22223)
keturn Mar 17, 2023
0a87930
fix code example in mgp-str doc (#22219)
wdp-007 Mar 17, 2023
e057108
Use `dash==2.8.1` for now for daily CI (#22227)
ydshieh Mar 17, 2023
98613dc
Depth estimation task guide (#22205)
MKhalusova Mar 17, 2023
a6a030d
LLaMA house-keeping (#22216)
sgugger Mar 17, 2023
2c73d7b
fix AutoTP in deepspeed could not work for bloom (#22196)
sywangyi Mar 17, 2023
a09ef84
Add LlamaForSequenceClassification (#22209)
lewtun Mar 17, 2023
dae8b3b
Removed .mdx extension in two links (#22230)
MKhalusova Mar 17, 2023
b923c05
fix(docs): fix task guide links in model docs (#22226)
Seb0 Mar 17, 2023
52e69dd
Fix natten (#22229)
alihassanijr Mar 17, 2023
9607a01
Revert "Use `dash==2.8.1` for now for daily CI" (#22233)
ydshieh Mar 17, 2023
59090ff
Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbeddi…
ma787639046 Mar 17, 2023
0a9926d
[trainer] param count for deepspeed zero3 (#22193)
stas00 Mar 17, 2023
12afa8b
Update training_args.py -- a nightly install is not required anymore …
pminervini Mar 20, 2023
bf2fc80
[Docs] fix typos in some tokenizer docs (#22256)
yesinkim Mar 20, 2023
177fd50
Italian translation perf_infer_cpu (#22243)
nickprock Mar 20, 2023
1c588fe
[Trainer] Add optional communication backends for torch.distributed w…
heya5 Mar 20, 2023
484c993
Fix the gradient checkpointing bug of the llama model (#22270)
yqy2001 Mar 20, 2023
ef5f2a5
Fix balanced and auto device_map (#22271)
sgugger Mar 20, 2023
b2543cc
Rework a bit the LLaMA conversion script (#22236)
sgugger Mar 20, 2023
71b06fa
Proper map location for optimizer load (#22273)
sgugger Mar 20, 2023
05fa5c7
Fix doc links (#22274)
amyeroberts Mar 20, 2023
753db9f
Move torch.compile() wrapping after DDP/FSDP wrapping to ensure corre…
ani300 Mar 20, 2023
3feec97
Example of pad_to_multiple_of for padding and truncation guide & docs…
MKhalusova Mar 20, 2023
061d16b
Update vision docstring bool masked pos (#22237)
amyeroberts Mar 20, 2023
1703d08
replace_8bit_linear modules_to_not_convert default value fix (#22238)
Mar 21, 2023
6f9eef8
Fix error in mixed precision training of `TFCvtModel` (#22267)
gcuder Mar 21, 2023
afc258d
More doctests (#22268)
ydshieh Mar 21, 2023
b9955aa
fix more doctests (#22292)
ydshieh Mar 21, 2023
f282b2d
Add translation perf_infer_gpu_one for it (#22296)
davidegazze Mar 21, 2023
0919d62
Time to Say Goodbye, torch 1.7 and 1.8 (#22291)
ydshieh Mar 21, 2023
5d15c1b
Restore fp16 support on xla gpu device (#22300)
ymwangg Mar 21, 2023
a644393
Correct NATTEN function signatures and force new version (#22298)
alihassanijr Mar 21, 2023
e3eb222
[deepspeed] offload + non-cpuadam optimizer exception doc (#22044)
stas00 Mar 22, 2023
9a79244
Final update of doctest (#22299)
ydshieh Mar 22, 2023
79468e5
Add MaskedImageModelingOutput (#22212)
alaradirik Mar 22, 2023
378ce25
Enable traced model for text-generation task (#22265)
jiqing-feng Mar 22, 2023
2c37c0f
add low_cpu_mem_usage option in run_clm.py example which will benefit…
sywangyi Mar 22, 2023
9b1c935
fix: Allow only test_file in pytorch and flax summarization (#22293)
Mar 22, 2023
4072d18
Fix position embeddings for GPT-J and CodeGen (#22069)
njhill Mar 22, 2023
963539e
Fixed bug to calculate correct xpath_sub_list in MarkupLMTokenizer (#…
silentghoul-spec Mar 22, 2023
eb6409d
Enforce `max_memory` for device_map strategies (#22311)
sgugger Mar 22, 2023
e0b2ba2
Generate: Export TF generate with a TF tokenizer (#22310)
gante Mar 22, 2023
88ce6d5
Beef up Llama tests (#22314)
gante Mar 22, 2023
f9d07cc
Add Pix2Struct (#21400)
younesbelkada Mar 22, 2023
561f83c
docs: Resolve incorrect type typo in trainer methods (#22316)
tomaarsen Mar 22, 2023
197a6af
Chunkable token classification pipeline (#21771)
luccailliau Mar 22, 2023
b4baf0a
Fix PipelineTests skip conditions (#22320)
ydshieh Mar 22, 2023
87c6bb5
[deepspeed zero3] need `generate(synced_gpus=True, ...)` (#22242)
stas00 Mar 22, 2023
4e01955
Fix quality due to ruff release
sgugger Mar 23, 2023
9600073
Really fix quality due to ruff release
sgugger Mar 23, 2023
a33d4aa
[gptj] support older pytorch version (#22325)
stas00 Mar 23, 2023
c92d679
[`MBart`] Add `accelerate` support for MBart (#22309)
younesbelkada Mar 23, 2023
a53fe3d
Fixed gradient checkpoint bug for TimeSeriesTransformer (#22272)
pmollerus23 Mar 23, 2023
4e6a9de
Mention why one needs to specify max_steps in Trainer (#22333)
lhoestq Mar 23, 2023
6744e67
Fix various imports (#22281)
sgugger Mar 23, 2023
34583c1
Minor typo in pipeline FillMaskPipeline's documentation. (#22339)
SamuelLarkin Mar 23, 2023
e5cd789
Added type hints to TFDeiTModel (#22327)
Batese2001 Mar 23, 2023
892ec86
Fix --bf16 option support for Neuron after PR #22300 (#22307)
jeffhataws Mar 23, 2023
505628f
Generate: add test for left-padding support (#22322)
gante Mar 23, 2023
8c55fe7
Enable training Llama with model or pipeline parallelism (#22329)
kooshi Mar 23, 2023
63d4c0e
Automatically create/update tiny models (#22275)
ydshieh Mar 23, 2023
7e64776
[HFTracer] Make embeddings ops take on the dtype of the weight (#22347)
jamesr66a Mar 24, 2023
4e53190
Fix typo in Greedy Search Description (#22345)
awinml Mar 24, 2023
75a00cf
Generate: Add GPTNeoX integration test (#22346)
gante Mar 24, 2023
f876ad9
Add Mega: Moving Average Equipped Gated Attention (#21766)
mnaylor5 Mar 24, 2023
5543b2e
Update docker files to use official torch 2.0.0 (#22357)
ydshieh Mar 24, 2023
b43dac8
Pin tensorflow-text to go with tensorflow (#22362)
sgugger Mar 24, 2023
72277c2
Improve error message (#22361)
Mahrkeenerh Mar 24, 2023
6f7580b
TensorFlow: pin maximum version to 2.12 (#22364)
gante Mar 24, 2023
9b89cf6
Resnet flax (#21472)
Shubhamai Mar 24, 2023
71aa651
[Trainer] add disclaimer that full_determinism is slow (#22368)
stas00 Mar 24, 2023
ada5269
Fix TF pipeline job
sgugger Mar 24, 2023
7177ed1
[safetensors] don't use in `torch<1.10` (#22370)
stas00 Mar 24, 2023
3252dc3
TensorFlow: additional missing `cmake` dependencies in CI (#22383)
gante Mar 27, 2023
2fee8e6
Changed world_size() to get_world_size() bugfix (#22381)
Charlie-Bell Mar 27, 2023
69409c5
Translated documentation in italian (#22388)
nickprock Mar 27, 2023
685e37b
Adapt find_tied_parameters to handle breaking change in Accelerate (#…
sgugger Mar 27, 2023
c4b9338
load_in_8bit now respects 'balanced' device maps in multi-gpu environ…
kooshi Mar 27, 2023
5374138
Wav2Vec2ProcessorWithLM can return N best hypotheses now (#22235)
vsokolovskii Mar 27, 2023
659920a
Seq2seq trainer generation config arg (#22323)
Natooz Mar 27, 2023
791064c
Generate: support for left-padding on GPTNeoX and Llama (#22382)
gante Mar 27, 2023
87a1a3c
[`bnb`] Force `requires_grad` to be `False` (#22396)
younesbelkada Mar 27, 2023
19eacee
Transformers env safetensors (#22400)
sgugger Mar 27, 2023
e7e9246
[Pix2Struct] Add support to resize embeddings (#22394)
NielsRogge Mar 27, 2023
10a7ea7
Trainer: move Seq2SeqTrainer imports under the typing guard (#22401)
gante Mar 27, 2023
34b4d50
Trainer: missing None check (#22404)
gante Mar 27, 2023
d144ad4
Hardware Auto-Setup for Examples (#22319)
dongreenberg Mar 27, 2023
2ec54c7
Fix quality
sgugger Mar 27, 2023
0cab406
[WIP]`NLLB-MoE` Adds the moe model (#22024)
ArthurZucker Mar 27, 2023
7ea9876
[neptune] fix checkpoint bug with relative out_dir (#22102)
kshitij12345 Mar 27, 2023
3d30610
Bump redis from 4.1.4 to 4.5.3 in /examples/research_projects/decisio…
dependabot[bot] Mar 28, 2023
f7bc749
Fix bug in perplexity guide calculations and update perplexity number…
fpgaminer Mar 28, 2023
c426a6d
[performance] ensure `causal_mask` is created directly on device (#22…
jeffra Mar 28, 2023
0755419
MBart: Fix docs and doctests (#22422)
gante Mar 28, 2023
dc38b93
Add clean_up_tokenization_spaces to config (#22341)
ArthurZucker Mar 29, 2023
fc3f40d
Hyperparameter search reporting to W&B (#22440)
NoB0 Mar 29, 2023
1ead501
[`bnb`] fix bnb failing test (#22439)
younesbelkada Mar 29, 2023
5ff4808
[`Generate`] Add conditional generation for multimodal models (#22424)
younesbelkada Mar 29, 2023
93187ce
Don't hard error when cache version can't be converted to int (#22427)
sgugger Mar 29, 2023
80b7e6c
Use real tokenizers if tiny version(s) creation has issue(s) (#22428)
ydshieh Mar 29, 2023
9ac11ec
Revert "Error (also in original) model, scaling only q matrix not qk.…
sgugger Mar 29, 2023
deb36f0
[`Pix2Struct`] Fix slow test (#22448)
younesbelkada Mar 29, 2023
e832063
Revert "Fix --bf16 option support for Neuron after PR #22300" (#22451)
jeffhataws Mar 29, 2023
78293e2
Update Neptune docs (#22452)
Mar 29, 2023
85cbc37
Avoid using personal HF token in CI (#22453)
ydshieh Mar 29, 2023
b35dcee
Update release instructions (#22454)
sgugger Mar 29, 2023
00dd0d1
Pin ruff (#22455)
sgugger Mar 29, 2023
720a290
Update: ignore padding support for TransfoXL training when n_clusters…
StefanHeng Mar 29, 2023
817428d
Move common properties to BackboneMixin (#21855)
amyeroberts Mar 30, 2023
11d5ddf
Rescale image back if it was scaled during PIL conversion (#22458)
amyeroberts Mar 30, 2023
ccee612
Skip flaky NLLB Moe test for now (#22463)
amyeroberts Mar 30, 2023
8e0179a
Generate: basic token streaming (#22449)
gante Mar 30, 2023
944ff59
🚨🚨🚨 Fix ordering of height, width for BLIP image processor (#22466)
amyeroberts Mar 30, 2023
da76cec
Guard imports of PreTrainedTokenizerFast on is_tokenizers_available (…
Mar 30, 2023
6a7b834
[NLLB-MoE] `model_type` update for auto mapping (#22470)
ArthurZucker Mar 30, 2023
1d2f628
Llama: support for `max_position_embeddings` (#22471)
gante Mar 30, 2023
49ae6c8
Docs fix: Multinomial sampling decoding needs "num_beams=1", since by…
manueldeprada Mar 30, 2023
7492a1e
(Re-)Enable Nightly + Past CI (#22393)
ydshieh Mar 30, 2023
71422fc
Relax `eos_token_id < 0` checks in `generate()` from `ValueError` to …
lewtun Mar 31, 2023
b37668d
Update `Wav2Vec2ProcessorWithLM` doc example (#22474)
ydshieh Mar 31, 2023
3c815cb
Making sure we can use safetensors to serialize all the time. (#22437)
Narsil Mar 31, 2023
76d9d9d
Bump redis from 4.5.3 to 4.5.4 in /examples/research_projects/decisio…
dependabot[bot] Mar 31, 2023
67308c4
Update Neptune callback docstring (#22497)
Mar 31, 2023
6cca3e4
Test fetch v2 (#22367)
sgugger Mar 31, 2023
29e8140
Update convert_llama_weights_to_hf.py (#22525)
Ricardokevins Apr 3, 2023
89c2557
Backbone add out indices (#22493)
amyeroberts Apr 3, 2023
c7781fe
[Time-Series] fix past_observed_mask type (#22076)
elisim Apr 3, 2023
a2e3570
Fix llama tokenizer (#22402)
ArthurZucker Apr 3, 2023
aa1a366
[WIP] docs: ko: sagemaker.mdx (#22509)
jungnerd Apr 3, 2023
fb09652
added biogpt token classifier (#22447)
upjabir Apr 3, 2023
bbad89b
Generate: `TextIteratorStreamer` (streamer for gradio) (#22501)
gante Apr 3, 2023
a08de03
Fix convert_opt_original_pytorch_checkpoint_to_pytorch.py typo (#22526)
larekrow Apr 3, 2023
af961bf
llama docs: fix conversion script url (#22514)
python273 Apr 3, 2023
5f2e5cd
fix LayoutLMv3TokenizerFast subword label after 'Ġ' token (#21695)
thibaultdouzon Apr 3, 2023
fe24b03
[BLIP] fix cross attentions for BlipTextEncoder (#22515)
zhbh01 Apr 3, 2023
0d419e1
[`Trainer`] Force `is_model_parallel` when model is loaded in multipl…
younesbelkada Apr 3, 2023
1376d60
[`T5`] Enable naive Pipeline Parallelism training for T5 (#22535)
younesbelkada Apr 3, 2023
7adc163
Fix missing metrics with multiple eval datasets (#22536)
hawkeoni Apr 3, 2023
2cccab4
[setup] drop deprecated `distutils` usage (#22531)
XuehaiPan Apr 3, 2023
a02907d
Generate: Enable easier TextStreamer customization (#22516)
vblagoje Apr 3, 2023
ecac59c
[setup] migrate setup script to `pyproject.toml` (#22539)
XuehaiPan Apr 3, 2023
d5372fd
Skip failing test
sgugger Apr 3, 2023
14f22d1
Update test_image_processing_pix2struct.py (#22543)
younesbelkada Apr 3, 2023
889607f
Fix OPTForQuestionAnswering doc string (#22481)
curlup Apr 4, 2023
995ad7b
Generate: Add text streamer decoding options (#22544)
gante Apr 4, 2023
b8edf87
[Roformer] Fixing a bug in RoFormerEncoder where it was ignoring the …
TheWall9 Apr 4, 2023
54e0182
🚨🚨🚨 `[NLLB Tokenizer]` Fix the prefix tokens 🚨🚨🚨 (#22313)
ArthurZucker Apr 4, 2023
53e5389
Implemented safetensors checkpoints save/load for Trainer (#22498)
viktor-shcherb Apr 4, 2023
19ce467
Remove hack for dynamic modules and use Python functions instead (#22…
sgugger Apr 4, 2023
3a1e4b7
[`bnb`] Fix typo (#22556)
younesbelkada Apr 4, 2023
763a78e
Add id2label and label2id to model's config in run_xnil (#22558)
maziyarpanahi Apr 4, 2023
dad2f7f
Soft error whisper. (#22475)
Narsil Apr 4, 2023
8b5ee45
Add TF port of BLIP (#22090)
Rocketknight1 Apr 4, 2023
271bc94
corrected the code comment for the output of find_pruneable_heads_and…
SunHaozhe Apr 4, 2023
10dc0e1
Flax Regnet (#21867)
Shubhamai Apr 4, 2023
f4070f7
fix `_no_split_modules` for Whisper model (#22486)
pacman100 Apr 4, 2023
63db711
Fix inverted conditional in TF common test! (#22540)
Rocketknight1 Apr 4, 2023
05efd03
Skip failing test
sgugger Apr 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/model_doc/whisper.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,9 @@ The original code can be found [here](https://github.com/openai/whisper).

[[autodoc]] FlaxWhisperForConditionalGeneration
- __call__

## FlaxWhisperForAudioClassification

[[autodoc]] FlaxWhisperForAudioClassification
- __call__

8 changes: 7 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3627,6 +3627,7 @@
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]
)
_import_structure["models.xglm"].extend(
Expand Down Expand Up @@ -6620,7 +6621,12 @@
FlaxWav2Vec2Model,
FlaxWav2Vec2PreTrainedModel,
)
from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel
from .models.whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
)
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/models/auto/modeling_flax_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@
]
)

FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
("whisper", "FlaxWhisperForAudioClassification"),
]
)


FLAX_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_MAPPING_NAMES)
FLAX_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Expand Down Expand Up @@ -255,6 +261,9 @@
FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)


class FlaxAutoModel(_BaseAutoModelClass):
Expand Down Expand Up @@ -355,6 +364,10 @@ class FlaxAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING


class FlaxAutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING


FlaxAutoModelForSpeechSeq2Seq = auto_class_update(
FlaxAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"FlaxWhisperForConditionalGeneration",
"FlaxWhisperModel",
"FlaxWhisperPreTrainedModel",
"FlaxWhisperForAudioClassification",
]


Expand Down Expand Up @@ -126,6 +127,7 @@
pass
else:
from .modeling_flax_whisper import (
FlaxWhisperForAudioClassification,
FlaxWhisperForConditionalGeneration,
FlaxWhisperModel,
FlaxWhisperPreTrainedModel,
Expand Down
175 changes: 175 additions & 0 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
Expand Down Expand Up @@ -1468,3 +1469,177 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
append_replace_return_docstrings(
FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)


class FlaxWhisperForAudioClassificationModule(nn.Module):
config: WhisperConfig
dtype: jnp.dtype = jnp.float32
gradient_checkpointing: bool = False
Comment thread
raghavanone marked this conversation as resolved.
Outdated

def setup(self) -> None:
self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype)
self.config.is_encoder_decoder = False
num_layers = self.config.num_hidden_layers + 1
if self.config.use_weighted_layer_sum:
self.layer_weights = jnp.repeat(1 / num_layers, num_layers)
self.projector = nn.Dense(self.config.classifier_proj_size)
Comment thread
raghavanone marked this conversation as resolved.
Outdated
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)

def __call__(
self,
input_features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states: bool = True,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

if self.config.use_weighted_layer_sum:
hidden_states = jnp.stack(encoder_outputs, dim=1)
Comment thread
raghavanone marked this conversation as resolved.
Outdated
norm_weights = jax.nn.softmax(self.layer_weights, axis=-1)
hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1)
else:
hidden_states = encoder_outputs[0]

hidden_states = self.projector(hidden_states)
pooled_output = jnp.mean(hidden_states, axis=1)

logits = self.classifier(pooled_output)

if not return_dict:
return (logits,) + encoder_outputs[1:]

return FlaxSequenceClassifierOutput(
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


@add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING)
class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel):
module_class = FlaxWhisperForAudioClassificationModule
dtype: jnp.dtype = jnp.float32

def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_features = jnp.zeros(input_shape, dtype="f4")
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)

decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
Comment thread
raghavanone marked this conversation as resolved.
Outdated
jnp.ones_like(decoder_input_ids)
Comment thread
raghavanone marked this conversation as resolved.
Outdated

batch_size, sequence_length = decoder_input_ids.shape
jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

random_params = self.module.init(
rngs,
input_features=input_features,
)["params"]

if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params

@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
def __call__(
self,
input_features: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
**kwargs,
):
r"""
Comment thread
raghavanone marked this conversation as resolved.
Outdated
Returns:

Example:

```python
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
>>> from datasets import load_dataset

>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> encoder_outputs = model.encode(input_features=input_features)
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict

# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng

def _encoder_forward(module, input_features, **kwargs):
Comment thread
raghavanone marked this conversation as resolved.
Outdated
encode_module = module._get_encoder_module()
return encode_module(input_features, **kwargs)

return self.module.apply(
{"params": params or self.params},
input_features=jnp.array(input_features, dtype="f4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
# method=_encoder_forward,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this commented line too

)


FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r"""
Returns:

Transcription example:

```python
>>> from transformers import WhisperProcessor, FlaxWhisperForAudioClassification
>>> from datasets import load_dataset

>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = FlaxWhisperForAudioClassification.from_pretrained("openai/whisper-tiny.en", from_pt=True)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
Comment thread
raghavanone marked this conversation as resolved.
Outdated
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
>>> input_features = inputs.input_features
>>> outputs = model(input_features=input_features)
>>> logits = outputs.logit
```
"""

overwrite_call_docstring(
FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_flax_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxWhisperForAudioClassification(metaclass=DummyObject):
_backends = ["flax"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])


class FlaxWhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]

Expand Down
Loading