diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index be5d60e4b7..5a0ae211f5 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -48,7 +48,7 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if not self.training and hidden_states.device.type == "hpu" and FusedRMSNorm: + if hidden_states.device.type == "hpu" and FusedRMSNorm: orig_dtype = hidden_states.dtype hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon) return hidden_states.to(orig_dtype)