diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index f331ad7a00..b36a637b98 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1861,12 +1861,8 @@ def evaluation_loop( if loss is not None: # Handle NaN loss if args.logging_nan_inf_filter and (torch.isnan(loss) or torch.isinf(loss)): - # If loss is NaN or Inf, use the average of previous losses or a small constant - if losses_host is not None and len(losses_host) > 0: - avg_loss = torch.mean(losses_host) - loss = avg_loss - else: - loss = torch.tensor(1e-8, device=loss.device) + # If loss is NaN or Inf, use a small constant + loss = torch.tensor(1e-8, device=loss.device) losses = self.gather_function((loss.repeat(batch_size))) all_losses.add(losses)