Skip to content

Conversation

@staghado
Copy link
Contributor

@staghado staghado commented Dec 9, 2023

What does this PR do?

This PR adds Flash Attention 2 support for MusicGen model. It is based on Bart example and it is a WIP for now.
I could not test the model because FA2 is not supported yet for T4 GPUs.

Fixes #27552

@sanchit-gandhi @ylacombe

@ylacombe
Copy link
Contributor

Hey @staghado, thanks for taking care of this, let us know when it's ready to be reviewed!

Copy link

@circuluspibo circuluspibo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working test

@staghado
Copy link
Contributor Author

staghado commented Jan 6, 2024

I have conducted some tests on an A10 GPU :
- The code seems to work without errors when _supports_flash_attn_2 is set to True for MusicgenForConditionalGeneration but does not load the model with FA2 if not specified by hand. Maybe it needs to be added at the class level in MusicgenForConditionalGeneration?
- There is no difference in generation speed between eager attention and FA2 :

Screenshot from 2024-01-06 21-50-49

@sanchit-gandhi
Copy link
Contributor

cc @ylacombe could you possibly circle back here when you get the chance!

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @staghado, thanks for the update here!

I believe that you have to add _supports_flash_attn_2 to MusicgenForConditionalGeneration and MusicgenForCausalLM, otherwise it won't be flagged as supporting FA2 !

Also, did you make sure to add the flag attn_implementation="flash_attention_2" to from_pretrained ? as indicated here. Let me know.

Besides that, the modeling code looks good except some changes of format that shouldn't happen where you add a coma at the end of the line, all along the modeling file.

You can also add Muscigen to the list of models supported here.

Let me know if you need further help !

staghado and others added 4 commits January 15, 2024 13:10
Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com>
@staghado
Copy link
Contributor Author

staghado commented Jan 15, 2024

Hi @ylacombe,

I confirm that the model was instantiated as described here with the exception of torch_dtype=torch.float16 instead of torch_dtype=torch.bfloat16 because some operations did not seem to implement bfloat16.

@staghado staghado changed the title [WIP] Adding FA2 support for MusicGen Adding FA2 support for MusicGen Jan 25, 2024
@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @staghado , thanks for iterating here ! And sorry for the delay!

There's still some changes that appeared because of added commas, but otherwise looks good to me!

cc @amyeroberts for a review! (btw, there's no need for additional tests right ?)

Comment on lines +1074 to +1080
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a one liner

Suggested change
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
all_cross_attentions,
]
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]

Comment on lines +1336 to +1341
torch.ones(
(bsz, num_codebooks, max_length),
dtype=torch.long,
device=input_ids.device,
)
* -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Comment on lines +1362 to +1363
torch.ones((channel_codebooks, max_length), dtype=torch.bool),
diagonal=max_length - channel_codebooks + 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Comment on lines +1521 to +1523
input_ids,
generation_config.pad_token_id,
generation_config.eos_token_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Comment on lines +1782 to +1784
self.text_encoder,
self.decoder._modules[decoder_base_model_prefix],
self.decoder.base_model_prefix,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

Comment on lines +1945 to +1947
text_encoder_pretrained_model_name_or_path,
**kwargs_text_encoder,
return_unused_kwargs=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here and for all the other changes below

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Flash Attention 2 for audio/musicgen

4 participants