diff --git a/sentence_transformers/losses/DenoisingAutoEncoderLoss.py b/sentence_transformers/losses/DenoisingAutoEncoderLoss.py index 8cdb607df..b2c94b95a 100644 --- a/sentence_transformers/losses/DenoisingAutoEncoderLoss.py +++ b/sentence_transformers/losses/DenoisingAutoEncoderLoss.py @@ -113,9 +113,21 @@ def __init__(self, model: SentenceTransformer, decoder_name_or_path: str = None, "Since the encoder vocabulary has been changed and --tie_encoder_decoder=True, now the new vocabulary has also been used for the decoder." ) decoder_base_model_prefix = self.decoder.base_model_prefix - PreTrainedModel._tie_encoder_decoder_weights( - model[0].auto_model, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix - ) + try: + # Compatibility with transformers <4.40.0 + PreTrainedModel._tie_encoder_decoder_weights( + model[0].auto_model, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + ) + except TypeError: + # Compatibility with transformers >=4.40.0 + PreTrainedModel._tie_encoder_decoder_weights( + model[0].auto_model, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + encoder_name_or_path, + ) def retokenize(self, sentence_features): input_ids = sentence_features["input_ids"]