From 20bb03034a5e7efd9426c678001b3c9e677cbc6d Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Mon, 13 Mar 2023 03:25:15 +0000 Subject: [PATCH 1/2] support sp Signed-off-by: Yi Dong --- .../modules/common/megatron/language_model.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 1ef5fe0060e4..c7134b0625ac 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -25,6 +25,7 @@ init_method_normal, scaled_init_method_normal, ) +from nemo.utils.app_state import AppState try: from apex.transformer import tensor_parallel @@ -491,6 +492,7 @@ def __init__( self.output_layer_init_method = output_layer_init_method self.position_embedding_type = position_embedding_type self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.sequence_parallel = sequence_parallel if kv_channels is None: @@ -664,17 +666,25 @@ def forward( else: pass - # encoder_input: [s, b, h] - # enc_attn_mask: [1, 1, s, s] if self.position_embedding_type == 'rope': if inference_max_sequence_len is not None: rotary_pos_emb = self.rotary_pos_emb(inference_max_sequence_len) elif self.encoder.input_tensor is not None: - rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0)) + if self.sequence_parallel: + app_state = AppState() + rotary_pos_emb = self.rotary_pos_emb( + self.encoder.input_tensor.size(0) * app_state.tensor_model_parallel_size + ) + else: + rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0)) else: - rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0)) + if self.sequence_parallel: + app_state = AppState() + rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0) * app_state.tensor_model_parallel_size) + else: + rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0)) else: rotary_pos_emb = None # encoder. From 98c5b17b0029addb0a6606bdd027fa21858412f2 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Mon, 13 Mar 2023 13:32:28 +0000 Subject: [PATCH 2/2] use parallel_state Signed-off-by: Yi Dong --- .../nlp/modules/common/megatron/language_model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index c7134b0625ac..902d571ae170 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -25,10 +25,9 @@ init_method_normal, scaled_init_method_normal, ) -from nemo.utils.app_state import AppState try: - from apex.transformer import tensor_parallel + from apex.transformer import parallel_state, tensor_parallel from apex.transformer.enums import AttnMaskType HAVE_APEX = True @@ -673,16 +672,16 @@ def forward( rotary_pos_emb = self.rotary_pos_emb(inference_max_sequence_len) elif self.encoder.input_tensor is not None: if self.sequence_parallel: - app_state = AppState() rotary_pos_emb = self.rotary_pos_emb( - self.encoder.input_tensor.size(0) * app_state.tensor_model_parallel_size + self.encoder.input_tensor.size(0) * parallel_state.get_tensor_model_parallel_world_size() ) else: rotary_pos_emb = self.rotary_pos_emb(self.encoder.input_tensor.size(0)) else: if self.sequence_parallel: - app_state = AppState() - rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0) * app_state.tensor_model_parallel_size) + rotary_pos_emb = self.rotary_pos_emb( + encoder_input.size(0) * parallel_state.get_tensor_model_parallel_world_size() + ) else: rotary_pos_emb = self.rotary_pos_emb(encoder_input.size(0)) else: