Skip to content

Commit

Permalink
Adds several configurable flags for Megatron GPT models (#5991)
Browse files Browse the repository at this point in the history
* Initial

Signed-off-by: MaximumEntropy <[email protected]>

* Multiple fixes

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add to CI test

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* check position embs for gpt prompt learning

Signed-off-by: Adi Renduchintala <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update args

Signed-off-by: MaximumEntropy <[email protected]>

* Disable tts unit test

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* Fix

Signed-off-by: MaximumEntropy <[email protected]>

* Empty

Signed-off-by: MaximumEntropy <[email protected]>

* Update Jenkinsfile

Changed optimizer for GPT training from 'fused_adam' to 'distributed_fused_adam'.

Signed-off-by: khcs <[email protected]>

* update config to to use correct key

Signed-off-by: ericharper <[email protected]>

* revert Jenkinsfile back to fused_adam

Signed-off-by: ericharper <[email protected]>

---------

Signed-off-by: MaximumEntropy <[email protected]>
Signed-off-by: Adi Renduchintala <[email protected]>
Signed-off-by: khcs <[email protected]>
Signed-off-by: ericharper <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adi Renduchintala <[email protected]>
Co-authored-by: khcs <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: ericharper <[email protected]>
  • Loading branch information
6 people committed Feb 18, 2023
1 parent 8e6f36a commit 4a56631
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 42 deletions.
24 changes: 24 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3166,6 +3166,12 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.position_embedding_type=rope \
model.rotary_percentage=0.5 \
model.normalization=rmsnorm \
model.bias=False \
model.bias_activation_fusion=False \
model.bias_dropout_add_fusion=False \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
Expand Down Expand Up @@ -3196,6 +3202,12 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.position_embedding_type=rope \
model.rotary_percentage=0.5 \
model.normalization=rmsnorm \
model.bias=False \
model.bias_activation_fusion=False \
model.bias_dropout_add_fusion=False \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
Expand Down Expand Up @@ -3237,6 +3249,12 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.activation=swiglu \
model.bias_activation_fusion=False \
model.hidden_dropout=0.0 \
model.attention_dropout=0.0 \
model.transformer_block_type=normformer \
model.headscale=True \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
Expand Down Expand Up @@ -3267,6 +3285,12 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.activation=swiglu \
model.bias_activation_fusion=False \
model.hidden_dropout=0.0 \
model.attention_dropout=0.0 \
model.transformer_block_type=normformer \
model.headscale=True \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
Expand Down
14 changes: 12 additions & 2 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ exp_manager:
filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}


model:
# specify micro_batch_size, global_batch_size, and model parallelism
# gradient accumulation will be done automatically based on data_parallel_size
Expand All @@ -61,15 +60,26 @@ model:
use_scaled_init_method: True # use scaled residuals initialization
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability for attention
ffn_dropout: 0.0 # Dropout probability in the feed-forward layer.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
normalization: layernorm # Type of normalization layers
normalization: 'layernorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm'
layernorm_epsilon: 1e-5
do_layer_norm_weight_decay: False # True means weight decay on all params
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.
bias: True # Whether to use bias terms in all weight matrices.
activation: 'gelu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu']
headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head.
transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer']
openai_gelu: False # Use OpenAI's GELU instead of the default GeLU
normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope']
rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this.
attention_type: 'multihead' # Attention type. Options ['multihead']
share_embeddings_and_output_weights: True # Share embedding and output layer weights.

