diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 778786e205..6178bb1338 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -33,11 +33,9 @@ try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm - - has_fused_rms_norm = True except ImportError: - has_fused_rms_norm = False print("Not using HPU fused kernel for RMSNorm") + FusedRMSNorm = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA @@ -71,7 +69,7 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if hidden_states.device.type == "hpu" and has_fused_rms_norm: + if hidden_states.device.type == "hpu" and FusedRMSNorm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: orig_dtype = hidden_states.dtype