Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.transformer import MegatronModule
from megatron.core.utils import unwrap_model
from megatron.core.utils import get_pg_size, unwrap_model
from modelopt.torch.opt.plugins import (
restore_modelopt_state,
save_modelopt_state,
Expand Down Expand Up @@ -386,12 +386,18 @@ def get_rng_state(
Optionally gathers states across data parallel ranks.
Returns format depends on checkpoint format.

For torch_dist format with Expert Parallelism (EP > 1), RNG states are sharded
by (PP, TP, DP) dimensions since different EP ranks may have different RNG states.
Without EP, states are sharded by (PP, TP) with DP rank as replica_id.

Args:
data_parallel_random_init: If True, gathers RNG states across data parallel ranks.
ckpt_format: The checkpoint format being used.
pg_collection: Process group collection for accessing parallel ranks/sizes.

Returns:
For torch_dist: A ShardedObject containing the RNG states.
For torch_dist: A ShardedObject containing the RNG states, sharded by
(PP, TP, DP) when EP > 1, or (PP, TP) with DP as replica_id otherwise.
For fsdp_dtensor: A dict mapping (pp_rank, tp_rank) to RNG state lists.
"""
rng_state = {
Expand All @@ -414,13 +420,30 @@ def get_rng_state(
pp_size = pg_collection.pp.size()
tp_rank = pg_collection.tp.rank()
tp_size = pg_collection.tp.size()
rng_state_list = ShardedObject(
"rng_state",
rng_state_list,
(pp_size, tp_size),
(pp_rank, tp_rank),
replica_id=pg_collection.dp_cp.rank(),
)
ep_size = get_pg_size(pg_collection.ep)

if ep_size > 1:
# Shard RNG by PP, TP, DP when using expert parallelism.
# With EP, different EP ranks within the same DP group may have different
# RNG states for their respective experts, so DP rank must be part of
# the sharding dimensions rather than replica_id.
dp_rank = pg_collection.dp_cp.rank()
dp_size = pg_collection.dp_cp.size()
rng_state_list = ShardedObject(
"rng_state",
rng_state_list,
(pp_size, tp_size, dp_size),
(pp_rank, tp_rank, dp_rank),
replica_id=0,
)
else:
rng_state_list = ShardedObject(
"rng_state",
rng_state_list,
(pp_size, tp_size),
(pp_rank, tp_rank),
replica_id=pg_collection.dp_cp.rank(),
)
elif ckpt_format == "fsdp_dtensor":
pp_rank = pg_collection.pp.rank()
tp_rank = pg_collection.tp.rank()
Expand Down Expand Up @@ -1679,6 +1702,8 @@ def _load_checkpoint_from_path(
# Load RNG states
if not release and not cfg.checkpoint.finetune and cfg.checkpoint.load_rng and not ignore_rng_state:
try:
cuda_rng_tracker = tensor_parallel.get_cuda_rng_tracker()
graph_safe_rng = tensor_parallel.is_graph_safe_cuda_rng_tracker(cuda_rng_tracker)
if "rng_state" in state_dict:
if ckpt_format == "fsdp_dtensor":
# FSDP DTensor format: {(pp_rank, tp_rank): rng_state_list}
Expand Down Expand Up @@ -1709,15 +1734,23 @@ def _load_checkpoint_from_path(
torch.cuda.set_rng_state(rng_state["cuda_rng_state"])
if not rng_state["rng_tracker_states"]:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(rng_state["rng_tracker_states"])
rng_tracker_states = {
k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng)
for k, v in rng_state["rng_tracker_states"].items()
}
cuda_rng_tracker.set_states(rng_tracker_states)
else: # backward compatibility
random.setstate(state_dict["random_rng_state"])
np.random.set_state(state_dict["np_rng_state"])
torch.set_rng_state(state_dict["torch_rng_state"])
torch.cuda.set_rng_state(state_dict["cuda_rng_state"])
if not state_dict["rng_tracker_states"]:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"])
rng_tracker_states = {
k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng)
for k, v in state_dict["rng_tracker_states"].items()
}
cuda_rng_tracker.set_states(rng_tracker_states)
except KeyError:
print_rank_0(
"Unable to load rng state from checkpoint {}. "
Expand Down
158 changes: 158 additions & 0 deletions tests/unit_tests/training/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def test_get_rng_state(self, mock_random, mock_np, mock_torch, mock_cuda, mock_d
mock_pg_collection.tp.size.return_value = 1
mock_pg_collection.dp_cp.rank.return_value = 0
mock_pg_collection.dp_cp.size.return_value = 1
mock_pg_collection.ep.size.return_value = 1 # EP = 1 (no expert parallelism)

result = get_rng_state(
data_parallel_random_init=False, ckpt_format="torch_dist", pg_collection=mock_pg_collection
Expand All @@ -290,6 +291,163 @@ def test_get_rng_state(self, mock_random, mock_np, mock_torch, mock_cuda, mock_d
assert rng_state["np_rng_state"] == "np_state"
assert rng_state["rng_tracker_states"] == "tracker_states"

@patch("megatron.bridge.training.checkpointing.get_pg_size")
@patch("megatron.bridge.training.checkpointing.tensor_parallel")
@patch("torch.distributed.is_initialized")
@patch("torch.cuda.get_rng_state")
@patch("torch.get_rng_state")
@patch("numpy.random.get_state")
@patch("random.getstate")
def test_get_rng_state_with_expert_parallelism(
self, mock_random, mock_np, mock_torch, mock_cuda, mock_dist_init, mock_tp, mock_get_pg_size
):
"""Test RNG state collection with Expert Parallelism (EP > 1).

When EP > 1, RNG state should be sharded by (PP, TP, DP) dimensions
with replica_id=0, since different EP ranks may have different RNG states.
"""
# Setup mocks
mock_dist_init.return_value = False
mock_random.return_value = "random_state"
mock_np.return_value = "np_state"
mock_torch.return_value = torch.tensor([1, 2, 3])
mock_cuda.return_value = torch.tensor([4, 5, 6])
mock_tracker = Mock()
mock_tracker.get_states.return_value = "tracker_states"
mock_tp.get_cuda_rng_tracker.return_value = mock_tracker

# Mock get_pg_size to return EP size > 1
mock_get_pg_size.return_value = 8 # EP > 1

# Create mock pg_collection with EP > 1 configuration
mock_pg_collection = Mock()
mock_pg_collection.pp.rank.return_value = 1
mock_pg_collection.pp.size.return_value = 2
mock_pg_collection.tp.rank.return_value = 3
mock_pg_collection.tp.size.return_value = 4
mock_pg_collection.dp_cp.rank.return_value = 5
mock_pg_collection.dp_cp.size.return_value = 6

result = get_rng_state(
data_parallel_random_init=False, ckpt_format="torch_dist", pg_collection=mock_pg_collection
)

# Verify get_pg_size was called with pg_collection.ep
mock_get_pg_size.assert_called_once_with(mock_pg_collection.ep)

# Verify the result is a ShardedObject with correct sharding
assert result.key == "rng_state"
# Shape should be (pp_size, tp_size, dp_size) when EP > 1
assert result.global_shape == (2, 4, 6)
# Global offset should include dp_rank
assert result.global_offset == (1, 3, 5)
# replica_id should be 0 (not dp_rank) when EP > 1
assert result.replica_id == 0

@patch("megatron.bridge.training.checkpointing.get_pg_size")
@patch("megatron.bridge.training.checkpointing.tensor_parallel")
@patch("torch.distributed.is_initialized")
@patch("torch.cuda.get_rng_state")
@patch("torch.get_rng_state")
@patch("numpy.random.get_state")
@patch("random.getstate")
def test_get_rng_state_without_expert_parallelism(
self, mock_random, mock_np, mock_torch, mock_cuda, mock_dist_init, mock_tp, mock_get_pg_size
):
"""Test RNG state collection without Expert Parallelism (EP = 1).

When EP = 1, RNG state should be sharded by (PP, TP) dimensions
with replica_id=dp_rank (standard behavior).
"""
# Setup mocks
mock_dist_init.return_value = False
mock_random.return_value = "random_state"
mock_np.return_value = "np_state"
mock_torch.return_value = torch.tensor([1, 2, 3])
mock_cuda.return_value = torch.tensor([4, 5, 6])
mock_tracker = Mock()
mock_tracker.get_states.return_value = "tracker_states"
mock_tp.get_cuda_rng_tracker.return_value = mock_tracker

# Mock get_pg_size to return EP size = 1
mock_get_pg_size.return_value = 1 # EP = 1

# Create mock pg_collection with EP = 1 configuration
mock_pg_collection = Mock()
mock_pg_collection.pp.rank.return_value = 1
mock_pg_collection.pp.size.return_value = 2
mock_pg_collection.tp.rank.return_value = 3
mock_pg_collection.tp.size.return_value = 4
mock_pg_collection.dp_cp.rank.return_value = 5
mock_pg_collection.dp_cp.size.return_value = 1

result = get_rng_state(
data_parallel_random_init=False, ckpt_format="torch_dist", pg_collection=mock_pg_collection
)

# Verify get_pg_size was called with pg_collection.ep
mock_get_pg_size.assert_called_once_with(mock_pg_collection.ep)

# Verify the result is a ShardedObject with correct sharding
assert result.key == "rng_state"
# Shape should be (pp_size, tp_size) when EP = 1
assert result.global_shape == (2, 4)
# Global offset should NOT include dp_rank
assert result.global_offset == (1, 3)
# replica_id should be dp_rank when EP = 1
assert result.replica_id == 5

@patch("megatron.bridge.training.checkpointing.get_pg_size")
@patch("megatron.bridge.training.checkpointing.tensor_parallel")
@patch("torch.distributed.is_initialized")
@patch("torch.cuda.get_rng_state")
@patch("torch.get_rng_state")
@patch("numpy.random.get_state")
@patch("random.getstate")
def test_get_rng_state_with_none_ep_group(
self, mock_random, mock_np, mock_torch, mock_cuda, mock_dist_init, mock_tp, mock_get_pg_size
):
"""Test RNG state collection when EP group is None (not initialized).

When pg_collection.ep is None, get_pg_size returns 1, so this should
behave the same as EP=1 (sharded by PP, TP with replica_id=dp_rank).
"""
# Setup mocks
mock_dist_init.return_value = False
mock_random.return_value = "random_state"
mock_np.return_value = "np_state"
mock_torch.return_value = torch.tensor([1, 2, 3])
mock_cuda.return_value = torch.tensor([4, 5, 6])
mock_tracker = Mock()
mock_tracker.get_states.return_value = "tracker_states"
mock_tp.get_cuda_rng_tracker.return_value = mock_tracker

# Mock get_pg_size to return 1 (what it does for None groups)
mock_get_pg_size.return_value = 1

# Create mock pg_collection with ep=None
mock_pg_collection = Mock()
mock_pg_collection.pp.rank.return_value = 0
mock_pg_collection.pp.size.return_value = 2
mock_pg_collection.tp.rank.return_value = 1
mock_pg_collection.tp.size.return_value = 4
mock_pg_collection.dp_cp.rank.return_value = 3
mock_pg_collection.dp_cp.size.return_value = 1
mock_pg_collection.ep = None # Explicitly None

result = get_rng_state(
data_parallel_random_init=False, ckpt_format="torch_dist", pg_collection=mock_pg_collection
)

# Verify get_pg_size was called with None
mock_get_pg_size.assert_called_once_with(None)

# Verify the result is a ShardedObject with correct sharding (same as EP=1)
assert result.key == "rng_state"
assert result.global_shape == (2, 4) # (pp_size, tp_size)
assert result.global_offset == (0, 1) # (pp_rank, tp_rank)
assert result.replica_id == 3 # dp_rank


class TestDeleteExtraState:
"""Tests for delete_extra_state utility added for cleanup of extraneous keys."""
Expand Down