From a27f6d37b7422e8177e644d6b2a814f951ad146f Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Thu, 22 Feb 2024 16:39:57 +0200 Subject: [PATCH] Fix graph breaks in torch compile mode Signed-off-by: Sanju C Sudhakaran --- .../transformers/models/llama/modeling_llama.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a4e0130bab..dbb8b18f4e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -26,15 +26,19 @@ try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True except ImportError: + has_fused_rope = False print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm + + has_fused_rms_norm = True except ImportError: + has_fused_rms_norm = False print("Not using HPU fused kernel for RMSNorm") - FusedRMSNorm = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA @@ -52,7 +56,7 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: + if hidden_states.device.type == "hpu" and has_fused_rms_norm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: orig_dtype = hidden_states.dtype @@ -952,7 +956,7 @@ def prepare_inputs_for_generation( def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): - if q.device.type == "hpu" and FusedRoPE and use_fused_rope: + if q.device.type == "hpu" and has_fused_rope and use_fused_rope: # TODO: remove `.clone()` when SynapseAI v1.15 is released return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids