diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 7e1b9f72e1..831f17eb61 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -981,18 +981,25 @@ def prepare_inputs_for_generation( def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): - if q.device.type == "hpu" and has_fused_rope and use_fused_rope: - # TODO: remove `.clone()` when SynapseAI v1.15 is released - if k.dtype==torch.bfloat16: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), position_ids + if q.device.type == "hpu" and FusedRoPE: + if q.dtype == torch.bfloat16: + rope_q = FusedRoPE.apply( + q, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, ) - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) + else: + rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + if k.dtype == torch.bfloat16: + rope_k = FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, + ) + else: + rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + return rope_q, rope_k else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids)