Skip to content

Add FlaxWhisperForAudioClassification model#21894

Closed
raghavanone wants to merge 277 commits into
huggingface:mainfrom
raghavanone:fix_issue_21779
Closed

Add FlaxWhisperForAudioClassification model#21894
raghavanone wants to merge 277 commits into
huggingface:mainfrom
raghavanone:fix_issue_21779

Conversation

@raghavanone

@raghavanone raghavanone commented Mar 2, 2023

Copy link
Copy Markdown
Contributor

What does this PR do?

Fix : #21779

Please review and let me know changes @sanchit-gandhi

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sanchit-gandhi sanchit-gandhi left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modelling code looks good @raghavanone! Nice one on getting this working so quickly 🙌 Do you want to have a go at adding the encoder-only tests? See the PyTorch WhisperForAudioClassficiation PR for details, think you can also add these quite quickly :)

@raghavanone

Copy link
Copy Markdown
Contributor Author

Modelling code looks good @raghavanone! Nice one on getting this working so quickly 🙌 Do you want to have a go at adding the encoder-only tests? See the PyTorch WhisperForAudioClassficiation PR for details, think you can also add these quite quickly :)

I have added the Encoder tests, But some test are failing, The FlaxWhisperForAudioClassification class extends FlaxWhisperPreTrainedModel . Due to this inheritance, the call method expects decoder related params.

Should the FlaxWhisperForAudioClassification not extend FlaxWhisperPreTrainedModel instead create a new pretrainedclass ?

@sanchit-gandhi

sanchit-gandhi commented Mar 7, 2023

Copy link
Copy Markdown
Contributor

