diff --git a/optimum/habana/sentence_transformers/st_gaudi_trainer.py b/optimum/habana/sentence_transformers/st_gaudi_trainer.py index b8f52b6e16..69b476ecf8 100644 --- a/optimum/habana/sentence_transformers/st_gaudi_trainer.py +++ b/optimum/habana/sentence_transformers/st_gaudi_trainer.py @@ -322,6 +322,8 @@ def compute_loss( and model != self.model # Only if the model is wrapped and hasattr(loss_fn, "model") # Only if the loss stores the model and loss_fn.model != model # Only if the wrapped model is not already stored + and hasattr(model, "module") + and loss_fn.model != model.module # wrapped model differs from orig model ): loss_fn = self.override_model_in_loss(loss_fn, model) loss = loss_fn(features, labels)