Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

t5-rpe-fix targeting r1.10.0; raise exception for PP>2. #4469

Merged
merged 1 commit into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_bart_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ model:
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_ul2_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ model:
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/machine_translation/conf/aayn_base_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ model:
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
attention_dropout: 0.1 # Dropout probability in the attention layer.
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'relative']
relative_attention_num_buckets: 32 # Relative position number of buckets for computing the bias
relative_attention_max_distance: 128 # max_distance to keep relative distance in the attention_num_buckets.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
layernorm_epsilon: 1e-5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class MegatronLMEncoderDecoderModel(MegatronBaseModel):

def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
if cfg.get('pipeline_model_parallel_size', 1) > 2 and self.cfg.get('position_embedding_type') == 'relative':
raise ValueError(
"pipeline_model_parallel_size cannot be > 2 with position_embedding_type == relative at the moment."
)
if cfg.get('pipeline_model_parallel_size', 1) > 1:
if cfg.get('pipeline_model_parallel_split_rank', 0) <= 0:
raise ValueError(
Expand Down Expand Up @@ -394,7 +398,7 @@ def allreduce_word_and_position_embeddings(self):
and parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.get_pipeline_model_parallel_split_rank() is not None
):
if self.enc_dec_model.position_embedding_type != 'relative':
if self.cfg.get('position_embedding_type') != 'relative':
position_embeddings_weight = self.enc_dec_model.position_embeddings_weight()
if self.megatron_amp_o2:
grad = position_embeddings_weight.main_grad
Expand Down Expand Up @@ -707,7 +711,7 @@ def setup(self, stage=None):
# when using pipeline model parallel the final stage need to initialize word embeddings
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
self.enc_dec_model.sync_initial_word_embeddings()
if self.enc_dec_model.position_embedding_type != 'relative':
if self.cfg.get('position_embedding_type') != 'relative':
self.enc_dec_model.sync_initial_position_embeddings()

def setup_training_data(self, cfg):
Expand Down
55 changes: 36 additions & 19 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def __init__(
megatron_legacy=False,
bias=True,
headscale=False,
has_relative_attention_bias=False,
):
super(ParallelAttention, self).__init__()

