From 31b4552940ebad87fcc36e52d3ffb23fcf450e47 Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Mon, 10 Mar 2025 16:22:16 -0700 Subject: [PATCH] Revert logits to float change for accuracy --- optimum/habana/transformers/models/mixtral/modeling_mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 5a5226ae0b..2c9e6ba2f1 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -790,7 +790,7 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = self.lm_head(hidden_states[:, slice_indices, :]).float() loss = None if labels is not None: