diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 460d82effc..33f2141fba 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -766,7 +766,7 @@ def generate( ) model_kwargs["kv_cache_len"] = calculated_max_length - if self.config.model_type in ["llama", "falcon"]: + if self.config.model_type in ["llama", "falcon", "mistral"]: if self.config.max_position_embeddings < calculated_max_length: unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index df99ff67a0..56088b911c 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -20,6 +20,7 @@ """PyTorch Mistral model.""" import math +import os from typing import List, Optional, Tuple, Union import habana_frameworks.torch.core as htcore @@ -65,6 +66,12 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + logger = logging.get_logger(__name__) @@ -109,6 +116,15 @@ def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) + + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -181,8 +197,10 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.matmul_qk = Matmul() self.matmul_av = Matmul() + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None self.inp_seq_len = -1 self._init_rope() + self.norm_factor = 1.0 / math.sqrt(self.head_dim) def _init_rope(self): """ @@ -255,7 +273,9 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -322,39 +342,63 @@ def forward( else: past_key_value = None - # repeat k/v heads if n_kv_heads < n_heads - query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) - attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + if use_flash_attention and FusedSDPA: + import habana_frameworks.torch.hpu as ht - if attn_weights.size() not in [ - (bsz, self.num_heads, q_len, kv_seq_len), - (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), - ]: - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" - f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" + if q_len == 1: + # next token + use_recompute = True if os.getenv("QUANT_CONFIG", "") else False + with ht.sdp_kernel(enable_recompute=use_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # first token + if flash_attention_causal_mask: + # causal masking on first token requires inputs to be of the same length + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, None, 0.0, True, None + ) + else: + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = self.fused_scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # repeat k/v heads if n_kv_heads < n_heads + query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if attention_mask is not None: - if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: + if attn_weights.size() not in [ + (bsz, self.num_heads, q_len, kv_seq_len), + (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), + ]: raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" + f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) - if attn_softmax_bf16: - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) - else: - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) - attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + attn_weights = attn_weights + attention_mask + + if attn_softmax_bf16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -405,7 +449,9 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -429,7 +475,9 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, - use_fused_rope=use_fused_rope, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = residual + hidden_states @@ -458,6 +506,10 @@ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) + def forward( self, input_ids: torch.LongTensor = None, @@ -473,8 +525,10 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, lazy_mode: Optional[bool] = True, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -577,7 +631,12 @@ def forward( output_attentions, use_cache, None, - use_fused_rope, + False, + cache_idx, + attn_softmax_bf16, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, ) else: layer_outputs = decoder_layer( @@ -591,7 +650,9 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, - use_fused_rope=use_fused_rope, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = layer_outputs[0] @@ -652,8 +713,10 @@ def forward( trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, lazy_mode: Optional[bool] = True, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -686,8 +749,10 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, - use_fused_rope=use_fused_rope, lazy_mode=lazy_mode, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -799,6 +864,9 @@ def prepare_inputs_for_generation( "cache_idx": kwargs.get("cache_idx"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), "lazy_mode": kwargs.get("lazy_mode"), + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), } ) return model_inputs