From dc5ea5f61cc0d41fcc8a4227692f198b2a0d5cfc Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 19 Jan 2022 14:21:16 -0800 Subject: [PATCH 1/2] fix checkpoint loading and model config file Signed-off-by: Yi Dong --- nemo/collections/nlp/models/nlp_model.py | 3 ++- tutorials/nlp/Relation_Extraction-BioMegatron.ipynb | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 9d3575944e97..1a2c7766ba25 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -253,8 +253,9 @@ def load_from_checkpoint( if 'cfg' in kwargs: model = cls._load_model_state(checkpoint, strict=strict, **kwargs) else: + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get('cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) model = cls._load_model_state( - checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs + checkpoint, strict=strict, cfg=cfg, **kwargs ) checkpoint = model diff --git a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb index 1f74c308922f..e1bfc234e9d0 100644 --- a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb +++ b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb @@ -363,6 +363,7 @@ "outputs": [], "source": [ "# download the model's configuration file \n", + "MODEL_CONFIG = 'text_classification_config.yaml'\n", "config_dir = WORK_DIR + '/configs/'\n", "os.makedirs(config_dir, exist_ok=True)\n", "if not os.path.exists(config_dir + MODEL_CONFIG):\n", From 0897722d5de133d9ff364734b74ba080dbfa18c4 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 19 Jan 2022 14:39:41 -0800 Subject: [PATCH 2/2] fix style Signed-off-by: Yi Dong --- nemo/collections/nlp/models/nlp_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index 1a2c7766ba25..e6510878a8c4 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -253,10 +253,10 @@ def load_from_checkpoint( if 'cfg' in kwargs: model = cls._load_model_state(checkpoint, strict=strict, **kwargs) else: - cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get('cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) - model = cls._load_model_state( - checkpoint, strict=strict, cfg=cfg, **kwargs + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get( + 'cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] ) + model = cls._load_model_state(checkpoint, strict=strict, cfg=cfg, **kwargs) checkpoint = model finally: