From 8448ee3e83b4a72758e031f2b4eab73515a739a4 Mon Sep 17 00:00:00 2001 From: Yi Dong <43824965+yidong72@users.noreply.github.com> Date: Mon, 13 Mar 2023 20:58:37 -0400 Subject: [PATCH] Add sequence parallel support to Rope positional embedding (#6178) * support sp Signed-off-by: Yi Dong * use parallel_state Signed-off-by: Yi Dong --------- Signed-off-by: Yi Dong Signed-off-by: hsiehjackson --- .../modules/common/megatron/language_model.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 59918e0b9637..b75b8cac4518 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -31,7 +31,7 @@ ) try: - from apex.transformer import tensor_parallel + from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType HAVE_APEX = True @@ -495,6 +495,7 @@ def __init__( self.output_layer_init_method = output_layer_init_method self.position_embedding_type = position_embedding_type self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.sequence_parallel = sequence_parallel if kv_channels is None: @@ -688,8 +689,6 @@ def forward( else: pass - # encoder_input: [s, b, h] - # enc_attn_mask: [1, 1, s, s] rotary_pos_emb = None @@ -698,17 +697,21 @@ def forward( if inference_max_sequence_len is not None: rotary_pos_emb = self.rotary_pos_emb(inference_max_sequence_len) elif self.encoder.input_tensor is not None: - rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0)) + if self.sequence_parallel: + rotary_pos_emb = self.rotary_pos_emb( + self.encoder.input_tensor.size(0) * parallel_state.get_tensor_model_parallel_world_size() + ) + else: + rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0)) else: - rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0)) - elif self.position_embedding_type == 'alibi': - enc_seq_length = enc_input_ids.size(1) - encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( - query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, - ) - elif self.position_embedding_type == 'kerple': - encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding - + if self.sequence_parallel: + rotary_pos_emb = self.rotary_pos_emb( + encoder_input.size(0) * parallel_state.get_tensor_model_parallel_world_size() + ) + else: + rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0)) + else: + rotary_pos_emb = None # encoder. if enc_hidden_states is None: encoder_output = self.encoder(