Skip to content

Commit

Permalink
TP fixes.
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Feb 1, 2024
1 parent 166a116 commit e89c6e3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 25 deletions.
1 change: 1 addition & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.ffn_dropout = cfg.model.ffn_dropout
gpt_cfg.use_flash_attention = cfg.model.get('use_flash_attention', False)
gpt_cfg.tensor_model_parallel_size = cfg.model.get('tensor_model_parallel_size', 1)
gpt_cfg.expert_model_parallel_size = cfg.model.get('expert_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0)

Expand Down
25 changes: 2 additions & 23 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,29 +313,8 @@ def fake_initialize_model_parallel(
logging.info(f'All tensor model parallel group ranks: {all_tensor_model_parallel_group_ranks}')
logging.info(f'Rank {rank} has tensor model parallel rank: {tensor_model_parallel_rank}')

# Build expert-parallelism groups
all_tensor_expert_parallel_group_ranks = []
expert_model_parallel_group = None
tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size
num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size
tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size
num_expert_groups: int = data_parallel_size // expert_model_parallel_size
for i in range(num_tensor_and_data_groups):
for j in range(num_expert_groups):
start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size
end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size
ranks = list(range(start_rank, end_rank))
all_tensor_expert_parallel_group_ranks.append(ranks)
if rank in ranks:
expert_model_parallel_group = ranks
logging.info(f'Rank {rank} has expert model parallel group: {ranks}')

if expert_model_parallel_group is not None:
expert_model_parallel_rank = expert_model_parallel_group.index(rank)
else:
expert_model_parallel_rank = 0
logging.info(f'All tensor expert parallel group ranks: {all_tensor_expert_parallel_group_ranks}')
logging.info(f'Rank {rank} has expert model parallel rank: {expert_model_parallel_rank}')
# EP not supported; set rank to zero.
expert_model_parallel_rank = 0

# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
Expand Down
2 changes: 0 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None)
# assert that fake tp and pp rank match after model parallel init
assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank()
assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank()
assert app_state.expert_model_parallel_rank == parallel_state.get_expert_model_parallel_rank()

# TODO(akoumparouli): EP
app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group()
app_state.data_parallel_group = parallel_state.get_data_parallel_group()
app_state.data_parallel_rank = parallel_state.get_data_parallel_rank()
Expand Down

0 comments on commit e89c6e3

Please sign in to comment.