diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index afa53bdc6e1..e629e5982b1 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -28,9 +28,11 @@ from .random import ( CheckpointWithoutOutput, checkpoint, + convert_cuda_rng_state, get_cuda_rng_tracker, get_data_parallel_rng_tracker_name, get_expert_parallel_rng_tracker_name, + is_graph_safe_cuda_rng_tracker, model_parallel_cuda_manual_seed, ) from .utils import ( @@ -63,9 +65,11 @@ "scatter_to_sequence_parallel_region", # random.py "checkpoint", + "convert_cuda_rng_state", "get_cuda_rng_tracker", "model_parallel_cuda_manual_seed", "get_expert_parallel_rng_tracker_name", + "is_graph_safe_cuda_rng_tracker", "CheckpointWithoutOutput", # utils.py "split_tensor_along_last_dim", diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 396e5c54a2d..617d2803c12 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -111,6 +111,41 @@ def cb(): _lazy_call(cb) +def convert_cuda_rng_state( + state: Union[torch.Tensor, torch.Generator], to_graphable: bool = False +) -> Union[torch.Tensor, torch.Generator]: + """ + Convert the cuda rng state tensor to the graphable version, + or from the graphable version to the non-graphable tensor version. + """ + if to_graphable: + if isinstance(state, torch.Tensor): + # Convert to the graphable version. + # Store current rng state. + orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + # Set rng state to the desired one + _set_cuda_rng_state(state, graph_safe=False) + # Get the graphable state + graphable_state = _get_cuda_rng_state(clone=True, graph_safe=True) + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state, graph_safe=False) + return graphable_state + elif isinstance(state, torch.Generator): + # already graphable, just return it. + return state + else: + raise ValueError(f"Invalid state type: {type(state)}") + else: + if isinstance(state, torch.Tensor): + # already non-graphable, just return it. + return state + elif isinstance(state, torch.Generator): + # Convert to the non-graphable tensor version. + return state.get_state() + else: + raise ValueError(f"Invalid state type: {type(state)}") + + def get_expert_parallel_rng_tracker_name(): """Get the expert parallel rng tracker name""" global _EXPERT_PARALLEL_RNG_TRACKER_NAME @@ -161,6 +196,10 @@ def reset(self): # Seeds are just for book keeping and ensure no seed is set twice. self.seeds_ = set() + # Name of the rng state currently being used in the generator. + # The default one is "default-rng" and won't be pushed to the self.states_ dictionary. + self._current_state_name = "default-rng" + def get_states(self): """Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.""" @@ -207,10 +246,14 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): # Check if we have added the state if name not in self.states_: raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. + # Store current rng state and name. Store in self.states_ if it's not the default state. orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng) - # Set rng state to the desired one + orig_state_name = self._current_state_name + if orig_state_name != "default-rng": + self.states_[orig_state_name] = orig_cuda_rng_state + # Set rng state and name to the desired one. _set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng) + self._current_state_name = name # Record cpu RNG state cpu_rng_state = torch.get_rng_state() # Do the stuff we wanted to do. @@ -220,10 +263,19 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): # Throw a warning if cpu RNG state changed if not torch.all(cpu_rng_state == torch.get_rng_state()).item(): logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context') + # Check if the current state name is the same as the desired state name. + if self._current_state_name != name: + raise Exception( + f'current state name {self._current_state_name} is not the same as the desired ' + f'state name {name}.' + ) # Update the current rng state for later use. self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng) - # And set the state to the original state we started with. + # And set the state and name to the original state we started with. + if orig_state_name != "default-rng": + orig_cuda_rng_state = self.states_[orig_state_name] _set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng) + self._current_state_name = orig_state_name # RNG tracker object. @@ -377,10 +429,24 @@ def model_parallel_cuda_manual_seed( _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed) +def is_graph_safe_cuda_rng_tracker(cuda_rng_tracker): + """Check if the cuda rng tracker is graph safe version.""" + if HAVE_TE and is_te_min_version("1.5.0"): + from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker + + if isinstance(cuda_rng_tracker, TECudaRNGStatesTracker): + return True + if getattr(cuda_rng_tracker, "use_cudagraphable_rng", False): + return True + return False + + def _get_all_rng_states(): """Get all the rng states.""" cpu_rng_state = torch.get_rng_state() - cuda_rng_state = _get_cuda_rng_state() + cuda_rng_state = _get_cuda_rng_state( + graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) + ) cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() return cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker @@ -388,7 +454,9 @@ def _get_all_rng_states(): def _set_all_rng_states(cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker): """Set all the rng states.""" torch.set_rng_state(cpu_rng_state) - _set_cuda_rng_state(cuda_rng_state) + _set_cuda_rng_state( + cuda_rng_state, graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) + ) get_cuda_rng_tracker().set_states(cuda_rng_state_tracker) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 6f75d67549e..27e6c65c738 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1907,7 +1907,12 @@ def create_cudagraphs(self): # Prepare CUDA Graph capturing input data and call `make_graphed_callables`. sample_args, kwargs = self._get_cuda_graph_input_data() - graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs) + if self.config.sequence_parallel: + rng_context = get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + with rng_context: + graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs) # Push the captured graphs to the corresponding TransformerBlock. num_layers_accumulated = 0 diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 8bab8d70065..28cff06f5ec 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -10,9 +10,11 @@ from megatron.core.fp4_utils import get_fp4_align_size from megatron.core.fp8_utils import get_fp8_align_size from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name from megatron.core.transformer.cuda_graphs import is_graph_capturing from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import internal_api try: import transformer_engine as te # pylint: disable=unused-import @@ -913,6 +915,7 @@ def get_moe_layer_wise_logging_tracker(): return _MOE_LAYER_WISE_LOGGING_TRACKER +@internal_api class RandomSTE(torch.autograd.Function): """ Straight-Through Estimator(STE) function that returns random values @@ -921,26 +924,14 @@ class RandomSTE(torch.autograd.Function): This is used to generate random logits of router for load-balanced benchmark. """ - generator = None - random_logits = None - @staticmethod def forward(ctx, logits): """ Forward pass returns random logits with rank-specific seed. """ - if is_graph_capturing() and RandomSTE.random_logits is not None: - return RandomSTE.random_logits - - if RandomSTE.generator is None: - global_rank = torch.distributed.get_rank() - base_seed = 42 - seed = base_seed + global_rank - RandomSTE.generator = torch.Generator(device=logits.device) - RandomSTE.generator.manual_seed(seed) - - RandomSTE.random_logits = logits.clone().normal_(generator=RandomSTE.generator) - return RandomSTE.random_logits + with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()): + random_logits = logits.clone().normal_() + return random_logits @staticmethod def backward(ctx, grad_output): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 70d1e4b1306..c157d062c53 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1277,7 +1277,10 @@ def validate_args(args, defaults={}): # CUDA Graphs if args.cuda_graph_impl != "none": - if args.transformer_impl == 'transformer_engine' and not args.te_rng_tracker: + if ( + "transformer_engine" in (args.transformer_impl, args.cuda_graph_impl) + and not args.te_rng_tracker + ): args.te_rng_tracker = True warn_rank_0("te_rng_tracker is not enabled, enabling it for CUDA graphs.", args.rank) assert ( diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 48a2025fa63..19206312b67 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1766,6 +1766,8 @@ def load_model_state_dict(module, state_dict, strict: bool): # rng states. if not release and not args.finetune and not args.no_load_rng and not ignore_rng_state: try: + cuda_rng_tracker = tensor_parallel.get_cuda_rng_tracker() + graph_safe_rng = tensor_parallel.is_graph_safe_cuda_rng_tracker(cuda_rng_tracker) if 'rng_state' in state_dict: if args.ckpt_format == "fsdp_dtensor": # FSDP DTensor checkpoints store rng_state in a different format. @@ -1791,8 +1793,10 @@ def load_model_state_dict(module, state_dict, strict: bool): # Check for empty states array if not rng_state['rng_tracker_states']: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - rng_state['rng_tracker_states']) + rng_tracker_states = { + k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng) + for k, v in rng_state['rng_tracker_states'].items() + } else: # backward compatability random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) @@ -1801,8 +1805,11 @@ def load_model_state_dict(module, state_dict, strict: bool): # Check for empty states array if not state_dict['rng_tracker_states']: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - state_dict['rng_tracker_states']) + rng_tracker_states = { + k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng) + for k, v in state_dict['rng_tracker_states'].items() + } + cuda_rng_tracker.set_states(rng_tracker_states) except KeyError: print_rank_0('Unable to load rng state from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py index 47b607b8795..a15ad83cb90 100644 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ b/tests/unit_tests/tensor_parallel/test_random.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import pytest import torch @@ -5,6 +7,7 @@ CheckpointWithoutOutput, CudaRNGStatesTracker, checkpoint, + convert_cuda_rng_state, get_cuda_rng_tracker, model_parallel_cuda_manual_seed, ) @@ -33,6 +36,148 @@ def test_cuda_rng_states_tracker(): assert torch.equal(rng_tracker.get_states()['state2'], rng_state) +@pytest.mark.parametrize("use_cudagraphable_rng", [True, False]) +def test_double_fork_cuda_rng_states_tracker(use_cudagraphable_rng): + rng_tracker = CudaRNGStatesTracker(use_cudagraphable_rng=use_cudagraphable_rng) + rng_tracker.add("state1", 1234) + rng_tracker.add("state2", 5678) + randn_double_fork_1 = [] + randn_double_fork_2 = [] + with rng_tracker.fork("state1"): + randn_double_fork_1.append(torch.randn(10, device="cuda")) + with rng_tracker.fork("state2"): + randn_double_fork_2.append(torch.randn(10, device="cuda")) + with rng_tracker.fork("state1"): + randn_double_fork_1.append(torch.randn(10, device="cuda")) + randn_double_fork_2.append(torch.randn(10, device="cuda")) + randn_double_fork_1.append(torch.randn(10, device="cuda")) + if use_cudagraphable_rng: + double_fork_state1 = rng_tracker.get_states()["state1"].get_state() + double_fork_state2 = rng_tracker.get_states()["state2"].get_state() + else: + double_fork_state1 = rng_tracker.get_states()["state1"] + double_fork_state2 = rng_tracker.get_states()["state2"] + + rng_tracker.reset() + rng_tracker.add("state1", 1234) + rng_tracker.add("state2", 5678) + randn_single_fork_1 = [] + randn_single_fork_2 = [] + with rng_tracker.fork("state1"): + randn_single_fork_1.append(torch.randn(10, device="cuda")) + randn_single_fork_1.append(torch.randn(10, device="cuda")) + randn_single_fork_1.append(torch.randn(10, device="cuda")) + with rng_tracker.fork("state2"): + randn_single_fork_2.append(torch.randn(10, device="cuda")) + randn_single_fork_2.append(torch.randn(10, device="cuda")) + if use_cudagraphable_rng: + single_fork_state1 = rng_tracker.get_states()["state1"].get_state() + single_fork_state2 = rng_tracker.get_states()["state2"].get_state() + else: + single_fork_state1 = rng_tracker.get_states()["state1"] + single_fork_state2 = rng_tracker.get_states()["state2"] + + assert torch.equal(randn_double_fork_1[0], randn_single_fork_1[0]) + assert torch.equal(randn_double_fork_1[1], randn_single_fork_1[1]) + assert torch.equal(randn_double_fork_1[2], randn_single_fork_1[2]) + assert torch.equal(randn_double_fork_2[0], randn_single_fork_2[0]) + assert torch.equal(randn_double_fork_2[1], randn_single_fork_2[1]) + assert torch.equal(double_fork_state1, single_fork_state1) + assert torch.equal(double_fork_state2, single_fork_state2) + + +def test_convert_cuda_rng_state(): + ## Get the default rng state + torch.cuda.manual_seed(999) + randn = torch.randn(10, device="cuda") + rng_state = torch.cuda.get_rng_state() + + try: + from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker + except ImportError: + TECudaRNGStatesTracker = None + + ## from non-graphable RNG to graphable RNG + # get state from non-graphable RNG + tracker = CudaRNGStatesTracker(use_cudagraphable_rng=False) + tracker.add("state1", 123) + for i in range(3): + with tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + state = convert_cuda_rng_state(tracker.states_["state1"], to_graphable=True) + rand_tensors = [] + for i in range(3): + with tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + rand_tensors.append(randn) + + # set state to local graph RNG + cudagraphable_tracker = CudaRNGStatesTracker(use_cudagraphable_rng=True) + cudagraphable_tracker.set_states({"state1": state.clone_state()}) + for i in range(3): + with cudagraphable_tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + # set state to TE RNG + if TECudaRNGStatesTracker is not None: + te_tracker = TECudaRNGStatesTracker() + te_tracker.set_states({"state1": state}) + for i in range(3): + with te_tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + ## from graphable RNG to non-graphable RNG + # get state from graphable RNG + cudagraphable_tracker = CudaRNGStatesTracker(use_cudagraphable_rng=True) + cudagraphable_tracker.add("state2", 123) + for i in range(3): + with cudagraphable_tracker.fork("state2"): + randn = torch.randn(10, device="cuda") + state = convert_cuda_rng_state(cudagraphable_tracker.states_["state2"], to_graphable=False) + rand_tensors = [] + for i in range(3): + with cudagraphable_tracker.fork("state2"): + randn = torch.randn(10, device="cuda") + rand_tensors.append(randn) + + # set state to non-graphable RNG + tracker = CudaRNGStatesTracker(use_cudagraphable_rng=False) + tracker.set_states({"state2": state}) + for i in range(3): + with tracker.fork("state2"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + ## from TE RNG to non-graphable RNG + if TECudaRNGStatesTracker is not None: + # get state from TE RNG + cudagraphable_tracker = TECudaRNGStatesTracker() + cudagraphable_tracker.add("state3", 123) + for i in range(3): + with cudagraphable_tracker.fork("state3"): + randn = torch.randn(10, device="cuda") + state = convert_cuda_rng_state(cudagraphable_tracker.states_["state3"], to_graphable=False) + rand_tensors = [] + for i in range(3): + with cudagraphable_tracker.fork("state3"): + randn = torch.randn(10, device="cuda") + rand_tensors.append(randn) + + # set state to non-graphable RNG + tracker = CudaRNGStatesTracker(use_cudagraphable_rng=False) + tracker.set_states({"state3": state}) + for i in range(3): + with tracker.fork("state3"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + ## After all tests, check if the default rng state is still the same. + rng_state_final = torch.cuda.get_rng_state() + assert torch.equal(rng_state, rng_state_final) + + def test_model_parallel_cuda_manual_seed(): Utils.initialize_model_parallel(4, 2) model_parallel_cuda_manual_seed(0, force_reset_rng=True)