Expand All @@ -299,6 +300,7 @@ def __init__(
self.attn_mask_type = attn_mask_type
self.megatron_legacy = megatron_legacy
self.headscale = headscale
self.has_relative_attention_bias = has_relative_attention_bias

if kv_channels is None:
assert (
Expand Down Expand Up @@ -382,7 +384,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
if self.position_embedding_type == 'relative':
if self.position_embedding_type == 'relative' and self.has_relative_attention_bias:
self.relative_attention_bias = torch.nn.Embedding(
relative_attention_num_buckets, self.num_attention_heads_per_partition
).to(torch.cuda.current_device())
Expand Down Expand Up @@ -498,7 +500,7 @@ def compute_bias(self, query_length, key_length):
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(self.layer_type != LayerType.decoder), # (not self.is_decoder),
bidirectional=(self.attention_type != AttnMaskType.causal), # self.is_decoder and self_attention.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this comment? I think should be encoder + self-attention?

num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
Expand Down Expand Up @@ -683,9 +685,12 @@ def forward(

if position_bias is None:
if self.position_embedding_type == 'relative':
position_bias = self.compute_bias(real_seq_length, key_length)
else:
pass # HuggingFace implementation initialize position_bias to zero when not using
if self.has_relative_attention_bias:
position_bias = self.compute_bias(real_seq_length, key_length)
elif attention_mask is not None:
position_bias = torch.zeros_like(attention_mask).to(torch.cuda.current_device())
else:
position_bias = torch.zeros(1, key_length, key_length).to(torch.cuda.current_device())

# if key and values are already calculated
# we want only the last query position bias
Expand Down Expand Up @@ -975,6 +980,7 @@ def __init__(
normalization='layernorm',
transformer_block_type='pre_ln',
headscale=False,
has_relative_attention_bias=False,
):
super(ParallelTransformerLayer_, self).__init__()

Expand Down Expand Up @@ -1038,6 +1044,7 @@ def __init__(
megatron_legacy=megatron_legacy,
bias=bias,
headscale=headscale,
has_relative_attention_bias=has_relative_attention_bias,
)
# Normformer normalization
if transformer_block_type == 'normformer':
Expand Down Expand Up @@ -1079,6 +1086,7 @@ def __init__(
megatron_legacy=megatron_legacy,
bias=bias,
headscale=headscale,
has_relative_attention_bias=False,
)
# Normformer normalization
if transformer_block_type == 'normformer':
Expand Down Expand Up @@ -1207,14 +1215,6 @@ def forward(
# Post-LN: x -> MHA -> Residual -> LN -> MLP -> Residual -> LN
# Normformer: x -> LN -> MHA -> LN -> Residual -> MLP (w/LN) -> Residual

if type(hidden_states) is tuple:
if len(hidden_states) == 2:
hidden_states, position_bias = hidden_states
elif len(hidden_states) == 3:
hidden_states, position_bias, encoder_decoder_position_bias = hidden_states
else:
raise IndexError('Hidden_states needs to be tuple containing 2 or 3 elements.')

Comment on lines -1210 to -1217
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you delete this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he moved it further down

if type(x_) is tuple:
if len(x_) == 2:
x_, position_bias = x_
elif len(x_) == 3:
x_, position_bias, encoder_decoder_position_bias = x_
else:
raise IndexError('Hidden_states (x_) needs to be tuple containing 2 or 3 elements.')

residual = hidden_states
# Layer norm at the beginning of the transformer layer.
if self.transformer_block_type in ['pre_ln', 'normformer']:
Expand All @@ -1228,6 +1228,7 @@ def forward(
set_inference_key_value_memory=set_inference_key_value_memory,
inference_max_sequence_len=inference_max_sequence_len,
rotary_pos_emb=self_attention_pos_emb,
position_bias=position_bias,
)

if get_key_value:
Expand Down Expand Up @@ -1484,7 +1485,7 @@ def __init__(

self.num_layers = self.get_num_layers(num_layers)
# Transformer layers.
def build_layer(layer_number):
def build_layer(layer_number, has_relative_attention_bias=False):
if isinstance(layer_type, list):
lt = layer_type[layer_number - 1]
else:
Expand Down Expand Up @@ -1522,6 +1523,7 @@ def build_layer(layer_number):
normalization=normalization,
transformer_block_type=transformer_block_type,
headscale=headscale,
has_relative_attention_bias=has_relative_attention_bias,
)

if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down Expand Up @@ -1558,7 +1560,14 @@ def build_layer(layer_number):
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers

self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])
self.layers = torch.nn.ModuleList(
[
build_layer(
i + 1 + offset, has_relative_attention_bias=(i == 0) and parallel_state.is_pipeline_first_stage()
)
for i in range(self.num_layers)
]
)

if self.post_process:
# Final layer norm before output.
Expand Down Expand Up @@ -1624,9 +1633,16 @@ def custom_forward(*inputs):
encoder_output,
enc_dec_attn_mask,
rotary_pos_emb,
position_bias,
encoder_decoder_position_bias,
position_bias=position_bias,
encoder_decoder_position_bias=encoder_decoder_position_bias,
)
if type(x_) is tuple:
if len(x_) == 2:
x_, position_bias = x_
elif len(x_) == 3:
x_, position_bias, encoder_decoder_position_bias = x_
else:
raise IndexError('Hidden_states (x_) needs to be tuple containing 2 or 3 elements.')
return x_

return custom_forward
Expand Down Expand Up @@ -1704,9 +1720,10 @@ def forward(
inference_max_sequence_len=None,
rotary_pos_emb=None, # list of positional embedding tensors, first one self attention, second one and third one are for cross attention (q, k)
retrieved_emb=None, # tensor of retrieved embedding of shape [b, k, r, n, d]
position_bias=None,
encoder_decoder_position_bias=None,
):
position_bias = None
encoder_decoder_position_bias = None

# Checks.
if inference_max_sequence_len:
assert self.activations_checkpoint_method is None, 'inference does not work with activation checkpointing'
Expand Down