diff --git a/examples/language-modeling/run_lora_clm.py b/examples/language-modeling/run_lora_clm.py index c7ba0ebeae..5eaea49794 100644 --- a/examples/language-modeling/run_lora_clm.py +++ b/examples/language-modeling/run_lora_clm.py @@ -156,6 +156,14 @@ class ModelArguments: ) }, ) + use_fused_rope: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use Habana fused-rope for fine-tuning. The current support is limited to Llama only.", + ) + }, + ) load_meta_device: bool = field( default=False, metadata={ @@ -540,6 +548,8 @@ def main(): if model_args.use_flash_attention: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute + if model_args.use_fused_rope is False: + model.generation_config.use_fused_rope = False if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None: tokenizer.pad_token_id = model.generation_config.pad_token_id diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index de537dbe4f..e75e48a7c7 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -51,3 +51,4 @@ def __init__(self, **kwargs): self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None) self.use_flash_attention = kwargs.get("use_flash_attention", None) self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) + self.use_fused_rope = kwargs.get("use_fused_rope", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 42da95552e..c5c23b8f78 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -683,6 +683,7 @@ def generate( # determine whether flash attention needs to be used model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False + model_kwargs["use_fused_rope"] = False if generation_config.use_fused_rope is False else True if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index b6411e6b54..f811bc38e5 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -208,6 +208,7 @@ def pre_attn_forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -274,7 +275,9 @@ def pre_attn_forward( kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, use_fused_rope=use_fused_rope + ) if past_key_value is not None or reuse_cache: # reuse k, v, self_attention @@ -456,6 +459,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -486,6 +490,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, + use_fused_rope=use_fused_rope, **kwargs, ) self.self_attn.attention_all_reduce(output_pre_attn) @@ -516,6 +521,7 @@ def pre_attn( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( @@ -531,6 +537,7 @@ def pre_attn( use_flash_attention, flash_attention_recompute, cache_idx=cache_idx, + use_fused_rope=use_fused_rope, ) return output_attn, attn_weights, present_key_value @@ -580,6 +587,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -682,6 +690,7 @@ def forward( False, use_flash_attention, flash_attention_recompute, + use_fused_rope, ) else: layer_outputs = decoder_layer( @@ -697,6 +706,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, + use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -771,6 +781,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + use_fused_rope: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -794,6 +805,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, + use_fused_rope=use_fused_rope, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -911,8 +923,8 @@ def prepare_inputs_for_generation( return model_inputs -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and FusedRoPE: +def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): + if q.device.type == "hpu" and FusedRoPE and use_fused_rope: # TODO: remove `.clone()` when it is fixed in SynapseAI return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 3f81a10b25..9dd391a5e3 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -909,6 +909,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args): inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.use_fused_rope is False: + inputs["use_fused_rope"] = False # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): @@ -1684,6 +1686,8 @@ def evaluation_loop( inputs["use_flash_attention"] = True if self.model.generation_config.flash_attention_recompute: inputs["flash_attention_recompute"] = True + if self.model.generation_config.use_fused_rope is False: + inputs["use_fused_rope"] = False # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)