diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index d0f2945c684..6d44dbd8fa2 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -360,27 +360,8 @@ def get_rng_state(ckpt_format: str, tp_group: torch.distributed.ProcessGroup, pp pp_size = get_pg_size(pp_group) tp_rank = get_pg_rank(tp_group) tp_size = get_pg_size(tp_group) - ep_size = mpu.get_expert_model_parallel_world_size() - - if ep_size > 1: - # Shard RNG by PP, TP, DP when using expert parallelism. - dp_rank = mpu.get_data_parallel_rank(with_context_parallel=True) - dp_size = mpu.get_data_parallel_world_size(with_context_parallel=True) - rng_state_list = ShardedObject( - 'rng_state', - rng_state_list, - (pp_size, tp_size, dp_size), - (pp_rank, tp_rank, dp_rank), - replica_id=0, - ) - else: - rng_state_list = ShardedObject( - 'rng_state', - rng_state_list, - (pp_size, tp_size), - (pp_rank, tp_rank), - replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), - ) + rng_state_list = ShardedObject('rng_state', rng_state_list, (pp_size, tp_size), (pp_rank, tp_rank), + replica_id=mpu.get_data_parallel_rank(with_context_parallel=True)) elif ckpt_format == "fsdp_dtensor": pp_rank = mpu.get_pipeline_model_parallel_rank() tp_rank = mpu.get_tensor_model_parallel_rank()