diff --git a/optimum/habana/sentence_transformers/st_gaudi_trainer.py b/optimum/habana/sentence_transformers/st_gaudi_trainer.py index b8f52b6e16..6f852ac924 100644 --- a/optimum/habana/sentence_transformers/st_gaudi_trainer.py +++ b/optimum/habana/sentence_transformers/st_gaudi_trainer.py @@ -263,6 +263,8 @@ def override_model_in_loss(self, loss: torch.nn.Module, model: "SentenceTransfor from sentence_transformers import SentenceTransformer for name, child in loss.named_children(): + if _is_peft_model(child): + child = child.get_base_model() if name == "model" and isinstance(child, SentenceTransformer): loss.model = model elif isinstance(child, torch.nn.Module):