Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE length extrapolation with interpolation #7005

Merged
merged 16 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ model:
share_embeddings_and_output_weights: True # Share embedding and output layer weights.
overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1
seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595.

tokenizer:
library: 'megatron'
Expand Down
193 changes: 193 additions & 0 deletions examples/nlp/language_modeling/megatron_gpt_continue_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.core.config import hydra_runner
from nemo.utils import AppState, logging
from nemo.utils.exp_manager import exp_manager
from nemo.utils.model_utils import inject_model_parallel_rank


def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these modify config, load_from_nemo, load_from_checkpoint_dir, validate_checkpoint_loading_args functions are the same as in the SFT code. can we put them into an utility file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are almost the same not 100% identical. The issue is each one (SFT vs continued training) modifies some common attributes like data, optim, but also a few different things.

"""
This function modifies the original gpt pre-training config (t5_cfg) with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(gpt_cfg, True)
OmegaConf.resolve(cfg)
with open_dict(gpt_cfg):
gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
gpt_cfg.micro_batch_size = cfg.model.micro_batch_size
gpt_cfg.global_batch_size = cfg.model.global_batch_size
gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False)
gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None)
gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None)
gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None)
gpt_cfg.data = cfg.model.data
gpt_cfg.optim = cfg.model.optim
gpt_cfg.precision = cfg.trainer.precision
gpt_cfg.restore_from_path = cfg.restore_from_path
gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint
gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view
gpt_cfg.encoder_seq_length = cfg.model.encoder_seq_length
gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(gpt_cfg)
gpt_cfg.cfg = gpt_cfg

return gpt_cfg


def load_from_nemo(cls, cfg, trainer, gpt_cfg, modify_confg_fn):
gpt_cfg = modify_confg_fn(gpt_cfg, cfg, add_cfg_to_tree=False)
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.restore_from_path
model = cls.restore_from(
restore_path=cfg.restore_from_path,
trainer=trainer,
override_config_path=gpt_cfg,
save_restore_connector=save_restore_connector,
)
return model


def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn):
app_state = AppState()
if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size
app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
app_state.virtual_pipeline_model_parallel_rank,
) = fake_initialize_model_parallel(
world_size=app_state.model_parallel_size,
rank=trainer.global_rank,
tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank,
)
checkpoint_path = inject_model_parallel_rank(
os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name)
)
hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file)
gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True)
with tempfile.NamedTemporaryFile(suffix='.yaml') as f:
OmegaConf.save(config=gpt_cfg, f=f.name)
model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,)
return model


def validate_checkpoint_loading_args(cfg):
if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir):
raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.')
if cfg.checkpoint_name is None:
raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.')
if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file):
raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.')


@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
with_distributed_adam = cfg.model.optim.get('name', 'fused_adam') == 'distributed_fused_adam'
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
)
if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2 and not with_distributed_adam:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)

exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
if cfg.model.resume_from_checkpoint is not None:
resume_from_checkpoint = cfg.model.resume_from_checkpoint
else:
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)

if cfg.restore_from_path:
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.restore_from_path
gpt_cfg = MegatronGPTModel.restore_from(
restore_path=cfg.restore_from_path,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
model = load_from_nemo(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config)
elif cfg.model.get("pretrained_checkpoint", None) is not None:
validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint)
model = load_from_checkpoint_dir(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config)

Check failure

Code scanning / CodeQL

Wrong number of arguments in a call

Call to [function load_from_checkpoint_dir](1) with too many arguments; should be no more than 4.

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable

Local variable 'gpt_cfg' may be used before it is initialized.
else:
print(' > WARNING: No checkpoint provided. Starting from scratch.')
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision
model = MegatronGPTModel(cfg.model, trainer)
trainer.fit(model)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ model:
activations_checkpoint_num_layers: null # not used with 'selective'
answer_only_loss: False # not used right now
gradient_as_bucket_view: False
seq_len_interpolation_factor: null # if not None, seq_len_interpolation_factor will match the base model's value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some explanation about how interpolation factor translate to longer sequences.
e.g. factor = 2, sequence length x 2

use_flash_attention: null # if not None, will match the base model's value

hidden_dropout: 0.0
attention_dropout: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def main(cfg) -> None:
peft_model_cfg.data.test_ds = cfg.model.data.test_ds
peft_model_cfg.activations_checkpoint_granularity = None
peft_model_cfg.activations_checkpoint_method = None
if peft_model_cfg.get("use_flash_attention", False):
peft_model_cfg.use_flash_attention = cfg.model.use_flash_attention
if cfg.model.get("seq_len_interpolation_factor", None) is not None:
peft_model_cfg["seq_len_interpolation_factor"] = cfg.model.seq_len_interpolation_factor

with open_dict(cfg):
# update the config with the trained model config
Expand Down
9 changes: 9 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

if cfg.model.get('use_flash_attention', None) is not None:
gpt_cfg.use_flash_attention = cfg.model.use_flash_attention

if cfg.model.get('seq_len_interpolation_factor', None) is not None:
gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor

sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
use_emha=False,
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
):
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -249,6 +250,7 @@ def __init__(
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
if isinstance(self.model, list):
converted_model = []
for module in self.model:
converted_model.append(Float16Module(module=module, precision=cfg.precision))
converted_model.append(
Float16Module(
module=module,
precision=cfg.precision,
share_token_embeddings=self.cfg.get('share_embeddings_and_output_weights', True),
)
)
self.model = converted_model
else:
self.model = Float16Module(module=self.model, precision=cfg.precision)
self.model = Float16Module(
module=self.model,
precision=cfg.precision,
share_token_embeddings=self.cfg.get('share_embeddings_and_output_weights', True),
)

if self.trainer.precision == 'bf16':
self.autocast_dtype = torch.bfloat16
Expand Down Expand Up @@ -360,6 +370,7 @@ def model_provider_func(self, pre_process, post_process):
ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False),
use_flash_attention=self.cfg.get('use_flash_attention', False),
megatron_legacy=self.cfg.get('megatron_legacy', False),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
)

return model
Expand Down Expand Up @@ -981,7 +992,7 @@ def build_pretraining_data_loader(
data_parallel_size=parallel_state.get_data_parallel_world_size(),
drop_last=drop_last,
global_batch_size=self.cfg.global_batch_size,
rampup_batch_size=self.cfg.rampup_batch_size,
rampup_batch_size=self.cfg.get('rampup_batch_size', None),
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
elif self.cfg.data.dataloader_type == 'cyclic':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def get_language_model(
use_emha=False,
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -200,6 +201,7 @@ def get_language_model(
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -508,6 +510,7 @@ def __init__(
use_emha=False,
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
):
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -559,7 +562,9 @@ def __init__(
assert 0 < rotary_percentage <= 1
if rotary_percentage < 1:
rotary_dim = int(rotary_dim * rotary_percentage)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
self.rotary_pos_emb = RotaryEmbedding(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe not in this PR. we need to add seq_len_interpolation_factor for all the models that uses RoPe.

rotary_dim, seq_len_interpolation_factor=seq_len_interpolation_factor
)

elif position_embedding_type == 'alibi':
# TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def float_conversion(val):


class Float16Module(MegatronModule):
def __init__(self, module, precision):
def __init__(self, module, precision, share_token_embeddings=True):
if not HAVE_MEGATRON_CORE:
raise ImportError(
"Megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)
super().__init__()
super().__init__(share_token_embeddings=share_token_embeddings)
self.precision = precision

if precision == 'bf16':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,28 @@


class RotaryEmbedding(nn.Module):
def __init__(self, dim):
"""
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
"""

def __init__(self, dim: int, seq_len_interpolation_factor: int = None):
"""
Args:

dim (int): rotary embedding dimension
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
by this factor via the trick in https://arxiv.org/abs/2306.15595.
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, max_seq_len, offset=0):
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
if self.seq_len_interpolation_factor is not None:
seq = seq.type_as(self.inv_freq)
seq *= 1 / self.seq_len_interpolation_factor
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
Expand Down