diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 9d3575944e97..e6510878a8c4 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -253,9 +253,10 @@ def load_from_checkpoint( if 'cfg' in kwargs: model = cls._load_model_state(checkpoint, strict=strict, **kwargs) else: - model = cls._load_model_state( - checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get( + 'cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] ) + model = cls._load_model_state(checkpoint, strict=strict, cfg=cfg, **kwargs) checkpoint = model finally: diff --git a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb index 1f74c308922f..e1bfc234e9d0 100644 --- a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb +++ b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb @@ -363,6 +363,7 @@ "outputs": [], "source": [ "# download the model's configuration file \n", + "MODEL_CONFIG = 'text_classification_config.yaml'\n", "config_dir = WORK_DIR + '/configs/'\n", "os.makedirs(config_dir, exist_ok=True)\n", "if not os.path.exists(config_dir + MODEL_CONFIG):\n",