-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RoPE length extrapolation with interpolation #7005
Changes from all commits
b9287e3
0a4574b
6e79e92
363831e
95282d4
21739f0
9cf880a
f170ed2
5daffb7
9a6fab2
6d5f75e
f1d7cbe
61ac17f
bc6fe06
895d8bb
b60be05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import tempfile | ||
|
||
from omegaconf.omegaconf import OmegaConf, open_dict | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.plugins.environments import TorchElasticEnvironment | ||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector | ||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel | ||
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel | ||
from nemo.collections.nlp.parts.nlp_overrides import ( | ||
GradScaler, | ||
MegatronHalfPrecisionPlugin, | ||
NLPDDPStrategy, | ||
NLPSaveRestoreConnector, | ||
PipelineMixedPrecisionPlugin, | ||
) | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import AppState, logging | ||
from nemo.utils.exp_manager import exp_manager | ||
from nemo.utils.model_utils import inject_model_parallel_rank | ||
|
||
|
||
def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): | ||
""" | ||
This function modifies the original gpt pre-training config (t5_cfg) with attributes from the finetuning config (cfg). | ||
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. | ||
""" | ||
OmegaConf.set_struct(gpt_cfg, True) | ||
OmegaConf.resolve(cfg) | ||
with open_dict(gpt_cfg): | ||
gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) | ||
gpt_cfg.micro_batch_size = cfg.model.micro_batch_size | ||
gpt_cfg.global_batch_size = cfg.model.global_batch_size | ||
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) | ||
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) | ||
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) | ||
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) | ||
gpt_cfg.data = cfg.model.data | ||
gpt_cfg.optim = cfg.model.optim | ||
gpt_cfg.precision = cfg.trainer.precision | ||
gpt_cfg.restore_from_path = cfg.restore_from_path | ||
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint | ||
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view | ||
gpt_cfg.encoder_seq_length = cfg.model.encoder_seq_length | ||
gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings | ||
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor | ||
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention | ||
|
||
# This is needed when modifying a hparam file directly to load `.ckpt` files. | ||
# This is not needed to modify the cfg in `.nemo` files. | ||
if add_cfg_to_tree: | ||
OmegaConf.resolve(gpt_cfg) | ||
gpt_cfg.cfg = gpt_cfg | ||
|
||
return gpt_cfg | ||
|
||
|
||
def load_from_nemo(cls, cfg, trainer, gpt_cfg, modify_confg_fn): | ||
gpt_cfg = modify_confg_fn(gpt_cfg, cfg, add_cfg_to_tree=False) | ||
save_restore_connector = NLPSaveRestoreConnector() | ||
if os.path.isdir(cfg.restore_from_path): | ||
save_restore_connector.model_extracted_dir = cfg.restore_from_path | ||
model = cls.restore_from( | ||
restore_path=cfg.restore_from_path, | ||
trainer=trainer, | ||
override_config_path=gpt_cfg, | ||
save_restore_connector=save_restore_connector, | ||
) | ||
return model | ||
|
||
|
||
def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn): | ||
app_state = AppState() | ||
if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1: | ||
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size | ||
app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size | ||
app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size | ||
( | ||
app_state.tensor_model_parallel_rank, | ||
app_state.pipeline_model_parallel_rank, | ||
app_state.model_parallel_size, | ||
app_state.data_parallel_size, | ||
app_state.pipeline_model_parallel_split_rank, | ||
app_state.virtual_pipeline_model_parallel_rank, | ||
) = fake_initialize_model_parallel( | ||
world_size=app_state.model_parallel_size, | ||
rank=trainer.global_rank, | ||
tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size, | ||
pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size, | ||
pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank, | ||
) | ||
checkpoint_path = inject_model_parallel_rank( | ||
os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name) | ||
) | ||
hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file) | ||
gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True) | ||
with tempfile.NamedTemporaryFile(suffix='.yaml') as f: | ||
OmegaConf.save(config=gpt_cfg, f=f.name) | ||
model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,) | ||
return model | ||
|
||
|
||
def validate_checkpoint_loading_args(cfg): | ||
if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir): | ||
raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.') | ||
if cfg.checkpoint_name is None: | ||
raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.') | ||
if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file): | ||
raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.') | ||
|
||
|
||
@hydra_runner(config_path="conf", config_name="megatron_gpt_config") | ||
def main(cfg) -> None: | ||
logging.info("\n\n************** Experiment configuration ***********") | ||
logging.info(f'\n{OmegaConf.to_yaml(cfg)}') | ||
|
||
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) | ||
with_distributed_adam = cfg.model.optim.get('name', 'fused_adam') == 'distributed_fused_adam' | ||
plugins = [] | ||
strategy = NLPDDPStrategy( | ||
no_ddp_communication_hook=True, | ||
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, | ||
find_unused_parameters=False, | ||
) | ||
if cfg.trainer.precision in [16, 'bf16']: | ||
scaler = None | ||
if cfg.trainer.precision == 16: | ||
scaler = GradScaler( | ||
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), | ||
growth_interval=cfg.model.get('native_amp_growth_interval', 1000), | ||
hysteresis=cfg.model.get('hysteresis', 2), | ||
) | ||
if megatron_amp_o2 and not with_distributed_adam: | ||
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) | ||
else: | ||
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) | ||
|
||
if cfg.get('cluster_type', None) == 'BCP': | ||
plugins.append(TorchElasticEnvironment()) | ||
|
||
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) | ||
|
||
exp_manager(trainer, cfg.exp_manager) | ||
|
||
# update resume from checkpoint found by exp_manager | ||
if cfg.model.resume_from_checkpoint is not None: | ||
resume_from_checkpoint = cfg.model.resume_from_checkpoint | ||
else: | ||
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path | ||
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') | ||
|
||
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) | ||
|
||
if cfg.restore_from_path: | ||
save_restore_connector = NLPSaveRestoreConnector() | ||
if os.path.isdir(cfg.restore_from_path): | ||
save_restore_connector.model_extracted_dir = cfg.restore_from_path | ||
gpt_cfg = MegatronGPTModel.restore_from( | ||
restore_path=cfg.restore_from_path, | ||
trainer=trainer, | ||
return_config=True, | ||
save_restore_connector=save_restore_connector, | ||
) | ||
model = load_from_nemo(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config) | ||
elif cfg.model.get("pretrained_checkpoint", None) is not None: | ||
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) | ||
model = load_from_checkpoint_dir(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config) | ||
Check failure Code scanning / CodeQL Wrong number of arguments in a call
Call to [function load_from_checkpoint_dir](1) with too many arguments; should be no more than 4.
Check failure Code scanning / CodeQL Potentially uninitialized local variable
Local variable 'gpt_cfg' may be used before it is initialized.
|
||
else: | ||
print(' > WARNING: No checkpoint provided. Starting from scratch.') | ||
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams | ||
with open_dict(cfg): | ||
cfg.model.precision = cfg.trainer.precision | ||
model = MegatronGPTModel(cfg.model, trainer) | ||
trainer.fit(model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,8 @@ model: | |
activations_checkpoint_num_layers: null # not used with 'selective' | ||
answer_only_loss: False # not used right now | ||
gradient_as_bucket_view: False | ||
seq_len_interpolation_factor: null # if not None, seq_len_interpolation_factor will match the base model's value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add some explanation about how interpolation factor translate to longer sequences. |
||
use_flash_attention: null # if not None, will match the base model's value | ||
|
||
hidden_dropout: 0.0 | ||
attention_dropout: 0.0 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,6 +123,7 @@ def get_language_model( | |
use_emha=False, | ||
ub_tp_comm_overlap=False, | ||
use_flash_attention=False, | ||
seq_len_interpolation_factor=None, | ||
): | ||
"""Build language model and return along with the key to save.""" | ||
|
||
|
@@ -200,6 +201,7 @@ def get_language_model( | |
use_emha=use_emha, | ||
ub_tp_comm_overlap=ub_tp_comm_overlap, | ||
use_flash_attention=use_flash_attention, | ||
seq_len_interpolation_factor=seq_len_interpolation_factor, | ||
) | ||
# key used for checkpoints. | ||
language_model_key = 'language_model' | ||
|
@@ -508,6 +510,7 @@ def __init__( | |
use_emha=False, | ||
ub_tp_comm_overlap=False, | ||
use_flash_attention=False, | ||
seq_len_interpolation_factor=None, | ||
): | ||
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights) | ||
|
||
|
@@ -559,7 +562,9 @@ def __init__( | |
assert 0 < rotary_percentage <= 1 | ||
if rotary_percentage < 1: | ||
rotary_dim = int(rotary_dim * rotary_percentage) | ||
self.rotary_pos_emb = RotaryEmbedding(rotary_dim) | ||
self.rotary_pos_emb = RotaryEmbedding( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe not in this PR. we need to add seq_len_interpolation_factor for all the models that uses RoPe. |
||
rotary_dim, seq_len_interpolation_factor=seq_len_interpolation_factor | ||
) | ||
|
||
elif position_embedding_type == 'alibi': | ||
# TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these
modify config
,load_from_nemo
,load_from_checkpoint_dir
,validate_checkpoint_loading_args
functions are the same as in the SFT code. can we put them into an utility file?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are almost the same not 100% identical. The issue is each one (SFT vs continued training) modifies some common attributes like
data
,optim
, but also a few different things.