From c78a31694d0eff59cfd7705849596808eed5793f Mon Sep 17 00:00:00 2001 From: Adam Stachowicz Date: Fri, 22 Nov 2024 19:16:10 +0200 Subject: [PATCH 1/2] [SW_208086] implement fused sdpa for wav2vec2 (#18) --- optimum/habana/transformers/modeling_utils.py | 2 + .../habana/transformers/models/__init__.py | 1 + .../transformers/models/wav2vec2/__init__.py | 1 + .../models/wav2vec2/modeling_wav2vec2.py | 167 +++++++++++++++++- 4 files changed, 170 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index b43e283595..937dc7ec23 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -133,6 +133,7 @@ GaudiStarcoder2DecoderLayer, GaudiStarcoder2ForCausalLM, GaudiStarcoder2Model, + GaudiWav2Vec2SdpaAttention, GaudiWhisperDecoder, GaudiWhisperDecoderLayer, GaudiWhisperForConditionalGeneration, @@ -270,6 +271,7 @@ def adapt_transformers_to_gaudi(): transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder.forward = gaudi_wav2vec2_encoder_forward transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward = gaudi_wav2vec2forctc_forward transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer.forward = gaudi_wav2vec2_tdnnlayer_forward + transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SdpaAttention = GaudiWav2Vec2SdpaAttention # Generation is modified to run faster in lazy mode transformers.generation.GenerationMixin.generate = GaudiGenerationMixin.generate diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index c02e9588f3..5e8ffb0b07 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -265,6 +265,7 @@ from .vit import gaudi_vit_self_attention_forward from .vits import gaudi_unconstrained_rational_quadratic_spline from .wav2vec2 import ( + GaudiWav2Vec2SdpaAttention, _gaudi_wav2vec2_compute_mask_indices, _gaudi_wav2vec2_mask_hidden_states, _gaudi_wav2vec2_sample_negative_indices, diff --git a/optimum/habana/transformers/models/wav2vec2/__init__.py b/optimum/habana/transformers/models/wav2vec2/__init__.py index 84372061b6..b8cb45c1dd 100644 --- a/optimum/habana/transformers/models/wav2vec2/__init__.py +++ b/optimum/habana/transformers/models/wav2vec2/__init__.py @@ -1,4 +1,5 @@ from .modeling_wav2vec2 import ( + GaudiWav2Vec2SdpaAttention, _gaudi_wav2vec2_compute_mask_indices, _gaudi_wav2vec2_mask_hidden_states, _gaudi_wav2vec2_sample_negative_indices, diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index eded656119..ae812ea5a3 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import random from typing import Optional, Tuple, Union import torch @@ -24,7 +26,8 @@ CausalLMOutput, Wav2Vec2BaseModelOutput, ) -from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION +from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config +from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION, Wav2Vec2Attention, logger try: @@ -35,6 +38,12 @@ print("Could not import Custom CTCLoss kernel. This Kernel is available only for SynapseAI >= 1.15.0") custom_ctc_loss_fwd = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + def _gaudi_wav2vec2_compute_mask_indices( shape: Tuple[int, int], @@ -476,3 +485,159 @@ def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch hidden_states = self.activation(hidden_states) return hidden_states + + +class GaudiWav2Vec2SdpaAttention(Wav2Vec2Attention): + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2 + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[Wav2Vec2Config] = None, + ): + super().__init__( + embed_dim, + num_heads, + dropout, + is_decoder, + bias, + is_causal, + config, + ) + self.use_flash_attention = True if os.getenv("USE_FLASH_ATTENTION") == "1" else False + self.flash_attention_fast_softmax = True if os.getenv("FLASH_ATTENTION_FAST_SOFTMAX") == "1" else False + self.flash_attention_recompute = True if os.getenv("FLASH_ATTENTION_RECOMPUTE") == "1" else False + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from Wav2Vec2SdpaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py + The only difference is If the `USE_FLASH_ATTENTION` switch is enabled, then use the HPU's fused SDPA; otherwise, use PyTorch's native SDPA + """ + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + if self.use_flash_attention and FusedSDPA: + if tgt_len == 1: + # next token + softmax_mode = True if os.getenv("QUANT_CONFIG", "") else False + recompute_mode = False + else: + # first token + softmax_mode = "fast" if self.flash_attention_fast_softmax else "None" + recompute_mode = self.flash_attention_recompute + + attn_output = FusedSDPA.apply( + query_states, + key_states, + value_states, + attention_mask, + self.dropout if self.training else 0.0, + is_causal, + None, # scale ; Non -> default scale + softmax_mode, # "None" -> Non fast softmax + recompute_mode, # -> recompute_mode = false + ) + else: + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value From 3db64458a48ab77cc1715398a548d8d4066445fd Mon Sep 17 00:00:00 2001 From: Adam Stachowicz Date: Tue, 26 Nov 2024 13:34:06 +0200 Subject: [PATCH 2/2] Fix style --- optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py index ae812ea5a3..0e1378ee57 100644 --- a/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -15,7 +15,6 @@ # limitations under the License. import os -import random from typing import Optional, Tuple, Union import torch