From 7fe80272b3a468047395eeb2ca6df7ae27669718 Mon Sep 17 00:00:00 2001 From: staghado Date: Sat, 9 Dec 2023 16:14:21 +0100 Subject: [PATCH 1/5] First try at adding FA2 support for Musicgen --- .../models/musicgen/modeling_musicgen.py | 367 ++++++++++++++++-- 1 file changed, 338 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 8cca8108efd0..d631b7d797d1 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -22,13 +22,20 @@ import torch import torch.nn as nn +import torch.nn.functional as F from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation.configuration_utils import GenerationConfig -from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList +from ...generation.logits_process import ( + ClassifierFreeGuidanceLogitsProcessor, + LogitsProcessorList, +) from ...generation.stopping_criteria import StoppingCriteriaList -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -40,6 +47,8 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -48,6 +57,11 @@ from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -62,6 +76,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + @dataclass class MusicgenUnconditionalInput(ModelOutput): """ @@ -304,12 +331,236 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen +class MusicgenFlashAttention2(MusicgenAttention): + """ + Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MusicgenFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +MUSICGEN_ATTENTION_CLASSES = { + "eager": MusicgenAttention, + "flash_attention_2": MusicgenFlashAttention2, +} + + class MusicgenDecoderLayer(nn.Module): def __init__(self, config: MusicgenDecoderConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = MusicgenAttention( + self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -321,7 +572,7 @@ def __init__(self, config: MusicgenDecoderConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MusicgenAttention( + self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation]( self.embed_dim, config.num_attention_heads, dropout=config.attention_dropout, @@ -434,6 +685,7 @@ class MusicgenPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.initializer_factor @@ -666,6 +918,7 @@ def __init__(self, config: MusicgenDecoderConfig): ) self.layers = nn.ModuleList([MusicgenDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.layer_norm = nn.LayerNorm(config.hidden_size) self.gradient_checkpointing = False @@ -721,16 +974,22 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + if self._use_flash_attention_2: + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -818,7 +1077,13 @@ def forward( if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( @@ -1074,7 +1339,12 @@ def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: in max_length = max_length if max_length is not None else self.generation_config.max_length input_ids_shifted = ( - torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 + torch.ones( + (bsz, num_codebooks, max_length), + dtype=torch.long, + device=input_ids.device, + ) + * -1 ) channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks @@ -1095,7 +1365,8 @@ def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: in # construct a pattern mask that indicates the positions of padding tokens for each codebook # first fill the upper triangular part (the EOS padding) delay_pattern = torch.triu( - torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1 + torch.ones((channel_codebooks, max_length), dtype=torch.bool), + diagonal=max_length - channel_codebooks + 1, ) # then fill the lower triangular part (the BOS padding) delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool)) @@ -1125,7 +1396,8 @@ def build_delay_pattern_mask(self, input_ids: torch.LongTensor, pad_token_id: in @staticmethod def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask): """Apply a delay pattern mask to the decoder input ids, only preserving predictions where - the mask is set to -1, and otherwise setting to the value detailed in the mask.""" + the mask is set to -1, and otherwise setting to the value detailed in the mask. + """ seq_len = input_ids.shape[-1] decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len] input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask) @@ -1252,7 +1524,9 @@ def generate( requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - input_ids, generation_config.pad_token_id, generation_config.eos_token_id + input_ids, + generation_config.pad_token_id, + generation_config.eos_token_id, ) # 5. Prepare `max_length` depending on other stopping criteria. @@ -1510,7 +1784,9 @@ def tie_weights(self): # tie text encoder and decoder base model decoder_base_model_prefix = self.decoder.base_model_prefix self._tie_encoder_decoder_weights( - self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix + self.text_encoder, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, ) def get_audio_encoder(self): @@ -1671,7 +1947,9 @@ def from_sub_models_pretrained( if "config" not in kwargs_text_encoder: encoder_config, kwargs_text_encoder = AutoConfig.from_pretrained( - text_encoder_pretrained_model_name_or_path, **kwargs_text_encoder, return_unused_kwargs=True + text_encoder_pretrained_model_name_or_path, + **kwargs_text_encoder, + return_unused_kwargs=True, ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: @@ -1685,7 +1963,9 @@ def from_sub_models_pretrained( kwargs_text_encoder["config"] = encoder_config text_encoder = AutoModel.from_pretrained( - text_encoder_pretrained_model_name_or_path, *model_args, **kwargs_text_encoder + text_encoder_pretrained_model_name_or_path, + *model_args, + **kwargs_text_encoder, ) audio_encoder = kwargs_audio_encoder.pop("model", None) @@ -1698,7 +1978,9 @@ def from_sub_models_pretrained( if "config" not in kwargs_audio_encoder: encoder_config, kwargs_audio_encoder = AutoConfig.from_pretrained( - audio_encoder_pretrained_model_name_or_path, **kwargs_audio_encoder, return_unused_kwargs=True + audio_encoder_pretrained_model_name_or_path, + **kwargs_audio_encoder, + return_unused_kwargs=True, ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: @@ -1712,7 +1994,9 @@ def from_sub_models_pretrained( kwargs_audio_encoder["config"] = encoder_config audio_encoder = AutoModel.from_pretrained( - audio_encoder_pretrained_model_name_or_path, *model_args, **kwargs_audio_encoder + audio_encoder_pretrained_model_name_or_path, + *model_args, + **kwargs_audio_encoder, ) decoder = kwargs_decoder.pop("model", None) @@ -1725,7 +2009,9 @@ def from_sub_models_pretrained( if "config" not in kwargs_decoder: decoder_config, kwargs_decoder = AutoConfig.from_pretrained( - decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + decoder_pretrained_model_name_or_path, + **kwargs_decoder, + return_unused_kwargs=True, ) if isinstance(decoder_config, MusicgenConfig): @@ -1757,7 +2043,12 @@ def from_sub_models_pretrained( config = MusicgenConfig.from_sub_models_config( text_encoder.config, audio_encoder.config, decoder.config, **kwargs ) - return cls(text_encoder=text_encoder, audio_encoder=audio_encoder, decoder=decoder, config=config) + return cls( + text_encoder=text_encoder, + audio_encoder=audio_encoder, + decoder=decoder, + config=config, + ) @add_start_docstrings_to_model_forward(MUSICGEN_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1930,7 +2221,10 @@ def prepare_inputs_for_generation( **kwargs, ): if decoder_delay_pattern_mask is None: - decoder_input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( + ( + decoder_input_ids, + decoder_delay_pattern_mask, + ) = self.decoder.build_delay_pattern_mask( decoder_input_ids, self.generation_config.pad_token_id, max_length=self.generation_config.max_length, @@ -1996,7 +2290,11 @@ def _prepare_decoder_input_ids_for_generation( if device is None: device = self.device decoder_input_ids_start = ( - torch.ones((batch_size * self.decoder.num_codebooks, 1), dtype=torch.long, device=device) + torch.ones( + (batch_size * self.decoder.num_codebooks, 1), + dtype=torch.long, + device=device, + ) * decoder_start_token_id ) @@ -2011,7 +2309,10 @@ def _prepare_decoder_input_ids_for_generation( if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] decoder_attention_mask = torch.cat( - (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask), + ( + torch.ones_like(decoder_attention_mask)[:, :1], + decoder_attention_mask, + ), dim=-1, ) model_kwargs["decoder_attention_mask"] = decoder_attention_mask @@ -2057,7 +2358,11 @@ def _prepare_text_encoder_kwargs_for_generation( last_hidden_state = torch.concatenate([last_hidden_state, torch.zeros_like(last_hidden_state)], dim=0) if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = torch.concatenate( - [model_kwargs["attention_mask"], torch.zeros_like(model_kwargs["attention_mask"])], dim=0 + [ + model_kwargs["attention_mask"], + torch.zeros_like(model_kwargs["attention_mask"]), + ], + dim=0, ) model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) @@ -2304,7 +2609,9 @@ def generate( if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( - inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id + inputs_tensor, + generation_config.pad_token_id, + generation_config.eos_token_id, ) if "encoder_outputs" not in model_kwargs: @@ -2523,7 +2830,9 @@ def get_unconditional_inputs(self, num_samples=1): >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256) ```""" last_hidden_state = torch.zeros( - (num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype + (num_samples, 1, self.config.text_encoder.hidden_size), + device=self.device, + dtype=self.dtype, ) attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long) From f161afe28c1ef4de2256421d829cbbef7c4f0a70 Mon Sep 17 00:00:00 2001 From: Said Taghadouini <84044788+staghado@users.noreply.github.com> Date: Mon, 15 Jan 2024 13:10:08 +0100 Subject: [PATCH 2/5] Update src/transformers/models/musicgen/modeling_musicgen.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- src/transformers/models/musicgen/modeling_musicgen.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index d631b7d797d1..156a91837db1 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -32,10 +32,7 @@ LogitsProcessorList, ) from ...generation.stopping_criteria import StoppingCriteriaList -from ...modeling_attn_mask_utils import ( - _prepare_4d_attention_mask, - _prepare_4d_causal_attention_mask, -) +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, From 5624359eff33efa9fc36c3fd5f3a532421fbb61a Mon Sep 17 00:00:00 2001 From: Said Taghadouini <84044788+staghado@users.noreply.github.com> Date: Mon, 15 Jan 2024 13:10:18 +0100 Subject: [PATCH 3/5] Update src/transformers/models/musicgen/modeling_musicgen.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- src/transformers/models/musicgen/modeling_musicgen.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 156a91837db1..6e32ca75c98d 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -27,10 +27,7 @@ from ...activations import ACT2FN from ...generation.configuration_utils import GenerationConfig -from ...generation.logits_process import ( - ClassifierFreeGuidanceLogitsProcessor, - LogitsProcessorList, -) +from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList from ...generation.stopping_criteria import StoppingCriteriaList from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( From a3c900b1ef6ac5e06151441f2c19064be15ba934 Mon Sep 17 00:00:00 2001 From: staghado Date: Mon, 15 Jan 2024 13:55:02 +0100 Subject: [PATCH 4/5] add flash attention 2 flag to MusicgenForConditionalGeneration --- src/transformers/models/musicgen/modeling_musicgen.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 6e32ca75c98d..26fa28ba244c 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1681,6 +1681,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): base_model_prefix = "encoder_decoder" main_input_name = "input_ids" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True def __init__( self, From 1e23a61e309c976ca63a66c229abd6a3edda8a81 Mon Sep 17 00:00:00 2001 From: staghado Date: Mon, 15 Jan 2024 14:01:48 +0100 Subject: [PATCH 5/5] Add MusicGen to the list of models supporting FA2 in the doc --- docs/source/en/perf_infer_gpu_one.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index b12670584a4e..2ec912636d9f 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -51,6 +51,8 @@ FlashAttention-2 is currently supported for the following architectures: * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) +* [MusicGen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) + You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.