tokenizer:
library: 'megatron'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
use_cpu_initialization=False,
hidden_dropout=0.1,
attention_dropout=0.1,
ffn_dropout=0.0,
precision=16,
fp32_residual_connection=False,
activations_checkpoint_granularity=None,
Expand All @@ -122,9 +123,18 @@ def __init__(
activations_checkpoint_layers_per_pipeline=None,
normalization='layernorm',
layernorm_epsilon=1e-5,
bias=True,
bias_activation_fusion=True,
bias_dropout_add_fusion=True,
masked_softmax_fusion=True,
activation='gelu',
headscale=False,
transformer_block_type='pre_ln',
normalize_attention_scores=True,
position_embedding_type='learned_absolute',
rotary_percentage=1.0,
attention_type='multihead',
share_embeddings_and_output_weights=True,
gradient_accumulation_fusion=False,
persist_layer_norm=False,
openai_gelu=False,
Expand All @@ -141,15 +151,15 @@ def __init__(
reduce_amax=True,
use_emha=False,
):

super(GPTModel, self).__init__()
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
self.sequence_parallel = sequence_parallel
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights

if kv_channels is None:
assert (
Expand All @@ -167,6 +177,7 @@ def __init__(
hidden_size=hidden_size,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
ffn_dropout=ffn_dropout,
num_tokentypes=num_tokentypes,
max_position_embeddings=max_position_embeddings,
num_layers=num_layers,
Expand All @@ -190,10 +201,19 @@ def __init__(
activations_checkpoint_layers_per_pipeline=activations_checkpoint_layers_per_pipeline,
normalization=normalization,
layernorm_epsilon=layernorm_epsilon,
rotary_percentage=rotary_percentage,
share_embeddings_and_output_weights=share_embeddings_and_output_weights,
bias=bias,
bias_activation_fusion=bias_activation_fusion,
bias_dropout_add_fusion=bias_dropout_add_fusion,
masked_softmax_fusion=masked_softmax_fusion,
gradient_accumulation_fusion=gradient_accumulation_fusion,
activation=activation,
headscale=headscale,
transformer_block_type=transformer_block_type,
normalize_attention_scores=normalize_attention_scores,
position_embedding_type=position_embedding_type,
attention_type=attention_type,
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
Expand All @@ -210,9 +230,10 @@ def __init__(
use_emha=use_emha,
)

self.initialize_word_embeddings(
init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size
)
if self.share_embeddings_and_output_weights:
self.initialize_word_embeddings(
init_method=init_method_normal(init_method_std), vocab_size=vocab_size, hidden_size=hidden_size
)

def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
Expand Down Expand Up @@ -253,7 +274,9 @@ def forward(
return post_language_model_processing(
lm_output,
labels,
self.word_embeddings_weight(),
self.language_model.output_layer.weight
if not self.share_embeddings_and_output_weights
else self.word_embeddings_weight(),
get_key_value,
self.parallel_output,
forward_method_parallel_output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,16 +508,15 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
num_parameters_on_device = sum(
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
)
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_last_stage(
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_first_stage(
ignore_virtual=True
):
# substract the embedding weights on the last virtual stage
num_word_embedding_parameters = sum([p.nelement() for p in model[-1].word_embeddings_weight()])
num_parameters_on_device -= num_word_embedding_parameters
else:
num_parameters_on_device = sum([p.nelement() for p in model.parameters()])

if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_last_stage(
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_first_stage(
ignore_virtual=True
):
# substract the embedding weights on the last stage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def model_provider_func(self, pre_process, post_process):
use_cpu_initialization=self.cfg.get('use_cpu_initialization', False),
hidden_dropout=self.cfg.get('hidden_dropout', 0.1),
attention_dropout=self.cfg.get('attention_dropout', 0.1),
ffn_dropout=self.cfg.get('ffn_dropout', 0.0),
precision=self.cfg.get('precision', 16),
fp32_residual_connection=self.cfg.get('fp32_residual_connection', False),
activations_checkpoint_granularity=self.cfg.get('activations_checkpoint_granularity', None),
Expand All @@ -187,8 +188,18 @@ def model_provider_func(self, pre_process, post_process):
normalization=self.cfg.get('normalization', 'layernorm'),
layernorm_epsilon=self.cfg.get('layernorm_epsilon', 1e-5),
onnx_safe=self.cfg.get('onnx_safe', False),
bias=self.cfg.get('bias', True),
bias_activation_fusion=self.cfg.get('bias_activation_fusion', True),
bias_dropout_add_fusion=self.cfg.get('bias_dropout_add_fusion', True),
activation=self.cfg.get('activation', 'gelu'),
headscale=self.cfg.get('headscale', False),
transformer_block_type=self.cfg.get('transformer_block_type', 'pre_ln'),
openai_gelu=self.cfg.get('openai_gelu', False),
normalize_attention_scores=self.cfg.get('normalize_attention_scores', True),
position_embedding_type=self.cfg.get('position_embedding_type', 'learned_absolute'),
rotary_percentage=self.cfg.get('rotary_percentage', 1.0),
share_embeddings_and_output_weights=self.cfg.get('share_embeddings_and_output_weights', True),
attention_type=self.cfg.get('attention_type', 'multihead'),
masked_softmax_fusion=self.cfg.get('masked_softmax_fusion', True),
gradient_accumulation_fusion=self.cfg.get('gradient_accumulation_fusion', False),
persist_layer_norm=self.cfg.get('persist_layer_norm', False),
Expand Down Expand Up @@ -397,7 +408,9 @@ def training_step(self, dataloader_iter, batch_idx):
# so we all-reduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and self.cfg.get(
'share_embeddings_and_output_weights', True
):
# when using pipeline parallelism the first and last stage must keep embeddings in sync
self.allreduce_first_last_embeddings()

Expand Down Expand Up @@ -795,10 +808,12 @@ def setup(self, stage=None):
if isinstance(self.model, list):
for i, module in enumerate(self.model):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
module.sync_initial_word_embeddings()
if self.cfg.get('share_embeddings_and_output_weights', True):
module.sync_initial_word_embeddings()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
else:
self.model.sync_initial_word_embeddings()
if self.cfg.get('share_embeddings_and_output_weights', True):
self.model.sync_initial_word_embeddings()

if self.cfg.get('transformer_engine', False):
self.setup_transformer_engine_tp_groups()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,13 @@ def forward(
input_embeds = self.embed_input_inference(input_ids, taskname_ids)
else:
input_embeds = self.embed_input_train(input_ids, taskname_ids)
position_embeddings = self.frozen_model.model.language_model.embedding.position_embeddings(position_ids)
encoder_input = input_embeds + position_embeddings
if hasattr(self.frozen_model.model.language_model.embedding, "position_embeddings"):
position_embeddings = self.frozen_model.model.language_model.embedding.position_embeddings(
position_ids
)
encoder_input = input_embeds + position_embeddings
else:
encoder_input = input_embeds
encoder_input = encoder_input.transpose(0, 1).contiguous()
if self.cfg.get("sequence_parallel", False):
encoder_input = tensor_parallel.mappings.scatter_to_sequence_parallel_region(encoder_input)
Expand Down
76 changes: 55 additions & 21 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(
megatron_legacy=False,
bias=True,
headscale=False,
position_embedding_type='learned_absolute',
multi_query_attention=False,
activations_checkpoint_granularity=None,
sequence_parallel=False,
gradient_accumulation_fusion=False,
Expand All @@ -92,6 +94,8 @@ def __init__(
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.normalize_attention_scores = normalize_attention_scores
self.position_embedding_type = position_embedding_type
self.multi_query_attention = multi_query_attention

self.megatron_legacy = megatron_legacy

Expand Down Expand Up @@ -164,6 +168,7 @@ def __init__(
kv_channels=kv_channels,
masked_softmax_fusion=masked_softmax_fusion,
attention_dropout=attention_dropout,
multi_query_attention=multi_query_attention,
sequence_parallel=sequence_parallel,
normalize_attention_scores=normalize_attention_scores,
)
Expand Down Expand Up @@ -651,13 +656,15 @@ def __init__(
attention_dropout=0.1,
sequence_parallel=False,
normalize_attention_scores=True,
multi_query_attention=False,
):

super(CoreAttention, self).__init__()

self.precision = precision
self.fp16 = precision == 16
self.bf16 = precision == 'bf16'
self.multi_query_attention = multi_query_attention

self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = False
Expand Down Expand Up @@ -741,28 +748,55 @@ def forward(
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)

# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)
if self.multi_query_attention:
# [sq, b, np, hn] -> [b, np * sq, hn]
query_layer = query_layer.permute([1, 2, 0, 3]).reshape(
output_size[0], output_size[1] * output_size[2], -1
)

# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0,
)
# [sk, b, 1, hn] -> [b, hn, sk]
key_layer = key_layer.squeeze(2).permute(1, 2, 0)

# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)

# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer, # [b * np, sq, hn]
key_layer, # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
else:
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)

# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = torch.empty(
output_size[0] * output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device(),
)

# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0,
)

# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
Expand Down
Loading

0 comments on commit 4a56631

Please sign in to comment.