diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 56f439e093..13cf2c7223 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -318,7 +318,7 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: + if not self.training and hidden_states.device.type == "hpu" and FusedRMSNorm: orig_dtype = hidden_states.dtype hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon) return hidden_states.to(orig_dtype)