diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 6075524b5f..f7d28b2fe0 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -19,6 +19,7 @@ # limitations under the License. """PyTorch Mistral model.""" +import os import math from typing import List, Optional, Tuple, Union @@ -47,7 +48,6 @@ _gaudi_prepare_4d_causal_attention_mask, ) - try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -68,6 +68,15 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -113,16 +122,6 @@ def gaudi_mistral_repeat_kv( return query_states, key_states, value_states, attention_mask - -def update_sincos_cache(self, seq_len): - # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings - # This helps in avoiding creation of these caches during actual model forward pass and - # reduce memory consumption and improve performance. - if seq_len > self.max_position_embeddings: - self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) - - def gaudi_mistral_rmsnorm_forward(self, hidden_states): """ Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py @@ -153,6 +152,7 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.matmul_qk = Matmul() self.matmul_av = Matmul() + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -177,6 +177,14 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) return (self.k_cache.cache.shape, self.v_cache.cache.shape) + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + def forward( self, hidden_states: torch.Tensor, @@ -263,8 +271,9 @@ def forward( import habana_frameworks.torch.hpu as ht if q_len == 1: # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute):#False): + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) else: @@ -272,10 +281,10 @@ def forward( if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None) + attn_output = self.fused_scaled_dot_product_attention(query_states, key_states, value_states, None, 0.0, True, None) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( + attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None ) else: @@ -421,6 +430,10 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) + def forward( self, input_ids: torch.LongTensor = None, @@ -535,6 +548,7 @@ def forward( and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) ): htcore.mark_step() + if output_hidden_states: all_hidden_states += (hidden_states,)