diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 032a7449c27e..d739efa88485 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -16,7 +16,7 @@ import hashlib import json import os -from typing import Any, Optional +from typing import Any, Mapping, Optional from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer @@ -385,3 +385,13 @@ def load_from_checkpoint( finally: cls._set_model_restore_state(is_being_restored=False) return checkpoint + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + # starting with trasformers v4.31.0, buffer for position_ids is persistent=False + if ( + self.bert_model is not None + and "position_ids" not in self.bert_model.embeddings._modules + and "bert_model.embeddings.position_ids" in state_dict + ): + del state_dict["bert_model.embeddings.position_ids"] + super(NLPModel, self).load_state_dict(state_dict, strict=strict)