-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Adding FA2 support for MusicGen #27924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hey @staghado, thanks for taking care of this, let us know when it's ready to be reviewed! |
circuluspibo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
working test
|
I have conducted some tests on an A10 GPU : |
|
cc @ylacombe could you possibly circle back here when you get the chance! |
ylacombe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @staghado, thanks for the update here!
I believe that you have to add _supports_flash_attn_2 to MusicgenForConditionalGeneration and MusicgenForCausalLM, otherwise it won't be flagged as supporting FA2 !
Also, did you make sure to add the flag attn_implementation="flash_attention_2" to from_pretrained ? as indicated here. Let me know.
Besides that, the modeling code looks good except some changes of format that shouldn't happen where you add a coma at the end of the line, all along the modeling file.
You can also add Muscigen to the list of models supported here.
Let me know if you need further help !
335aca4 to
7fe8027
Compare
Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
ylacombe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @staghado , thanks for iterating here ! And sorry for the delay!
There's still some changes that appeared because of added commas, but otherwise looks good to me!
cc @amyeroberts for a review! (btw, there's no need for additional tests right ?)
| for v in [ | ||
| hidden_states, | ||
| next_cache, | ||
| all_hidden_states, | ||
| all_self_attns, | ||
| all_cross_attentions, | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be a one liner
| for v in [ | |
| hidden_states, | |
| next_cache, | |
| all_hidden_states, | |
| all_self_attns, | |
| all_cross_attentions, | |
| ] | |
| for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] |
| torch.ones( | ||
| (bsz, num_codebooks, max_length), | ||
| dtype=torch.long, | ||
| device=input_ids.device, | ||
| ) | ||
| * -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
| torch.ones((channel_codebooks, max_length), dtype=torch.bool), | ||
| diagonal=max_length - channel_codebooks + 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
| input_ids, | ||
| generation_config.pad_token_id, | ||
| generation_config.eos_token_id, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
| self.text_encoder, | ||
| self.decoder._modules[decoder_base_model_prefix], | ||
| self.decoder.base_model_prefix, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
| text_encoder_pretrained_model_name_or_path, | ||
| **kwargs_text_encoder, | ||
| return_unused_kwargs=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here and for all the other changes below

What does this PR do?
This PR adds Flash Attention 2 support for MusicGen model. It is based on Bart example and it is a WIP for now.
I could not test the model because FA2 is not supported yet for T4 GPUs.
Fixes #27552
@sanchit-gandhi @ylacombe