Skip to content

Commit

Permalink
t5-rpe-fix targeting r1.10.0; raise exception for PP>2. (#4469)
Browse files Browse the repository at this point in the history
Signed-off-by: Hoo Chang Shin <[email protected]>

Co-authored-by: Hoo Chang Shin <[email protected]>
  • Loading branch information
khcs and khcs authored Jun 29, 2022
1 parent d985ebc commit 6219718
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 21 deletions.
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_bart_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ model:
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_ul2_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ model:
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/machine_translation/conf/aayn_base_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ model:
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class MegatronLMEncoderDecoderModel(MegatronBaseModel):

def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
if cfg.get('pipeline_model_parallel_size', 1) > 2 and self.cfg.get('position_embedding_type') == 'relative':
raise ValueError(
"pipeline_model_parallel_size cannot be > 2 with position_embedding_type == relative at the moment."
)
if cfg.get('pipeline_model_parallel_size', 1) > 1:
if cfg.get('pipeline_model_parallel_split_rank', 0) <= 0:
raise ValueError(
Expand Down Expand Up @@ -394,7 +398,7 @@ def allreduce_word_and_position_embeddings(self):
and parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.get_pipeline_model_parallel_split_rank() is not None
):
if self.enc_dec_model.position_embedding_type != 'relative':
if self.cfg.get('position_embedding_type') != 'relative':
position_embeddings_weight = self.enc_dec_model.position_embeddings_weight()
if self.megatron_amp_o2:
grad = position_embeddings_weight.main_grad
Expand Down Expand Up @@ -707,7 +711,7 @@ def setup(self, stage=None):
# when using pipeline model parallel the final stage need to initialize word embeddings
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
self.enc_dec_model.sync_initial_word_embeddings()
if self.enc_dec_model.position_embedding_type != 'relative':
if self.cfg.get('position_embedding_type') != 'relative':
self.enc_dec_model.sync_initial_position_embeddings()

def setup_training_data(self, cfg):
Expand Down
55 changes: 36 additions & 19 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def __init__(
megatron_legacy=False,
bias=True,
headscale=False,
has_relative_attention_bias=False,
):
super(ParallelAttention, self).__init__()

Expand All @@ -299,6 +300,7 @@ def __init__(
self.attn_mask_type = attn_mask_type
self.megatron_legacy = megatron_legacy
self.headscale = headscale
self.has_relative_attention_bias = has_relative_attention_bias

if kv_channels is None:
assert (
Expand Down Expand Up @@ -382,7 +384,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
if self.position_embedding_type == 'relative':
if self.position_embedding_type == 'relative' and self.has_relative_attention_bias:
self.relative_attention_bias = torch.nn.Embedding(
relative_attention_num_buckets, self.num_attention_heads_per_partition
).to(torch.cuda.current_device())
Expand Down Expand Up @@ -498,7 +500,7 @@ def compute_bias(self, query_length, key_length):
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(self.layer_type != LayerType.decoder), # (not self.is_decoder),
bidirectional=(self.attention_type != AttnMaskType.causal), # self.is_decoder and self_attention.
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
Expand Down Expand Up @@ -683,9 +685,12 @@ def forward(

if position_bias is None:
if self.position_embedding_type == 'relative':
position_bias = self.compute_bias(real_seq_length, key_length)
else:
pass # HuggingFace implementation initialize position_bias to zero when not using
if self.has_relative_attention_bias:
position_bias = self.compute_bias(real_seq_length, key_length)
elif attention_mask is not None:
position_bias = torch.zeros_like(attention_mask).to(torch.cuda.current_device())
else:
position_bias = torch.zeros(1, key_length, key_length).to(torch.cuda.current_device())

# if key and values are already calculated
# we want only the last query position bias
Expand Down Expand Up @@ -975,6 +980,7 @@ def __init__(
normalization='layernorm',
transformer_block_type='pre_ln',
headscale=False,
has_relative_attention_bias=False,
):
super(ParallelTransformerLayer_, self).__init__()

Expand Down Expand Up @@ -1038,6 +1044,7 @@ def __init__(
megatron_legacy=megatron_legacy,
bias=bias,
headscale=headscale,
has_relative_attention_bias=has_relative_attention_bias,
)
# Normformer normalization
if transformer_block_type == 'normformer':
Expand Down Expand Up @@ -1079,6 +1086,7 @@ def __init__(
megatron_legacy=megatron_legacy,
bias=bias,
headscale=headscale,
has_relative_attention_bias=False,
)
# Normformer normalization
if transformer_block_type == 'normformer':
Expand Down Expand Up @@ -1207,14 +1215,6 @@ def forward(
# Post-LN: x -> MHA -> Residual -> LN -> MLP -> Residual -> LN
# Normformer: x -> LN -> MHA -> LN -> Residual -> MLP (w/LN) -> Residual

if type(hidden_states) is tuple:
if len(hidden_states) == 2:
hidden_states, position_bias = hidden_states
elif len(hidden_states) == 3:
hidden_states, position_bias, encoder_decoder_position_bias = hidden_states
else:
raise IndexError('Hidden_states needs to be tuple containing 2 or 3 elements.')

residual = hidden_states
# Layer norm at the beginning of the transformer layer.
if self.transformer_block_type in ['pre_ln', 'normformer']:
Expand All @@ -1228,6 +1228,7 @@ def forward(
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
rotary_pos_emb=self_attention_pos_emb,
position_bias=position_bias,
)

if get_key_value:
Expand Down Expand Up @@ -1484,7 +1485,7 @@ def __init__(

self.num_layers = self.get_num_layers(num_layers)
# Transformer layers.
def build_layer(layer_number):
def build_layer(layer_number, has_relative_attention_bias=False):
if isinstance(layer_type, list):
lt = layer_type[layer_number - 1]
else:
Expand Down Expand Up @@ -1522,6 +1523,7 @@ def build_layer(layer_number):
normalization=normalization,
transformer_block_type=transformer_block_type,
headscale=headscale,
has_relative_attention_bias=has_relative_attention_bias,
)

if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down Expand Up @@ -1558,7 +1560,14 @@ def build_layer(layer_number):
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers

self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
self.layers = torch.nn.ModuleList(
[
build_layer(
i + 1 + offset, has_relative_attention_bias=(i == 0) and parallel_state.is_pipeline_first_stage()
)
for i in range(self.num_layers)
]
)

if self.post_process:
# Final layer norm before output.
Expand Down Expand Up @@ -1624,9 +1633,16 @@ def custom_forward(*inputs):
encoder_output,
enc_dec_attn_mask,
rotary_pos_emb,
position_bias,
encoder_decoder_position_bias,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
)
if type(x_) is tuple:
if len(x_) == 2:
x_, position_bias = x_
elif len(x_) == 3:
x_, position_bias, encoder_decoder_position_bias = x_
else:
raise IndexError('Hidden_states (x_) needs to be tuple containing 2 or 3 elements.')
return x_

return custom_forward
Expand Down Expand Up @@ -1704,9 +1720,10 @@ def forward(
inference_max_sequence_len=None,
rotary_pos_emb=None, # list of positional embedding tensors, first one self attention, second one and third one are for cross attention (q, k)
retrieved_emb=None, # tensor of retrieved embedding of shape [b, k, r, n, d]
position_bias=None,
encoder_decoder_position_bias=None,
):
position_bias = None
encoder_decoder_position_bias = None

# Checks.
if inference_max_sequence_len:
assert self.activations_checkpoint_method is None, 'inference does not work with activation checkpointing'
Expand Down

0 comments on commit 6219718

Please sign in to comment.