20
20
21
21
from nemo .collections .nlp .models .language_modeling .megatron_finetune_model import MegatronT5FinetuneModel
22
22
from nemo .collections .nlp .models .language_modeling .megatron_glue_model import MegatronT5GLUEModel
23
- from nemo .collections .nlp .parts .nlp_overrides import GradScaler , MegatronHalfPrecisionPlugin , NLPDDPStrategy
23
+ from nemo .collections .nlp .models .language_modeling .megatron_t0_model import MegatronT0Model
24
+ from nemo .collections .nlp .parts .nlp_overrides import (
25
+ GradScaler ,
26
+ MegatronHalfPrecisionPlugin ,
27
+ NLPDDPStrategy ,
28
+ NLPSaveRestoreConnector ,
29
+ )
24
30
from nemo .core .config import hydra_runner
25
31
from nemo .utils import logging
26
32
from nemo .utils .exp_manager import StatelessTimer , exp_manager
@@ -71,15 +77,14 @@ def main(cfg) -> None:
71
77
# NOTE: Only data can be overriden here since this the file being restored here should already correspond to a GLUE/XNLI finetuned model.
72
78
OmegaConf .set_struct (t5_cfg , True )
73
79
with open_dict (t5_cfg ):
74
- t5_cfg .masked_softmax_fusion = False
75
80
t5_cfg .precision = cfg .trainer .precision
76
81
# Overwrite data configs
77
- if cfg .model .data .validation_ds .src_file_name is not None :
82
+ if cfg .model .data .validation_ds .get ( ' src_file_name' , None ) is not None :
78
83
logging .info (
79
84
'Found validation_ds.src_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.'
80
85
)
81
86
t5_cfg .data .validation_ds .src_file_name = cfg .model .data .validation_ds .src_file_name
82
- if cfg .model .data .validation_ds .tgt_file_name is not None :
87
+ if cfg .model .data .validation_ds .get ( ' tgt_file_name' , None ) is not None :
83
88
logging .info (
84
89
'Found validation_ds.tgt_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.'
85
90
)
@@ -88,9 +93,28 @@ def main(cfg) -> None:
88
93
t5_cfg .data .validation_ds .micro_batch_size = cfg .model .data .validation_ds .micro_batch_size
89
94
t5_cfg .data .validation_ds .global_batch_size = cfg .model .data .validation_ds .global_batch_size
90
95
91
- model = MegatronT5FinetuneModel .restore_from (
92
- restore_path = cfg .model .restore_from_path , trainer = trainer , override_config_path = t5_cfg
93
- )
96
+ if hasattr (cfg .model .data .validation_ds , 'task_name' ):
97
+ model = MegatronT5GLUEModel .restore_from (
98
+ restore_path = cfg .model .restore_from_path ,
99
+ trainer = trainer ,
100
+ override_config_path = t5_cfg ,
101
+ save_restore_connector = NLPSaveRestoreConnector (),
102
+ )
103
+ elif hasattr (cfg .model .data .validation_ds , 'file_names' ):
104
+ model = MegatronT0Model .restore_from (
105
+ restore_path = cfg .model .restore_from_path ,
106
+ trainer = trainer ,
107
+ override_config_path = t5_cfg ,
108
+ save_restore_connector = NLPSaveRestoreConnector (),
109
+ )
110
+ else :
111
+ model = MegatronT5FinetuneModel .restore_from (
112
+ restore_path = cfg .model .restore_from_path ,
113
+ trainer = trainer ,
114
+ override_config_path = t5_cfg ,
115
+ save_restore_connector = NLPSaveRestoreConnector (),
116
+ )
117
+
94
118
model .freeze ()
95
119
trainer .validate (model )
96
120
if hasattr (cfg .model .data , 'test_ds' ):
0 commit comments