Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sequence parallel support to Rope positional embedding #6178

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions nemo/collections/nlp/modules/common/megatron/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

try:
from apex.transformer import tensor_parallel
from apex.transformer import parallel_state, tensor_parallel
from apex.transformer.enums import AttnMaskType

HAVE_APEX = True
Expand Down Expand Up @@ -491,6 +491,7 @@ def __init__(
self.output_layer_init_method = output_layer_init_method
self.position_embedding_type = position_embedding_type
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.sequence_parallel = sequence_parallel

if kv_channels is None:

Expand Down Expand Up @@ -664,17 +665,25 @@ def forward(
else:
pass

# encoder_input: [s, b, h]

# enc_attn_mask: [1, 1, s, s]

if self.position_embedding_type == 'rope':
if inference_max_sequence_len is not None:
rotary_pos_emb = self.rotary_pos_emb(inference_max_sequence_len)
elif self.encoder.input_tensor is not None:
rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0))
if self.sequence_parallel:
rotary_pos_emb = self.rotary_pos_emb(
self.encoder.input_tensor.size(0) * parallel_state.get_tensor_model_parallel_world_size()
)
else:
rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0))
else:
rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0))
if self.sequence_parallel:
rotary_pos_emb = self.rotary_pos_emb(
encoder_input.size(0) * parallel_state.get_tensor_model_parallel_world_size()
)
else:
rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0))
else:
rotary_pos_emb = None
# encoder.
Expand Down