Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
GaudiStarcoder2DecoderLayer,
GaudiStarcoder2ForCausalLM,
GaudiStarcoder2Model,
GaudiWav2Vec2SdpaAttention,
GaudiWhisperDecoder,
GaudiWhisperDecoderLayer,
GaudiWhisperForConditionalGeneration,
Expand Down Expand Up @@ -270,6 +271,7 @@ def adapt_transformers_to_gaudi():
transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder.forward = gaudi_wav2vec2_encoder_forward
transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward = gaudi_wav2vec2forctc_forward
transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer.forward = gaudi_wav2vec2_tdnnlayer_forward
transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SdpaAttention = GaudiWav2Vec2SdpaAttention

# Generation is modified to run faster in lazy mode
transformers.generation.GenerationMixin.generate = GaudiGenerationMixin.generate
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
from .vit import gaudi_vit_self_attention_forward
from .vits import gaudi_unconstrained_rational_quadratic_spline
from .wav2vec2 import (
GaudiWav2Vec2SdpaAttention,
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
_gaudi_wav2vec2_sample_negative_indices,
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .modeling_wav2vec2 import (
GaudiWav2Vec2SdpaAttention,
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
_gaudi_wav2vec2_sample_negative_indices,
Expand Down
166 changes: 165 additions & 1 deletion optimum/habana/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Optional, Tuple, Union

import torch
Expand All @@ -24,7 +25,8 @@
CausalLMOutput,
Wav2Vec2BaseModelOutput,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
from transformers.models.wav2vec2.modeling_wav2vec2 import _HIDDEN_STATES_START_POSITION, Wav2Vec2Attention, logger


try:
Expand All @@ -35,6 +37,12 @@
print("Could not import Custom CTCLoss kernel. This Kernel is available only for SynapseAI >= 1.15.0")
custom_ctc_loss_fwd = None

try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None


def _gaudi_wav2vec2_compute_mask_indices(
shape: Tuple[int, int],
Expand Down Expand Up @@ -476,3 +484,159 @@ def gaudi_wav2vec2_tdnnlayer_forward(self, hidden_states: torch.Tensor) -> torch

hidden_states = self.activation(hidden_states)
return hidden_states


class GaudiWav2Vec2SdpaAttention(Wav2Vec2Attention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
config: Optional[Wav2Vec2Config] = None,
):
super().__init__(
embed_dim,
num_heads,
dropout,
is_decoder,
bias,
is_causal,
config,
)
self.use_flash_attention = True if os.getenv("USE_FLASH_ATTENTION") == "1" else False
self.flash_attention_fast_softmax = True if os.getenv("FLASH_ATTENTION_FAST_SOFTMAX") == "1" else False
self.flash_attention_recompute = True if os.getenv("FLASH_ATTENTION_RECOMPUTE") == "1" else False

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from Wav2Vec2SdpaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py
The only difference is If the `USE_FLASH_ATTENTION` switch is enabled, then use the HPU's fused SDPA; otherwise, use PyTorch's native SDPA
"""
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, tgt_len, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

query_states = self._shape(query_states, tgt_len, bsz)

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False

if self.use_flash_attention and FusedSDPA:
if tgt_len == 1:
# next token
softmax_mode = True if os.getenv("QUANT_CONFIG", "") else False
recompute_mode = False
else:
# first token
softmax_mode = "fast" if self.flash_attention_fast_softmax else "None"
recompute_mode = self.flash_attention_recompute

attn_output = FusedSDPA.apply(
query_states,
key_states,
value_states,
attention_mask,
self.dropout if self.training else 0.0,
is_causal,
None, # scale ; Non -> default scale
softmax_mode, # "None" -> Non fast softmax
recompute_mode, # -> recompute_mode = false
)
else:
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)

if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)

attn_output = self.out_proj(attn_output)

return attn_output, None, past_key_value