Skip to content

Commit

Permalink
Add sequence parallel support to Rope positional embedding (NVIDIA#6178)
Browse files Browse the repository at this point in the history
* support sp

Signed-off-by: Yi Dong <[email protected]>

* use parallel_state

Signed-off-by: Yi Dong <[email protected]>

---------

Signed-off-by: Yi Dong <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
yidong72 authored and hsiehjackson committed Jun 2, 2023
1 parent 4b80d6a commit 8448ee3
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions nemo/collections/nlp/modules/common/megatron/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)

try:
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType

HAVE_APEX = True
Expand Down Expand Up @@ -495,6 +495,7 @@ def __init__(
self.output_layer_init_method = output_layer_init_method
self.position_embedding_type = position_embedding_type
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.sequence_parallel = sequence_parallel

if kv_channels is None:

Expand Down Expand Up @@ -688,8 +689,6 @@ def forward(
else:
pass

# encoder_input: [s, b, h]

# enc_attn_mask: [1, 1, s, s]

rotary_pos_emb = None
Expand All @@ -698,17 +697,21 @@ def forward(
if inference_max_sequence_len is not None:
rotary_pos_emb = self.rotary_pos_emb(inference_max_sequence_len)
elif self.encoder.input_tensor is not None:
rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0))
if self.sequence_parallel:
rotary_pos_emb = self.rotary_pos_emb(
self.encoder.input_tensor.size(0) * parallel_state.get_tensor_model_parallel_world_size()
)
else:
rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0))
else:
rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0))
elif self.position_embedding_type == 'alibi':
enc_seq_length = enc_input_ids.size(1)
encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding(
query_seq_length=enc_seq_length, key_seq_length=enc_seq_length,
)
elif self.position_embedding_type == 'kerple':
encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding

if self.sequence_parallel:
rotary_pos_emb = self.rotary_pos_emb(
encoder_input.size(0) * parallel_state.get_tensor_model_parallel_world_size()
)
else:
rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0))
else:
rotary_pos_emb = None
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(
Expand Down

0 comments on commit 8448ee3

Please sign in to comment.