Skip to content

Commit

Permalink
RoPE length extrapolation with interpolation (#7005)
Browse files Browse the repository at this point in the history
* Push changes

Signed-off-by: MaximumEntropy <[email protected]>

* Fixes

Signed-off-by: MaximumEntropy <[email protected]>

* add continue training script

Signed-off-by: MaximumEntropy <[email protected]>

* [WIP] nonlinear interp

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* override encoder_seq_len

Signed-off-by: MaximumEntropy <[email protected]>

* Remove nonlinear

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* sft with pi (#7006)

* sft with pi

Signed-off-by: Evelina <[email protected]>

* update values only if not None"

Signed-off-by: Evelina <[email protected]>

---------

Signed-off-by: Evelina <[email protected]>

* Address comments

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add info

Signed-off-by: MaximumEntropy <[email protected]>

* Empty

Signed-off-by: MaximumEntropy <[email protected]>

---------

Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: Evelina <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Evelina <[email protected]>
Signed-off-by: Gerald Shen <[email protected]>
  • Loading branch information
3 people authored and gshennvm committed Jul 12, 2023
1 parent 58170f1 commit 57f345d
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ model:
share_embeddings_and_output_weights: True # Share embedding and output layer weights.
overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595.

tokenizer:
library: 'megatron'
Expand Down
193 changes: 193 additions & 0 deletions examples/nlp/language_modeling/megatron_gpt_continue_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.core.config import hydra_runner
from nemo.utils import AppState, logging
from nemo.utils.exp_manager import exp_manager
from nemo.utils.model_utils import inject_model_parallel_rank


def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original gpt pre-training config (t5_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(gpt_cfg, True)
OmegaConf.resolve(cfg)
with open_dict(gpt_cfg):
gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
gpt_cfg.micro_batch_size = cfg.model.micro_batch_size
gpt_cfg.global_batch_size = cfg.model.global_batch_size
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False)
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
gpt_cfg.data = cfg.model.data
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
gpt_cfg.restore_from_path = cfg.restore_from_path
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view
gpt_cfg.encoder_seq_length = cfg.model.encoder_seq_length
gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(gpt_cfg)
gpt_cfg.cfg = gpt_cfg

return gpt_cfg


def load_from_nemo(cls, cfg, trainer, gpt_cfg, modify_confg_fn):
gpt_cfg = modify_confg_fn(gpt_cfg, cfg, add_cfg_to_tree=False)
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.restore_from_path
model = cls.restore_from(
restore_path=cfg.restore_from_path,
trainer=trainer,
override_config_path=gpt_cfg,
save_restore_connector=save_restore_connector,
)
return model


def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn):
app_state = AppState()
if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size
app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
app_state.virtual_pipeline_model_parallel_rank,
) = fake_initialize_model_parallel(
world_size=app_state.model_parallel_size,
rank=trainer.global_rank,
tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank,
)
checkpoint_path = inject_model_parallel_rank(
os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name)
)
hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file)
gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True)
with tempfile.NamedTemporaryFile(suffix='.yaml') as f:
OmegaConf.save(config=gpt_cfg, f=f.name)
model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,)
return model


def validate_checkpoint_loading_args(cfg):
if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir):
raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.')
if cfg.checkpoint_name is None:
raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.')
if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file):
raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.')


@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
with_distributed_adam = cfg.model.optim.get('name', 'fused_adam') == 'distributed_fused_adam'
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2 and not with_distributed_adam:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)

exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)

if cfg.restore_from_path:
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.restore_from_path
gpt_cfg = MegatronGPTModel.restore_from(
restore_path=cfg.restore_from_path,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
model = load_from_nemo(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config)
elif cfg.model.get("pretrained_checkpoint", None) is not None:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config)
else:
print(' > WARNING: No checkpoint provided. Starting from scratch.')
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision
model = MegatronGPTModel(cfg.model, trainer)
trainer.fit(model)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ model:
activations_checkpoint_num_layers: null # not used with 'selective'
answer_only_loss: False # not used right now
gradient_as_bucket_view: False
seq_len_interpolation_factor: null # if not None, seq_len_interpolation_factor will match the base model's value
use_flash_attention: null # if not None, will match the base model's value

hidden_dropout: 0.0
attention_dropout: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def main(cfg) -> None:
peft_model_cfg.data.test_ds = cfg.model.data.test_ds
peft_model_cfg.activations_checkpoint_granularity = None
peft_model_cfg.activations_checkpoint_method = None
if peft_model_cfg.get("use_flash_attention", False):
peft_model_cfg.use_flash_attention = cfg.model.use_flash_attention
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
peft_model_cfg["seq_len_interpolation_factor"] = cfg.model.seq_len_interpolation_factor

with open_dict(cfg):
# update the config with the trained model config
Expand Down
9 changes: 9 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

if cfg.model.get('use_flash_attention', None) is not None:
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

if cfg.model.get('seq_len_interpolation_factor', None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
use_emha=False,
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
):
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__(
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
if isinstance(self.model, list):
converted_model = []
for module in self.model:
converted_model.append(Float16Module(module=module, precision=cfg.precision))
converted_model.append(
Float16Module(
module=module,
precision=cfg.precision,
share_token_embeddings=self.cfg.get('share_embeddings_and_output_weights', True),
)
)
self.model = converted_model
else:
self.model = Float16Module(module=self.model, precision=cfg.precision)
self.model = Float16Module(
module=self.model,
precision=cfg.precision,
share_token_embeddings=self.cfg.get('share_embeddings_and_output_weights', True),
)

if self.trainer.precision == 'bf16':
self.autocast_dtype = torch.bfloat16
Expand Down Expand Up @@ -360,6 +370,7 @@ def model_provider_func(self, pre_process, post_process):
ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False),
use_flash_attention=self.cfg.get('use_flash_attention', False),
megatron_legacy=self.cfg.get('megatron_legacy', False),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
)

return model
Expand Down Expand Up @@ -981,7 +992,7 @@ def build_pretraining_data_loader(
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=drop_last,
global_batch_size=self.cfg.global_batch_size,
rampup_batch_size=self.cfg.rampup_batch_size,
rampup_batch_size=self.cfg.get('rampup_batch_size', None),
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
elif self.cfg.data.dataloader_type == 'cyclic':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_language_model(
use_emha=False,
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -200,6 +201,7 @@ def get_language_model(
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -508,6 +510,7 @@ def __init__(
use_emha=False,
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
):
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -559,7 +562,9 @@ def __init__(
assert 0 < rotary_percentage <= 1
if rotary_percentage < 1:
rotary_dim = int(rotary_dim * rotary_percentage)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim, seq_len_interpolation_factor=seq_len_interpolation_factor
)

elif position_embedding_type == 'alibi':
# TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def float_conversion(val):


class Float16Module(MegatronModule):
def __init__(self, module, precision):
def __init__(self, module, precision, share_token_embeddings=True):
if not HAVE_MEGATRON_CORE:
raise ImportError(
"Megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)
super().__init__()
super().__init__(share_token_embeddings=share_token_embeddings)
self.precision = precision

if precision == 'bf16':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,28 @@


class RotaryEmbedding(nn.Module):
def __init__(self, dim):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""

def __init__(self, dim: int, seq_len_interpolation_factor: int = None):
"""
Args:
dim (int): rotary embedding dimension
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
by this factor via the trick in https://arxiv.org/abs/2306.15595.
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
if self.seq_len_interpolation_factor is not None:
seq = seq.type_as(self.inv_freq)
seq *= 1 / self.seq_len_interpolation_factor
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
Expand Down

0 comments on commit 57f345d

Please sign in to comment.