diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 86ce73f50a..a850541cca 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -718,7 +718,6 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False - model_kwargs["use_fused_rope"] = False if generation_config.use_fused_rope is False else True if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 2a07427bb7..778786e205 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -25,15 +25,19 @@ try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True except ImportError: + has_fused_rope = False print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm + + has_fused_rms_norm = True except ImportError: + has_fused_rms_norm = False print("Not using HPU fused kernel for RMSNorm") - FusedRMSNorm = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA @@ -67,7 +71,7 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: + if hidden_states.device.type == "hpu" and has_fused_rms_norm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: orig_dtype = hidden_states.dtype @@ -313,7 +317,6 @@ def pre_attn_forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -368,9 +371,7 @@ def pre_attn_forward( kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope( - query_states, key_states, cos, sin, position_ids, use_fused_rope=use_fused_rope - ) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) if past_key_value is not None or reuse_cache: # reuse k, v, self_attention @@ -498,7 +499,6 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -530,7 +530,6 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, - use_fused_rope=use_fused_rope, **kwargs, ) @@ -563,7 +562,6 @@ def pre_attn( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -580,7 +578,6 @@ def pre_attn( use_flash_attention, flash_attention_recompute, cache_idx=cache_idx, - use_fused_rope=use_fused_rope, ) return output_attn, attn_weights, present_key_value @@ -631,7 +628,6 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -736,7 +732,6 @@ def forward( False, use_flash_attention, flash_attention_recompute, - use_fused_rope, ) else: layer_outputs = decoder_layer( @@ -753,7 +748,6 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, - use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -826,13 +820,16 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.generation_config.use_fused_rope is False: + global has_fused_rope + has_fused_rope = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -851,7 +848,6 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, - use_fused_rope=use_fused_rope, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -990,8 +986,8 @@ def prepare_inputs_for_generation( return model_inputs -def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): - if q.device.type == "hpu" and FusedRoPE and use_fused_rope: +def apply_customized_rope(q, k, cos, sin, position_ids): + if q.device.type == "hpu" and has_fused_rope: # TODO: remove `.clone()` when it is fixed in SynapseAI return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 274ca2a646..ee1427227d 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -925,8 +925,6 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True - if self.model.generation_config.use_fused_rope is False: - inputs["use_fused_rope"] = False # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1731,8 +1729,6 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True - if self.model.generation_config.use_fused_rope is False: - inputs["use_fused_rope"] = False # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)