diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 3758e422a5..c654b3a04b 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -46,6 +46,14 @@ ) import habana_frameworks.torch.core as htcore +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") + try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm except ImportError: @@ -166,6 +174,7 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -207,7 +216,9 @@ def forward( else: kv_seq_len += kv_shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, use_fused_rope=use_fused_rope + ) if use_cache: # reuse k, v, self_attention @@ -315,6 +326,7 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -343,6 +355,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, ) hidden_states = residual + hidden_states @@ -386,6 +399,7 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -492,6 +506,7 @@ def forward( output_attentions, use_cache, None, + use_fused_rope, ) else: layer_outputs = decoder_layer( @@ -505,6 +520,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -565,6 +581,7 @@ def forward( trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: """ @@ -594,6 +611,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, lazy_mode=lazy_mode, ) hidden_states = outputs[0] @@ -709,3 +727,20 @@ def prepare_inputs_for_generation( } ) return model_inputs + +def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): + if q.device.type == "hpu" and has_fused_rope and use_fused_rope: + # TODO: remove `.clone()` when SynapseAI v1.15 is released + if k.dtype==torch.bfloat16: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), position_ids + ) + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids)