Skip to content

Commit a038976

Browse files
MaximumEntropyericharper
authored andcommitted
Fix providing glue in seq2seq eval (#4843)
* Fix providing glue in seq2seq eval Signed-off-by: MaximumEntropy <[email protected]> * Fix Signed-off-by: MaximumEntropy <[email protected]> * Style Signed-off-by: MaximumEntropy <[email protected]> Signed-off-by: MaximumEntropy <[email protected]>
1 parent 6219e56 commit a038976

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@
2020

2121
from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel
2222
from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel
23-
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, MegatronHalfPrecisionPlugin, NLPDDPStrategy
23+
from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model
24+
from nemo.collections.nlp.parts.nlp_overrides import (
25+
GradScaler,
26+
MegatronHalfPrecisionPlugin,
27+
NLPDDPStrategy,
28+
NLPSaveRestoreConnector,
29+
)
2430
from nemo.core.config import hydra_runner
2531
from nemo.utils import logging
2632
from nemo.utils.exp_manager import StatelessTimer, exp_manager
@@ -71,15 +77,14 @@ def main(cfg) -> None:
7177
# NOTE: Only data can be overriden here since this the file being restored here should already correspond to a GLUE/XNLI finetuned model.
7278
OmegaConf.set_struct(t5_cfg, True)
7379
with open_dict(t5_cfg):
74-
t5_cfg.masked_softmax_fusion = False
7580
t5_cfg.precision = cfg.trainer.precision
7681
# Overwrite data configs
77-
if cfg.model.data.validation_ds.src_file_name is not None:
82+
if cfg.model.data.validation_ds.get('src_file_name', None) is not None:
7883
logging.info(
7984
'Found validation_ds.src_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.'
8085
)
8186
t5_cfg.data.validation_ds.src_file_name = cfg.model.data.validation_ds.src_file_name
82-
if cfg.model.data.validation_ds.tgt_file_name is not None:
87+
if cfg.model.data.validation_ds.get('tgt_file_name', None) is not None:
8388
logging.info(
8489
'Found validation_ds.tgt_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.'
8590
)
@@ -88,9 +93,28 @@ def main(cfg) -> None:
8893
t5_cfg.data.validation_ds.micro_batch_size = cfg.model.data.validation_ds.micro_batch_size
8994
t5_cfg.data.validation_ds.global_batch_size = cfg.model.data.validation_ds.global_batch_size
9095

91-
model = MegatronT5FinetuneModel.restore_from(
92-
restore_path=cfg.model.restore_from_path, trainer=trainer, override_config_path=t5_cfg
93-
)
96+
if hasattr(cfg.model.data.validation_ds, 'task_name'):
97+
model = MegatronT5GLUEModel.restore_from(
98+
restore_path=cfg.model.restore_from_path,
99+
trainer=trainer,
100+
override_config_path=t5_cfg,
101+
save_restore_connector=NLPSaveRestoreConnector(),
102+
)
103+
elif hasattr(cfg.model.data.validation_ds, 'file_names'):
104+
model = MegatronT0Model.restore_from(
105+
restore_path=cfg.model.restore_from_path,
106+
trainer=trainer,
107+
override_config_path=t5_cfg,
108+
save_restore_connector=NLPSaveRestoreConnector(),
109+
)
110+
else:
111+
model = MegatronT5FinetuneModel.restore_from(
112+
restore_path=cfg.model.restore_from_path,
113+
trainer=trainer,
114+
override_config_path=t5_cfg,
115+
save_restore_connector=NLPSaveRestoreConnector(),
116+
)
117+
94118
model.freeze()
95119
trainer.validate(model)
96120
if hasattr(cfg.model.data, 'test_ds'):

0 commit comments

Comments
 (0)