diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index 9e2f9aaae0..17ddb3d828 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -406,8 +406,20 @@ def prepare_inputs_for_generation( def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids - ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + if q.dtype == torch.bfloat16: + rope_q = FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), position_ids) + else: + rope_q = FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + if k.dtype == torch.bfloat16: + rope_k = FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), position_ids) + else: + rope_k = FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + return rope_q, rope_k else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids)