-
Notifications
You must be signed in to change notification settings - Fork 33.7k
Add FlaxWhisperForAudioClassification model #21894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 20 commits
Commits
Show all changes
277 commits
Select commit
Hold shift + click to select a range
2f66a42
[WIP] Add BridgeTowerForContrastiveLearning (#21964)
abhiwand fe03e51
Fix test for torchneuroncore in Trainer (#22028)
sgugger 781edc5
Add tokenize_kwargs parameter definition in the FeatureExtractionPipe…
anruijian 57b3a97
[examples/speech-recognition] Add SpecAugment to run_speech_recogniti…
bofenghuang 9a10d4b
fixes the gradient checkpointing of whisper (#22019)
soma2000-lang 65f56ec
Avoid `text_config_dict` and `vision_config_dict` being saved for CL…
ydshieh 3f29db2
Mark all `BridgeTower` tests slow for now (#22039)
ydshieh 2a5f185
Bug fix: token classification pipeline while passing offset_mapping (…
cceyda f89b95c
Update ALIGN docs (#22025)
alaradirik 0d08783
[21737][T5]: Fix gradient checkpoint bug (#22036)
nipunjindal 67bd8ee
Docs Improvement - In ZSH, not using ' ' around pip install fails, fi…
shaun-scale c2420fd
Can't install tf2 on M1 Chip by default (#22046)
shaun-scale da98339
Remove set_access_token usage + fail tests if FutureWarning (#22051)
Wauplin 6fd5f2f
Show the number of `huggingface_hub` warnings in CI report (#22054)
ydshieh d95d2d9
Return analysis for hyperparameter_search with Ray backend (#22040)
anruijian a7dacfb
pt-to-tf model architecture override (#22055)
Rocketknight1 c8437ea
rm $ symbol from code block from contributing.md (#22057)
kamalkraj 09e9344
[deepspeed] offload + non-cpuadam optimizer exception (#22043)
stas00 e39722e
Edit the docstring of `image_processing_donut` to match code (#22033)
vermouthmjl bf97e22
Skip 3 tests for `WhisperEncoderModelTest` (#22060)
ydshieh 4405fa4
Add setters by type of args to TrainingArguments (#21570)
sgugger a4e1c7b
Update tiny model creation script (#22058)
ydshieh 777f921
Fix case when using --gradient_accumulation_steps with DDP disabled. …
sangeethabal 1e9ccb6
Add a progress bar for the total download of shards (#22062)
sgugger b84287e
Fix gradient checkpointing bug in Speech2Text (#22079)
KMFODA ae930f0
Fix gradient checkpointing bug in switch transformer (#22081)
KMFODA 3d42b72
[GPT2] Propose fix for #21080 (#21853)
ArthurZucker 13289d4
Fix small typo in flan-ul2.mdx (#22068)
kevin51jiang c78d87a
Generate - Fix broken documentation links (#22078)
gante 0993014
Fix gradient checkpointing bug in Speecht5 (#22080)
KMFODA 9199bf7
Fix hint in src/transformers/modeling_utils.py (#22074)
J-shang 6a42232
handle numpy inputs in whole word mask data collator (#22032)
dwyatte 18fa8e9
GPT-J specific half precision on CPU note (#22086)
MKhalusova 9fd5b5e
Fix imports of TF MobileViT (#22065)
sgugger ffddc14
Revert "[GPT2] Propose fix for #21080" (#22093)
ydshieh 30ccdfa
[Whisper] Remove embed_tokens from encoder docstring (#21996)
sanchit-gandhi 733b396
Add AutoModelForZeroShotImageClassification (#22087)
alaradirik 010d238
add new model of MGP-STR (#21418)
wdp-007 c211bd5
Add pr_checks.mdx Italian translation (#17459) (#22116)
alexcalabrese 6048e88
Fix gradient checkpointing bug in xglm (#22127)
KMFODA 27ca366
Fix gradient checkpointing bug in Trajectory Transformer (#22125)
KMFODA 3e32345
Fix gradient checkpointing bug in xlm_roberta_xl (#22128)
KMFODA 9c13edb
Added big_models.mdx italian translation #17600 (#22115)
nickprock 718ec35
[`Blip2`] skip accelerate test (#22124)
younesbelkada d56076d
Fix gradient checkpointing bug in xmod (#22129)
KMFODA b22b7b1
Fix gradient checkpointing bug in LongT5 (#22130)
KMFODA c349eda
Fix gradient checkpointing bug in trocr (#22126)
KMFODA 165bd4a
Zero-shot image classification task guide (#22132)
MKhalusova 798d110
Fix doc link for MGP-STR (#22138)
sgugger 729251a
Adding Type Hints to TF_Pegasus model (#21941)
pmollerus23 d56ad12
Add a new script to check model testers' config (#22063)
ydshieh c53441c
Update configuration_align.py (projected_dim=640) (#22139)
bishmdl76 f3067cc
[`Whiper`] add `get_input_embeddings` to `WhisperForAudioClassificati…
younesbelkada e6bbfa8
Trainer: let generate pick its inputs (#22108)
gante 382b5ec
Enforce same behavior as PyTorch 2.0 for older versions (#22136)
sgugger 2e0898b
[trainer] fix bug in grad accum with multiple epochs (#22098)
stas00 a3cb682
[deepspeed docs] Activation Checkpointing (#22099)
stas00 5a1f4f8
Remove backend check for torch.compile (#22140)
sgugger 7703f02
[Safetensors] Add explicit flag to from pretrained (#22083)
patrickvonplaten 4c876a3
Prepare daily CI for torch 2.0.0 (#22135)
ydshieh f3d55b5
docs: New terms and updates to glossary (#21982)
MichaelRipa 5c055e4
[🛠️] Fix-whisper-breaking-changes (#21965)
ArthurZucker cc22fe6
Move `is_pipeline_test_to_skip` to specific model test classes (#21999)
ydshieh 86b33f3
Add ConvNeXT V2 (#21679)
alaradirik 6d0b040
Update 2 doctest expected values for torch 2.0.0 (#22148)
ydshieh 4501ebe
Translation Italian: perf_train_cpu and perf_train_cpu_many (#22151)
nickprock 4f9f190
Fix big model inference for T5 models in float16 (#22095)
sgugger 2e5b600
Create MaskedImageCompletionOutput and fix ViT docs (#22152)
alaradirik 730bf0c
to_pil - don't rescale if int and in range 0-255 (#22158)
amyeroberts 9515690
[trainer] add `--optim adamw_torch_fused` for pt-2.0+ (#22144)
stas00 a115f40
Revert "Enforce same behavior as PyTorch 2.0 for older versions" (#22…
sgugger 7989275
v4.28.0.dev0
sgugger 741ec6b
Load optimizer state on CPU to avoid CUDA OOM (#22159)
sgugger 350740e
Run all tests by default (#22162)
sgugger 81b794f
Fix: unfinished_sequences with correct device (#22184)
Stxr 3704f82
Revert 22152 MaskedImageCompletionOutput changes (#22187)
amyeroberts f4c7551
Regression pipeline device (#22190)
sgugger 1222535
Update BridgeTowerForContrastiveLearning (#22145)
abhiwand 860a3cd
t5 remove data dependency (#22097)
prathikr cd0d499
Fix DeepSpeed CI (#22194)
ydshieh 6ac5880
Fix typo in Align docs (#22199)
alaradirik 3aa4b1d
Update expected values in `MgpstrModelIntegrationTest` (#22195)
ydshieh bd9f26a
Italian Translation of migration.mdx (#22183)
Baelish03 5e10dfe
LLaMA Implementation (#21955)
zphang 45d7741
LLaMA Implementation (#21955)
zphang fdce70e
Update tiny model creation script (#22202)
ydshieh 0bb5eb9
Temporarily fix ONNX model exporting error (#21830)
SatyaJandhyalaAtMS a0e90ae
[`XGLM`] Add `accelerate` support for XGLM (#22207)
younesbelkada ff4142a
fixes a typo in WhisperFeatureExtractor docs. (#22208)
susnato b24b1a5
🔥py38 + torch 2 🔥🔥🔥🚀 (#22204)
ydshieh d10e67f
Hotfix for natten issue with torch 2.0.0 on CircleCI (#22218)
ydshieh bbb5691
fix typos in llama.mdx (#22223)
keturn 0a87930
fix code example in mgp-str doc (#22219)
wdp-007 e057108
Use `dash==2.8.1` for now for daily CI (#22227)
ydshieh 98613dc
Depth estimation task guide (#22205)
MKhalusova a6a030d
LLaMA house-keeping (#22216)
sgugger 2c73d7b
fix AutoTP in deepspeed could not work for bloom (#22196)
sywangyi a09ef84
Add LlamaForSequenceClassification (#22209)
lewtun dae8b3b
Removed .mdx extension in two links (#22230)
MKhalusova b923c05
fix(docs): fix task guide links in model docs (#22226)
Seb0 52e69dd
Fix natten (#22229)
alihassanijr 9607a01
Revert "Use `dash==2.8.1` for now for daily CI" (#22233)
ydshieh 59090ff
Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbeddi…
ma787639046 0a9926d
[trainer] param count for deepspeed zero3 (#22193)
stas00 12afa8b
Update training_args.py -- a nightly install is not required anymore …
pminervini bf2fc80
[Docs] fix typos in some tokenizer docs (#22256)
yesinkim 177fd50
Italian translation perf_infer_cpu (#22243)
nickprock 1c588fe
[Trainer] Add optional communication backends for torch.distributed w…
heya5 484c993
Fix the gradient checkpointing bug of the llama model (#22270)
yqy2001 ef5f2a5
Fix balanced and auto device_map (#22271)
sgugger b2543cc
Rework a bit the LLaMA conversion script (#22236)
sgugger 71b06fa
Proper map location for optimizer load (#22273)
sgugger 05fa5c7
Fix doc links (#22274)
amyeroberts 753db9f
Move torch.compile() wrapping after DDP/FSDP wrapping to ensure corre…
ani300 3feec97
Example of pad_to_multiple_of for padding and truncation guide & docs…
MKhalusova 061d16b
Update vision docstring bool masked pos (#22237)
amyeroberts 1703d08
replace_8bit_linear modules_to_not_convert default value fix (#22238)
6f9eef8
Fix error in mixed precision training of `TFCvtModel` (#22267)
gcuder afc258d
More doctests (#22268)
ydshieh b9955aa
fix more doctests (#22292)
ydshieh f282b2d
Add translation perf_infer_gpu_one for it (#22296)
davidegazze 0919d62
Time to Say Goodbye, torch 1.7 and 1.8 (#22291)
ydshieh 5d15c1b
Restore fp16 support on xla gpu device (#22300)
ymwangg a644393
Correct NATTEN function signatures and force new version (#22298)
alihassanijr e3eb222
[deepspeed] offload + non-cpuadam optimizer exception doc (#22044)
stas00 9a79244
Final update of doctest (#22299)
ydshieh 79468e5
Add MaskedImageModelingOutput (#22212)
alaradirik 378ce25
Enable traced model for text-generation task (#22265)
jiqing-feng 2c37c0f
add low_cpu_mem_usage option in run_clm.py example which will benefit…
sywangyi 9b1c935
fix: Allow only test_file in pytorch and flax summarization (#22293)
4072d18
Fix position embeddings for GPT-J and CodeGen (#22069)
njhill 963539e
Fixed bug to calculate correct xpath_sub_list in MarkupLMTokenizer (#…
silentghoul-spec eb6409d
Enforce `max_memory` for device_map strategies (#22311)
sgugger e0b2ba2
Generate: Export TF generate with a TF tokenizer (#22310)
gante 88ce6d5
Beef up Llama tests (#22314)
gante f9d07cc
Add Pix2Struct (#21400)
younesbelkada 561f83c
docs: Resolve incorrect type typo in trainer methods (#22316)
tomaarsen 197a6af
Chunkable token classification pipeline (#21771)
luccailliau b4baf0a
Fix PipelineTests skip conditions (#22320)
ydshieh 87c6bb5
[deepspeed zero3] need `generate(synced_gpus=True, ...)` (#22242)
stas00 4e01955
Fix quality due to ruff release
sgugger 9600073
Really fix quality due to ruff release
sgugger a33d4aa
[gptj] support older pytorch version (#22325)
stas00 c92d679
[`MBart`] Add `accelerate` support for MBart (#22309)
younesbelkada a53fe3d
Fixed gradient checkpoint bug for TimeSeriesTransformer (#22272)
pmollerus23 4e6a9de
Mention why one needs to specify max_steps in Trainer (#22333)
lhoestq 6744e67
Fix various imports (#22281)
sgugger 34583c1
Minor typo in pipeline FillMaskPipeline's documentation. (#22339)
SamuelLarkin e5cd789
Added type hints to TFDeiTModel (#22327)
Batese2001 892ec86
Fix --bf16 option support for Neuron after PR #22300 (#22307)
jeffhataws 505628f
Generate: add test for left-padding support (#22322)
gante 8c55fe7
Enable training Llama with model or pipeline parallelism (#22329)
kooshi 63d4c0e
Automatically create/update tiny models (#22275)
ydshieh 7e64776
[HFTracer] Make embeddings ops take on the dtype of the weight (#22347)
jamesr66a 4e53190
Fix typo in Greedy Search Description (#22345)
awinml 75a00cf
Generate: Add GPTNeoX integration test (#22346)
gante f876ad9
Add Mega: Moving Average Equipped Gated Attention (#21766)
mnaylor5 5543b2e
Update docker files to use official torch 2.0.0 (#22357)
ydshieh b43dac8
Pin tensorflow-text to go with tensorflow (#22362)
sgugger 72277c2
Improve error message (#22361)
Mahrkeenerh 6f7580b
TensorFlow: pin maximum version to 2.12 (#22364)
gante 9b89cf6
Resnet flax (#21472)
Shubhamai 71aa651
[Trainer] add disclaimer that full_determinism is slow (#22368)
stas00 ada5269
Fix TF pipeline job
sgugger 7177ed1
[safetensors] don't use in `torch<1.10` (#22370)
stas00 3252dc3
TensorFlow: additional missing `cmake` dependencies in CI (#22383)
gante 2fee8e6
Changed world_size() to get_world_size() bugfix (#22381)
Charlie-Bell 69409c5
Translated documentation in italian (#22388)
nickprock 685e37b
Adapt find_tied_parameters to handle breaking change in Accelerate (#…
sgugger c4b9338
load_in_8bit now respects 'balanced' device maps in multi-gpu environ…
kooshi 5374138
Wav2Vec2ProcessorWithLM can return N best hypotheses now (#22235)
vsokolovskii 659920a
Seq2seq trainer generation config arg (#22323)
Natooz 791064c
Generate: support for left-padding on GPTNeoX and Llama (#22382)
gante 87a1a3c
[`bnb`] Force `requires_grad` to be `False` (#22396)
younesbelkada 19eacee
Transformers env safetensors (#22400)
sgugger e7e9246
[Pix2Struct] Add support to resize embeddings (#22394)
NielsRogge 10a7ea7
Trainer: move Seq2SeqTrainer imports under the typing guard (#22401)
gante 34b4d50
Trainer: missing None check (#22404)
gante d144ad4
Hardware Auto-Setup for Examples (#22319)
dongreenberg 2ec54c7
Fix quality
sgugger 0cab406
[WIP]`NLLB-MoE` Adds the moe model (#22024)
ArthurZucker 7ea9876
[neptune] fix checkpoint bug with relative out_dir (#22102)
kshitij12345 3d30610
Bump redis from 4.1.4 to 4.5.3 in /examples/research_projects/decisio…
dependabot[bot] f7bc749
Fix bug in perplexity guide calculations and update perplexity number…
fpgaminer c426a6d
[performance] ensure `causal_mask` is created directly on device (#22…
jeffra 0755419
MBart: Fix docs and doctests (#22422)
gante dc38b93
Add clean_up_tokenization_spaces to config (#22341)
ArthurZucker fc3f40d
Hyperparameter search reporting to W&B (#22440)
NoB0 1ead501
[`bnb`] fix bnb failing test (#22439)
younesbelkada 5ff4808
[`Generate`] Add conditional generation for multimodal models (#22424)
younesbelkada 93187ce
Don't hard error when cache version can't be converted to int (#22427)
sgugger 80b7e6c
Use real tokenizers if tiny version(s) creation has issue(s) (#22428)
ydshieh 9ac11ec
Revert "Error (also in original) model, scaling only q matrix not qk.…
sgugger deb36f0
[`Pix2Struct`] Fix slow test (#22448)
younesbelkada e832063
Revert "Fix --bf16 option support for Neuron after PR #22300" (#22451)
jeffhataws 78293e2
Update Neptune docs (#22452)
85cbc37
Avoid using personal HF token in CI (#22453)
ydshieh b35dcee
Update release instructions (#22454)
sgugger 00dd0d1
Pin ruff (#22455)
sgugger 720a290
Update: ignore padding support for TransfoXL training when n_clusters…
StefanHeng 817428d
Move common properties to BackboneMixin (#21855)
amyeroberts 11d5ddf
Rescale image back if it was scaled during PIL conversion (#22458)
amyeroberts ccee612
Skip flaky NLLB Moe test for now (#22463)
amyeroberts 8e0179a
Generate: basic token streaming (#22449)
gante 944ff59
🚨🚨🚨 Fix ordering of height, width for BLIP image processor (#22466)
amyeroberts da76cec
Guard imports of PreTrainedTokenizerFast on is_tokenizers_available (…
6a7b834
[NLLB-MoE] `model_type` update for auto mapping (#22470)
ArthurZucker 1d2f628
Llama: support for `max_position_embeddings` (#22471)
gante 49ae6c8
Docs fix: Multinomial sampling decoding needs "num_beams=1", since by…
manueldeprada 7492a1e
(Re-)Enable Nightly + Past CI (#22393)
ydshieh 71422fc
Relax `eos_token_id < 0` checks in `generate()` from `ValueError` to …
lewtun b37668d
Update `Wav2Vec2ProcessorWithLM` doc example (#22474)
ydshieh 3c815cb
Making sure we can use safetensors to serialize all the time. (#22437)
Narsil 76d9d9d
Bump redis from 4.5.3 to 4.5.4 in /examples/research_projects/decisio…
dependabot[bot] 67308c4
Update Neptune callback docstring (#22497)
6cca3e4
Test fetch v2 (#22367)
sgugger 29e8140
Update convert_llama_weights_to_hf.py (#22525)
Ricardokevins 89c2557
Backbone add out indices (#22493)
amyeroberts c7781fe
[Time-Series] fix past_observed_mask type (#22076)
elisim a2e3570
Fix llama tokenizer (#22402)
ArthurZucker aa1a366
[WIP] docs: ko: sagemaker.mdx (#22509)
jungnerd fb09652
added biogpt token classifier (#22447)
upjabir bbad89b
Generate: `TextIteratorStreamer` (streamer for gradio) (#22501)
gante a08de03
Fix convert_opt_original_pytorch_checkpoint_to_pytorch.py typo (#22526)
larekrow af961bf
llama docs: fix conversion script url (#22514)
python273 5f2e5cd
fix LayoutLMv3TokenizerFast subword label after 'Ġ' token (#21695)
thibaultdouzon fe24b03
[BLIP] fix cross attentions for BlipTextEncoder (#22515)
zhbh01 0d419e1
[`Trainer`] Force `is_model_parallel` when model is loaded in multipl…
younesbelkada 1376d60
[`T5`] Enable naive Pipeline Parallelism training for T5 (#22535)
younesbelkada 7adc163
Fix missing metrics with multiple eval datasets (#22536)
hawkeoni 2cccab4
[setup] drop deprecated `distutils` usage (#22531)
XuehaiPan a02907d
Generate: Enable easier TextStreamer customization (#22516)
vblagoje ecac59c
[setup] migrate setup script to `pyproject.toml` (#22539)
XuehaiPan d5372fd
Skip failing test
sgugger 14f22d1
Update test_image_processing_pix2struct.py (#22543)
younesbelkada 889607f
Fix OPTForQuestionAnswering doc string (#22481)
curlup 995ad7b
Generate: Add text streamer decoding options (#22544)
gante b8edf87
[Roformer] Fixing a bug in RoFormerEncoder where it was ignoring the …
TheWall9 54e0182
🚨🚨🚨 `[NLLB Tokenizer]` Fix the prefix tokens 🚨🚨🚨 (#22313)
ArthurZucker 53e5389
Implemented safetensors checkpoints save/load for Trainer (#22498)
viktor-shcherb 19ce467
Remove hack for dynamic modules and use Python functions instead (#22…
sgugger 3a1e4b7
[`bnb`] Fix typo (#22556)
younesbelkada 763a78e
Add id2label and label2id to model's config in run_xnil (#22558)
maziyarpanahi dad2f7f
Soft error whisper. (#22475)
Narsil 8b5ee45
Add TF port of BLIP (#22090)
Rocketknight1 271bc94
corrected the code comment for the output of find_pruneable_heads_and…
SunHaozhe 10dc0e1
Flax Regnet (#21867)
Shubhamai f4070f7
fix `_no_split_modules` for Whisper model (#22486)
pacman100 63db711
Fix inverted conditional in TF common test! (#22540)
Rocketknight1 05efd03
Skip failing test
sgugger File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,7 @@ | |
| FlaxCausalLMOutputWithCrossAttentions, | ||
| FlaxSeq2SeqLMOutput, | ||
| FlaxSeq2SeqModelOutput, | ||
| FlaxSequenceClassifierOutput, | ||
| ) | ||
| from ...modeling_flax_utils import ( | ||
| ACT2FN, | ||
|
|
@@ -1468,3 +1469,177 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): | |
| append_replace_return_docstrings( | ||
| FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC | ||
| ) | ||
|
|
||
|
|
||
| class FlaxWhisperForAudioClassificationModule(nn.Module): | ||
| config: WhisperConfig | ||
| dtype: jnp.dtype = jnp.float32 | ||
| gradient_checkpointing: bool = False | ||
|
|
||
| def setup(self) -> None: | ||
| self.encoder = FlaxWhisperEncoder(config=self.config, dtype=self.dtype) | ||
| self.config.is_encoder_decoder = False | ||
| num_layers = self.config.num_hidden_layers + 1 | ||
| if self.config.use_weighted_layer_sum: | ||
| self.layer_weights = jnp.repeat(1 / num_layers, num_layers) | ||
| self.projector = nn.Dense(self.config.classifier_proj_size) | ||
|
raghavanone marked this conversation as resolved.
Outdated
|
||
| self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) | ||
|
|
||
| def __call__( | ||
| self, | ||
| input_features, | ||
| encoder_outputs=None, | ||
| output_attentions=None, | ||
| output_hidden_states: bool = True, | ||
| return_dict: bool = True, | ||
| ): | ||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
| output_hidden_states = ( | ||
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
| ) | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| if encoder_outputs is None: | ||
| encoder_outputs = self.encoder( | ||
| input_features, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| ) | ||
|
|
||
| if self.config.use_weighted_layer_sum: | ||
| hidden_states = jnp.stack(encoder_outputs, dim=1) | ||
|
raghavanone marked this conversation as resolved.
Outdated
|
||
| norm_weights = jax.nn.softmax(self.layer_weights, axis=-1) | ||
| hidden_states = jnp.sum(hidden_states * jnp.reshape(norm_weights, [-1, 1, 1]), axis=1) | ||
| else: | ||
| hidden_states = encoder_outputs[0] | ||
|
|
||
| hidden_states = self.projector(hidden_states) | ||
| pooled_output = jnp.mean(hidden_states, axis=1) | ||
|
|
||
| logits = self.classifier(pooled_output) | ||
|
|
||
| if not return_dict: | ||
| return (logits,) + encoder_outputs[1:] | ||
|
|
||
| return FlaxSequenceClassifierOutput( | ||
| logits=logits, | ||
| hidden_states=encoder_outputs.hidden_states, | ||
| attentions=encoder_outputs.attentions, | ||
| ) | ||
|
|
||
|
|
||
| @add_start_docstrings("The Whisper Model with an audio classification head on top.", WHISPER_START_DOCSTRING) | ||
| class FlaxWhisperForAudioClassification(FlaxWhisperPreTrainedModel): | ||
| module_class = FlaxWhisperForAudioClassificationModule | ||
| dtype: jnp.dtype = jnp.float32 | ||
|
|
||
| def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: | ||
| # init input tensors | ||
| input_features = jnp.zeros(input_shape, dtype="f4") | ||
| input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) | ||
|
|
||
| decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") | ||
|
raghavanone marked this conversation as resolved.
Outdated
|
||
| jnp.ones_like(decoder_input_ids) | ||
|
raghavanone marked this conversation as resolved.
Outdated
|
||
|
|
||
| batch_size, sequence_length = decoder_input_ids.shape | ||
| jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) | ||
|
|
||
| params_rng, dropout_rng = jax.random.split(rng) | ||
| rngs = {"params": params_rng, "dropout": dropout_rng} | ||
|
|
||
| random_params = self.module.init( | ||
| rngs, | ||
| input_features=input_features, | ||
| )["params"] | ||
|
|
||
| if params is not None: | ||
| random_params = flatten_dict(unfreeze(random_params)) | ||
| params = flatten_dict(unfreeze(params)) | ||
| for missing_key in self._missing_keys: | ||
| params[missing_key] = random_params[missing_key] | ||
| self._missing_keys = set() | ||
| return freeze(unflatten_dict(params)) | ||
| else: | ||
| return random_params | ||
|
|
||
| @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) | ||
| def __call__( | ||
| self, | ||
| input_features: jnp.ndarray, | ||
| attention_mask: Optional[jnp.ndarray] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| train: bool = False, | ||
| params: dict = None, | ||
| dropout_rng: PRNGKey = None, | ||
| **kwargs, | ||
| ): | ||
| r""" | ||
|
raghavanone marked this conversation as resolved.
Outdated
|
||
| Returns: | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration | ||
| >>> from datasets import load_dataset | ||
|
|
||
| >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | ||
| >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) | ||
| >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
| >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") | ||
| >>> input_features = inputs.input_features | ||
| >>> encoder_outputs = model.encode(input_features=input_features) | ||
| ```""" | ||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
| output_hidden_states = ( | ||
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
| ) | ||
| return_dict = return_dict if return_dict is not None else self.config.return_dict | ||
|
|
||
| # Handle any PRNG if needed | ||
| rngs = {} | ||
| if dropout_rng is not None: | ||
| rngs["dropout"] = dropout_rng | ||
|
|
||
| def _encoder_forward(module, input_features, **kwargs): | ||
|
raghavanone marked this conversation as resolved.
Outdated
|
||
| encode_module = module._get_encoder_module() | ||
| return encode_module(input_features, **kwargs) | ||
|
|
||
| return self.module.apply( | ||
| {"params": params or self.params}, | ||
| input_features=jnp.array(input_features, dtype="f4"), | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| rngs=rngs, | ||
| # method=_encoder_forward, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can remove this commented line too |
||
| ) | ||
|
|
||
|
|
||
| FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING = r""" | ||
| Returns: | ||
|
|
||
| Transcription example: | ||
|
|
||
| ```python | ||
| >>> from transformers import WhisperProcessor, FlaxWhisperForAudioClassification | ||
| >>> from datasets import load_dataset | ||
|
|
||
| >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | ||
| >>> model = FlaxWhisperForAudioClassification.from_pretrained("openai/whisper-tiny.en", from_pt=True) | ||
| >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
|
raghavanone marked this conversation as resolved.
Outdated
|
||
| >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") | ||
| >>> input_features = inputs.input_features | ||
| >>> outputs = model(input_features=input_features) | ||
| >>> logits = outputs.logit | ||
| ``` | ||
| """ | ||
|
|
||
| overwrite_call_docstring( | ||
| FlaxWhisperForAudioClassification, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_AUDIO_CLASSIFICATION_DOCSTRING | ||
| ) | ||
| append_replace_return_docstrings( | ||
| FlaxWhisperForAudioClassification, output_type=FlaxSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.