Hey @raghavanone! The PyTorch model has just been merged (#21754), so you can rebase onto main to get the required config changes:

git fetch upstream
git rebase upstream/main

This will fix the failing Flax tests we're getting here: https://app.circleci.com/pipelines/github/huggingface/transformers/58972/workflows/2388bd70-553e-412f-9ee7-0599cace5639/jobs/719829

The only thing to make sure is that the first time you push after rebasing, you force push to origin:

git add .
git commit -m "Some new changes after rebase"
git push -f origin fix_issue_21779

You only have to force push once, the next time you can just regular push:

git add .
git commit -m "Some more changes"
git push -u origin fix_issue_21779

Comment thread src/transformers/models/whisper/modeling_flax_whisper.py Outdated
@raghavanone

Copy link
Copy Markdown
Contributor Author

@sanchit-gandhi There are 2 test failing here, I am unable to get the same failure locally in my machine. Any pointers on how to replicate failing test and fix it ?

Comment thread src/transformers/models/whisper/modeling_flax_whisper.py Outdated
Comment thread src/transformers/models/whisper/modeling_flax_whisper.py Outdated
Comment thread src/transformers/models/whisper/modeling_flax_whisper.py Outdated
Comment thread src/transformers/models/whisper/modeling_flax_whisper.py Outdated
Comment thread src/transformers/models/whisper/modeling_flax_whisper.py Outdated
Comment thread src/transformers/models/whisper/modeling_flax_whisper.py Outdated
output_hidden_states=output_hidden_states,
return_dict=return_dict,
rngs=rngs,
# method=_encoder_forward,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove this commented line too

Comment thread tests/models/whisper/test_modeling_flax_whisper.py Outdated
Comment thread tests/models/whisper/test_modeling_flax_whisper.py Outdated
Comment thread tests/models/whisper/test_modeling_flax_whisper.py Outdated

@sanchit-gandhi sanchit-gandhi left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work @raghavanone! Mainly just some clean-up before we can get this merged!

@sanchit-gandhi sanchit-gandhi left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one @raghavanone! Mainly just some code clean-up, then we can get this merged!

@sanchit-gandhi

Copy link
Copy Markdown
Contributor

Hey @raghavanone! Would you mind going through the previous review comments and marking them as resolved where you've addressed them? I'll then get you a final review asap! Thanks!

abhiwand and others added 19 commits April 5, 2023 12:47
* Add BridgeTower for ITC

* Fix review feedback

* Rename BridgeTowerForITC, cleanup

* Fix style and quality

* implement tests

---------

Co-authored-by: Tiep Le <97980157+tileintel@users.noreply.github.com>
Co-authored-by: Tiep Le <tiep.le@intel.com>
…line (huggingface#22031)

add tokenize_kwargs doc in the FeatureExtractionPipeline
…on_seq2seq.py (huggingface#21942)

* Add specaugment to run_speech_recognition_seq2seq.py

* Remove useless argument: text_column

* Fix quality

* Update return_attention_mask condition

* Update specaugment arguments only for whisper models

* Remove SpecAugment arguments from ModelArguments, only leave default values for simplicity

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update apply_spec_augment only for whisper models

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Rename return_attention_mask to forward_attention_mask to avoid confusion with wav2vec2 models

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* fixing

* Update modeling_whisper.py

* Update modeling_whisper.py

* Update src/transformers/models/whisper/modeling_whisper.py

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
…P-like models (huggingface#22035)

* Avoid text_config_dict and vision_config_dict being saved

* for other CLIP-like models

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* slow me

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
* Fix typos and add code examples, resources
* [21737][T5]: Fix gradient checkpoint bug

* [21737][T5]: Fix gradient checkpoint bug

* [21737][T5]: Fix gradient checkpoint bug

* Update src/transformers/models/mt5/modeling_mt5.py

* Update src/transformers/models/t5/modeling_t5.py

---------

Co-authored-by: njindal <njindal@adobe.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
…x it (huggingface#22045)

In ZSH, not using ' ' around pip install fails

Running 
```
pip install transformers[torch]
```
in the default ZSH terminal will fail with the error `zsh: no matches found: transformers[torch]`

The solution is to wrap the installation path in ' ' like 
```
pip install 'transformers[torch]'
```

Relevant StackOverflow: https://stackoverflow.com/questions/30539798/zsh-no-matches-found-requestssecurity
…ace#22051)

* Remove set_access_token usage + fail tests if FutureWarning

* do not fail on FutureWarning in CI

---------

Co-authored-by: testbot <lucainp@hf.co>
…ce#22054)

* show hfh warnings

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
…ce#22040)

* return analysis for hyperparameter_search with ray backend

* Revert "return analysis for hyperparameter_search with ray backend"

This reverts commit cd51790.

* add run_summary attribute to BestRun and return analysis for ray backend

* fix typo

* add doc for run_summary for ray backend
* Add an argument to pt-to-tf to allow overriding the model class

* make fixup

* Minor fix to error message

* Remove unused extra conversion from the script
rm $ symbol from code block 

Removed the $ symbol from the code block to make copy-pasting easier.
)

* [deepspeed] offload + non-cpuadam optimizer exception

* flip

* revert min version
…face#22033)

* Edit the docstring of `image_processing_donut` to match code

* improve style

* more style improvement after installing quality
python273 and others added 26 commits April 5, 2023 12:47
…e#21695)

LayoutLMv3TokenizerFast produces empty 'Ġ' token with `offset_mapping = (0, 0)`.
Next token is wrongly assumed to also be beginning of word and isn't
correctly assigned `pad_token_label`.
Modify test with text that produce 'Ġ' token.
Remove copy check from LayoutLMv2TokenizerFast for `_batch_encode_plus`.

solves issue: huggingface#19978
…e GPUs using `accelerate` (huggingface#22532)

* add `is_model_parallel` arg on Trainer

* add warning

* adapt from suggestions

* revert t5 changes

* remove commas

* adapt from suggestions
…#22535)

* enable PP for T5

* make fixup

* fix failing tests
* [setup] drop deprecated `distutils` usage

* drop deprecated `distutils.util.strtobool` usage

* fix import order

* reformat docstring by `doc-builder`
* [setup] migrate setup script to `pyproject.toml`

* [setup] cleanup configurations

* remove unused imports
* Fix OPTForQuestionAnswering doc string

for more adequate model answer decoding

* black style fix

* doc-builder style
…length of past_key_values when generating as a decoder (huggingface#22416)

* fix RoFormerEncoder postion embedding when generate as decoder

* make fixup

* add test case for check generate with past key values

* remove duplicating code
* fix the prefix tokens

* update fast and test values

* add legacy behaviour

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* update disclaimer, linkissue PR and behaviral changes

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <hi@lysand.re>

* styling

* make a quote

* quote this time

---------

Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Lysandre Debut <hi@lysand.re>
…e#22498)

* implemented safetensors save/load

* remove duplicated file

* added tests

* more tests

* style fix

* fix tf tests

* change to list comprehension

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* review fixes + safe load for sharded checkpoint

* style fix

* remove rogue import

* remove partial to avoid undefined exception

* use naming alias instead of safetensors.torch

* fix safe sharding in tests

* grammar

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update docs

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* minor corrections

* style

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Update modeling_utils.py
* Soft error whisper.

* Fix format.

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-34-94.taildb5d.ts.net>
* Initial commit

* more stash commit

* Yet another stash commit

* yet more stash commit

* Mostly working except for docs / repo consistency

* Stop importing model list from torch file

* Add TF BLIP models to docs

* Add auto classes

* Move get_text_features and get_image_features

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/blip/test_modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update tests/models/blip/test_modeling_tf_blip_text.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Use channels_last convolutions in TF (better performance + compatibility)

* Remove _shape function

* Move multi-line statement to one line in PT + TF

* Specify tf.keras.layers instead of importing from it

* Remove test_gradient_checkpointing and empty test_training methods

* move some multi-line statements to one line

* Update docstring for generate

* Remove pruned heads set

* Remove self.seq_len_dim

* Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states

* ensure original model follows config in more cases

* Skip the same cross-attention tests in the PT tests - didn't realize we did it twice!

* Add training args throughout the models and layers

* make fixup

* Fix docstring for inputs_embeds

* Add docstring for is_decoder

* Add docstrings to text models

* Remove redundant computation

* Add unpack_inputs / keras_serializable

* Add modeling_tf_blip to doctests

* Add config classes for keras serialization

* Changes to allow model porting with pt-to-tf

* Quick fix to decoder head and test tweaks

* Revert an issue with masking the embeddings outputs

* Allow missing keys in some equivalence tests (for unused layers)

* Add tf-pt equivalence tests back in

* Update src/transformers/models/blip/modeling_tf_blip.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/blip/modeling_tf_blip_text.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* make fixup

* Refactor invert_attention_mask out into tf_utils

* Re-enable cross-tests on the PT side too

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
…_indices (huggingface#22557)

* corrected/clarified the code comment of find_pruneable_heads_and_indices

* have run make style
* initial commit

* review changes

* post model PR merge

* updating doc
* Fix inverted conditional in TF common test!

* Make the same change in the PT tests file

* Make sure hidden states for GPT2 have the same output shape in PT/TF

* Minor fix to PT implementation of token classification loss

* Skip loss equivalence test for TFHubert because it keeps overflowing to inf

* Compute LM loss for TF the (weird) way it's computed in PT

* Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert

* Fix - don't try to access the hidden states property when output is a tuple
@sanchit-gandhi

Copy link
Copy Markdown
Contributor

Hey @raghavanone - I think the commit history has been corrupted for this PR? Gentle reminder that one must force push after rebasing: #21894 (comment) Think this is probably the culprit for the 250 extra commits!

In this instance, it's probably best to close this PR in favour of a new one that only contains the new changes you with to merge. Sorry about that!

@sanchit-gandhi

Copy link
Copy Markdown
Contributor

Closing in favour of #22883

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Flax Whisper for audio classification