You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Description :
When setting save top k > 1, training will produce k + 1 ckpt files: the top k checkpoint + the last checkpoint before training stops. There are a few things I observed:
It seems that only the last checkpoint file is automatically converted to nemo. In other words, the .nemo file might not be the checkpoint with the best metrics.
When exp_manager.checkpoint_callback_params.always_save_nemo=True , this should in theory automatically create a NeMo checkpoint for each megatron checkpoint that is saved. However, this does not seem to be the actual behavior, as only 1 nemo file was saved. There is a similar issue reported for NeMo.
The problem defeats the purpose save_top_k , and our automatically generated nemo file is not necessarily the checkpoint we intend to save.
Proposal:
I propose that we add a script in documentation to convert ckpt file to nemo. The method in our docs here does not work for MolMIM.
I have confirmed that the following method works for single-node trained ckpt files with model/pipeline parallelism = 1:
Create a megatron_ckpt_to_nemo.py
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# Note: I modified the script and confirmed that it works for MolMIM trained on single node (model parallel size = 1, pipeline parallel size = 1). # TODO: I have not tested it for multi-node training.r"""Conversion script to convert PTL checkpoints into nemo checkpoint. Example to run this conversion script: python -m torch.distributed.launch --nproc_per_node=<tensor_model_parallel_size> * <pipeline_model_parallel_size> \ megatron_ckpt_to_nemo.py \ --checkpoint_folder <path_to_PTL_checkpoints_folder> \ --checkpoint_name <checkpoint_name> \ --nemo_file_path <path_to_output_nemo_file> \ --tensor_model_parallel_size <tensor_model_parallel_size> \ --pipeline_model_parallel_size <pipeline_model_parallel_size> \ --gpus_per_node <gpus_per_node> \ --model_type <model_type>"""importdisimportosfromargparseimportArgumentParserimporttorchfromgenericpathimportisdirfrommegatron.coreimportparallel_statefromomegaconfimportOmegaConf, open_dictfrompytorch_lightning.plugins.environmentsimportTorchElasticEnvironmentfrompytorch_lightning.trainer.trainerimportTrainer# Note: I deleted these imports because they are not used. Also, they generated errors because some depencies are not installed in bionemo# from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel# from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel# from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel# from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel# from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel# from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model# from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModelfrombionemo.model.molecule.molmimimportMolMIMModelfromnemo.collections.nlp.parts.nlp_overridesimport (
GradScaler,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
fromnemo.utilsimportAppState, loggingfromnemo.utils.distributedimportinitialize_distributedfromnemo.utils.model_utilsimportinject_model_parallel_rankdefget_args():
parser=ArgumentParser()
parser.add_argument(
"--checkpoint_folder",
type=str,
default=None,
required=True,
help="Path to PTL checkpoints saved during training. Ex: /raid/nemo_experiments/megatron_gpt/checkpoints",
)
parser.add_argument(
"--checkpoint_name",
type=str,
default=None,
required=True,
help="Name of checkpoint to be used. Ex: megatron_gpt--val_loss=6.34-step=649-last.ckpt",
)
parser.add_argument(
"--hparams_file",
type=str,
default=None,
required=False,
help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml",
)
parser.add_argument("--nemo_file_path", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument(
"--no_pack_nemo_file",
action="store_true",
help="If passed, output will be written under nemo_file_path as a directory instead of packed as a tarred .nemo file.",
)
parser.add_argument("--gpus_per_node", type=int, required=True, default=None)
parser.add_argument("--tensor_model_parallel_size", type=int, required=True, default=None)
parser.add_argument("--pipeline_model_parallel_size", type=int, required=True, default=None)
parser.add_argument(
"--pipeline_model_parallel_split_rank",
type=int,
required=False,
default=None,
help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.",
)
parser.add_argument(
"--model_type",
type=str,
required=True,
default="gpt",
choices=["gpt", "sft", "t5", "bert", "nmt", "bart", "retro", "molmim"],
)
parser.add_argument("--local-rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1))
parser.add_argument("--bcp", action="store_true", help="Whether on BCP platform")
parser.add_argument(
"--precision",
type=str,
required=False,
default='16-mixed',
choices=['32-true', '16-mixed', 'bf16-mixed', '32'], # note: I added 32help="Precision value for the trainer that matches with precision of the ckpt",
)
args=parser.parse_args()
returnargsdefconvert(local_rank, rank, world_size, args):
app_state=AppState()
app_state.data_parallel_rank=0num_nodes=world_size//args.gpus_per_nodeplugins= []
strategy="auto"ifargs.bcp:
plugins.append(TorchElasticEnvironment())
ifargs.model_type=='gpt':
strategy=NLPDDPStrategy()
cfg= {
'trainer': {
'devices': args.gpus_per_node,
'num_nodes': num_nodes,
'accelerator': 'gpu',
'precision': args.precision,
},
'model': {'native_amp_init_scale': 2**32, 'native_amp_growth_interval': 1000, 'hysteresis': 2},
}
cfg=OmegaConf.create(cfg)
scaler=None# If FP16 create a GradScaler as the build_model_parallel_config of MegatronBaseModel expects itifcfg.trainer.precision=='16-mixed':
scaler=GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2**32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
# Note: I disabled the pipeline mixed precision plugin because it was generating errors# plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))# Set precision None after precision plugins are created as PTL >= 2.1 does not allow both# precision plugins and precision to exist# Note: since I disabled PipelineMixedPrecisionPlugin, I don't need to set precision to None# cfg.trainer.precision = Nonetrainer=Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
app_state.pipeline_model_parallel_size=args.pipeline_model_parallel_sizeapp_state.tensor_model_parallel_size=args.tensor_model_parallel_size# Auto set split rank for T5, BART, NMT if split rank is None.ifargs.pipeline_model_parallel_size>1andargs.model_typein ['t5', 'bart', 'nmt', 'molmim']:
ifargs.pipeline_model_parallel_split_rankisnotNone:
app_state.pipeline_model_parallel_split_rank=args.pipeline_model_parallel_split_rankelse:
ifargs.pipeline_model_parallel_size%2!=0:
raiseValueError(
f"Pipeline model parallel size {args.pipeline_model_parallel_size} must be even if split rank is not specified."
)
else:
# If split rank is not set, then we set it to be pipeline_model_parallel_size // 2 - this is because in most cases we have the same number of enc/dec layers.app_state.pipeline_model_parallel_split_rank=args.pipeline_model_parallel_size//2else:
app_state.pipeline_model_parallel_split_rank=Noneapp_state.model_parallel_size=app_state.tensor_model_parallel_size*app_state.pipeline_model_parallel_sizeparallel_state.initialize_model_parallel(
tensor_model_parallel_size=app_state.tensor_model_parallel_size,
pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
)
app_state.pipeline_model_parallel_rank=parallel_state.get_pipeline_model_parallel_rank()
app_state.tensor_model_parallel_rank=parallel_state.get_tensor_model_parallel_rank()
# check for distributed checkpointdist_ckpt_dir=os.path.join(args.checkpoint_folder, args.checkpoint_name)
ifos.path.isdir(dist_ckpt_dir):
checkpoint_path=dist_ckpt_direlse:
# legacy checkpoint needs model parallel injectioncheckpoint_path=inject_model_parallel_rank(os.path.join(args.checkpoint_folder, args.checkpoint_name))
logging.info(
f'rank: {rank}, local_rank: {local_rank}, is loading checkpoint: {checkpoint_path} for tp_rank: {app_state.tensor_model_parallel_rank} and pp_rank: {app_state.pipeline_model_parallel_rank}'
)
# note: I dsiabeled these sinece we just need bionemo model# if args.model_type == 'gpt':# model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)# elif args.model_type == 'sft':# model = MegatronGPTSFTModel.load_from_checkpoint(# checkpoint_path, hparams_file=args.hparams_file, trainer=trainer# )# # we force the target for the loaded model to have the correct target# # because the hparams.yaml sometimes contains MegatronGPTModel as the target.# with open_dict(model.cfg):# model.cfg.target = f"{MegatronGPTSFTModel.__module__}.{MegatronGPTSFTModel.__name__}"# elif args.model_type == 'bert':# model = MegatronBertModel.load_from_checkpoint(# checkpoint_path, hparams_file=args.hparams_file, trainer=trainer# )# elif args.model_type == 't5':# model = MegatronT5Model.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)# elif args.model_type == 'bart':# model = MegatronBARTModel.load_from_checkpoint(# checkpoint_path, hparams_file=args.hparams_file, trainer=trainer# )# elif args.model_type == 'nmt':# model = MegatronNMTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)# elif args.model_type == 'retro':# model = MegatronRetrievalModel.load_from_checkpoint(# checkpoint_path, hparams_file=args.hparams_file, trainer=trainer# )# note: I added MolMIMifargs.model_type=='molmim':
model=MolMIMModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
logging.info("loaded MOLMIM")
else:
passmodel._save_restore_connector=NLPSaveRestoreConnector()
save_file_path=args.nemo_file_pathifargs.no_pack_nemo_file:
# With --no_pack_nemo_file, nemo_file_path is expected to be a directory.# Adding a dummy model filename here conforms with SaveRestoreConnector's convention.model._save_restore_connector.pack_nemo_file=Falsesave_file_path=os.path.join(save_file_path, 'model.nemo')
iftorch.distributed.is_initialized():
torch.distributed.barrier()
model.save_to(save_file_path)
logging.info(f'NeMo model saved to: {args.nemo_file_path}')
if__name__=='__main__':
args=get_args()
local_rank, rank, world_size=initialize_distributed(args)
convert(local_rank, rank, world_size, args)
Create a megatron_ckpt_to_nemo.sh:
#! /bin/bashset -e
# check your pretrain YAML file for these parameters
tensor_model_parallel_size=1
pipeline_model_parallel_size=1
nproc_per_node=$((tensor_model_parallel_size * pipeline_model_parallel_size))# nproc_per_node = tensor_model_parallel_size * pipeline_model_parallel_size
gpus_per_node=8
model_type="molmim"
precision='32'# retrieve from trainer.precision or model.precision. Must be a string!# result base directory
result_base_dir="path/to/your/result/directory"# name of the folder containing the .ckpt file that you want to convert
checkpoint_folder="${result_base_dir}/checkpoints"# name of the ckpt file that you want to convert. Just the file name, not the path.
checkpoint_name="MolMIM-small--val_molecular_accuracy=1.00-val_loss=0.50-step=300000-consumed_samples=2457600000.0.ckpt"# Remove the .ckpt extension from the file name. We will use the same name for the output .nemo file
output_file_name="${checkpoint_name%.ckpt}.nemo"# output folder
output_folder="${result_base_dir}/checkpoints_converted"# Create the directory if it doesn't exist
mkdir -p "$output_folder"# Define the path where you want to save the output .nemo file
path_to_output_nemo_file="${output_folder}/${output_file_name}"# Path to the hparams.yaml file. This is the file in your training results folder, generated automatically by the trainings script. # It is NOT the input YAML file that you use for training. # It is NOT the YAML file produced by unpacking the nemo tar file.
hparams_file="${result_base_dir}/hparams.yaml"
python -m torch.distributed.launch --nproc_per_node=${nproc_per_node} megatron_ckpt_to_nemo.py \
--checkpoint_folder ${checkpoint_folder} \
--checkpoint_name ${checkpoint_name} \
--nemo_file_path ${path_to_output_nemo_file} \
--model_type ${model_type} \
--hparams_file ${hparams_file} \
--tensor_model_parallel_size ${tensor_model_parallel_size} \
--pipeline_model_parallel_size ${pipeline_model_parallel_size} \
--gpus_per_node ${gpus_per_node} \
--precision ${precision}echo"Conversion completed!"
To use, run
bash megatron_ckpt_to_nemo.sh
Note:
the hparams.yaml file is very important. It is automatically produced after a training run.
You cannot simply untar the .nemo file and replace the ckpt file with another ckpt. There are parameter mismatches as well as extra parameters (optimizer steps, lr, etc) that are removed during a standard ckpt -> nemo conversion.
The text was updated successfully, but these errors were encountered:
BioNeMo version: 1.10
Description :
When setting save top k > 1, training will produce k + 1
ckpt
files: the topk
checkpoint + the last checkpoint before training stops. There are a few things I observed:.nemo
file might not be the checkpoint with the best metrics.exp_manager.checkpoint_callback_params.always_save_nemo=True
, this should in theory automatically create a NeMo checkpoint for each megatron checkpoint that is saved. However, this does not seem to be the actual behavior, as only 1nemo
file was saved. There is a similar issue reported for NeMo.Proposal:
I propose that we add a script in documentation to convert
ckpt
file tonemo
. The method in our docs here does not work for MolMIM.I have confirmed that the following method works for single-node trained ckpt files with model/pipeline parallelism = 1:
megatron_ckpt_to_nemo.py
megatron_ckpt_to_nemo.sh
:To use, run
Note:
hparams.yaml
file is very important. It is automatically produced after a training run..nemo
file and replace theckpt
file with anotherckpt
. There are parameter mismatches as well as extra parameters (optimizer steps, lr, etc) that are removed during a standard ckpt -> nemo conversion.The text was updated successfully, but these errors were encountered: