Skip to content

Commit

Permalink
fix pos id - hf update (#7075)
Browse files Browse the repository at this point in the history
* fix pos id - hf update

Signed-off-by: Evelina <[email protected]>

* add missing import

Signed-off-by: Evelina <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Evelina <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ekmb and pre-commit-ci[bot] committed Jul 19, 2023
1 parent 8f3957f commit c90625e
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import hashlib
import json
import os
from typing import Any, Optional
from typing import Any, Mapping, Optional

from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
Expand Down Expand Up @@ -385,3 +385,13 @@ def load_from_checkpoint(
finally:
cls._set_model_restore_state(is_being_restored=False)
return checkpoint

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# starting with trasformers v4.31.0, buffer for position_ids is persistent=False
if (
self.bert_model is not None
and "position_ids" not in self.bert_model.embeddings._modules
and "bert_model.embeddings.position_ids" in state_dict
):
del state_dict["bert_model.embeddings.position_ids"]
super(NLPModel, self).load_state_dict(state_dict, strict=strict)

0 comments on commit c90625e

Please sign in to comment.