From f8edc88592742c9aa3fcf78c0b73989ef8a40474 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Thu, 11 Dec 2025 18:57:36 -0800 Subject: [PATCH 1/6] Fix cuda rng states Signed-off-by: Robin Zhang --- megatron/core/tensor_parallel/__init__.py | 2 + megatron/core/tensor_parallel/random.py | 35 +++++++ megatron/core/transformer/moe/moe_utils.py | 19 +--- megatron/training/arguments.py | 2 +- megatron/training/checkpointing.py | 22 ++++- .../unit_tests/tensor_parallel/test_random.py | 93 +++++++++++++++++++ 6 files changed, 153 insertions(+), 20 deletions(-) diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index afa53bdc6e1..00287074613 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -28,6 +28,7 @@ from .random import ( CheckpointWithoutOutput, checkpoint, + convert_cuda_rng_state, get_cuda_rng_tracker, get_data_parallel_rng_tracker_name, get_expert_parallel_rng_tracker_name, @@ -63,6 +64,7 @@ "scatter_to_sequence_parallel_region", # random.py "checkpoint", + "convert_cuda_rng_state", "get_cuda_rng_tracker", "model_parallel_cuda_manual_seed", "get_expert_parallel_rng_tracker_name", diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 396e5c54a2d..981e93037d3 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -111,6 +111,41 @@ def cb(): _lazy_call(cb) +def convert_cuda_rng_state( + state: Union[torch.Tensor, torch.Generator], to_graphable: bool = False +) -> Union[torch.Tensor, torch.Generator]: + """ + Convert the cuda rng state tensor to the graphable version, + or from the graphable version to the non-graphable tensor version. + """ + if to_graphable: + if isinstance(state, torch.Tensor): + # Convert to the graphable version. + # Store current rng state. + orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=False) + # Set rng state to the desired one + _set_cuda_rng_state(state, graph_safe=False) + # Get the graphable state + graphable_state = _get_cuda_rng_state(clone=True, graph_safe=True) + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state, graph_safe=False) + return graphable_state + elif isinstance(state, torch.Generator): + # already graphable, just return it. + return state + else: + raise ValueError(f"Invalid state type: {type(state)}") + else: + if isinstance(state, torch.Tensor): + # already non-graphable, just return it. + return state + elif isinstance(state, torch.Generator): + # Convert to the non-graphable tensor version. + return state.get_state() + else: + raise ValueError(f"Invalid state type: {type(state)}") + + def get_expert_parallel_rng_tracker_name(): """Get the expert parallel rng tracker name""" global _EXPERT_PARALLEL_RNG_TRACKER_NAME diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 8bab8d70065..a17d913f6da 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -10,6 +10,7 @@ from megatron.core.fp4_utils import get_fp4_align_size from megatron.core.fp8_utils import get_fp8_align_size from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name from megatron.core.transformer.cuda_graphs import is_graph_capturing from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig @@ -921,26 +922,14 @@ class RandomSTE(torch.autograd.Function): This is used to generate random logits of router for load-balanced benchmark. """ - generator = None - random_logits = None - @staticmethod def forward(ctx, logits): """ Forward pass returns random logits with rank-specific seed. """ - if is_graph_capturing() and RandomSTE.random_logits is not None: - return RandomSTE.random_logits - - if RandomSTE.generator is None: - global_rank = torch.distributed.get_rank() - base_seed = 42 - seed = base_seed + global_rank - RandomSTE.generator = torch.Generator(device=logits.device) - RandomSTE.generator.manual_seed(seed) - - RandomSTE.random_logits = logits.clone().normal_(generator=RandomSTE.generator) - return RandomSTE.random_logits + with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()): + random_logits = logits.clone().normal_() + return random_logits @staticmethod def backward(ctx, grad_output): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 7c9e4531c6d..ed107525ca2 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1276,7 +1276,7 @@ def validate_args(args, defaults={}): # CUDA Graphs if args.cuda_graph_impl != "none": - if args.transformer_impl == 'transformer_engine' and not args.te_rng_tracker: + if not args.te_rng_tracker: args.te_rng_tracker = True warn_rank_0("te_rng_tracker is not enabled, enabling it for CUDA graphs.", args.rank) assert ( diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 48a2025fa63..1af23f48ce0 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1766,6 +1766,15 @@ def load_model_state_dict(module, state_dict, strict: bool): # rng states. if not release and not args.finetune and not args.no_load_rng and not ignore_rng_state: try: + cuda_rng_tracker = tensor_parallel.get_cuda_rng_tracker() + try: + from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker + except ImportError: + TECudaRNGStatesTracker = None + use_cudagraphable_rng = ( + TECudaRNGStatesTracker is not None + and isinstance(cuda_rng_tracker, TECudaRNGStatesTracker) + ) or getattr(cuda_rng_tracker, 'use_cudagraphable_rng', False) if 'rng_state' in state_dict: if args.ckpt_format == "fsdp_dtensor": # FSDP DTensor checkpoints store rng_state in a different format. @@ -1791,8 +1800,10 @@ def load_model_state_dict(module, state_dict, strict: bool): # Check for empty states array if not rng_state['rng_tracker_states']: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - rng_state['rng_tracker_states']) + rng_tracker_states = { + k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=use_cudagraphable_rng) + for k, v in rng_state['rng_tracker_states'].items() + } else: # backward compatability random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) @@ -1801,8 +1812,11 @@ def load_model_state_dict(module, state_dict, strict: bool): # Check for empty states array if not state_dict['rng_tracker_states']: raise KeyError - tensor_parallel.get_cuda_rng_tracker().set_states( - state_dict['rng_tracker_states']) + rng_tracker_states = { + k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=use_cudagraphable_rng) + for k, v in state_dict['rng_tracker_states'].items() + } + cuda_rng_tracker.set_states(rng_tracker_states) except KeyError: print_rank_0('Unable to load rng state from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py index 47b607b8795..b4ca9e752b7 100644 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ b/tests/unit_tests/tensor_parallel/test_random.py @@ -5,6 +5,7 @@ CheckpointWithoutOutput, CudaRNGStatesTracker, checkpoint, + convert_cuda_rng_state, get_cuda_rng_tracker, model_parallel_cuda_manual_seed, ) @@ -33,6 +34,98 @@ def test_cuda_rng_states_tracker(): assert torch.equal(rng_tracker.get_states()['state2'], rng_state) +def test_convert_cuda_rng_state(): + ## Get the default rng state + torch.cuda.manual_seed(999) + randn = torch.randn(10, device="cuda") + rng_state = torch.cuda.get_rng_state() + + try: + from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker + except ImportError: + TECudaRNGStatesTracker = None + + ## from non-graphable RNG to graphable RNG + # get state from non-graphable RNG + tracker = CudaRNGStatesTracker(use_cudagraphable_rng=False) + tracker.add("state1", 123) + for i in range(3): + with tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + state = convert_cuda_rng_state(tracker.states_["state1"], to_graphable=True) + rand_tensors = [] + for i in range(3): + with tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + rand_tensors.append(randn) + + # set state to local graph RNG + cudagraphable_tracker = CudaRNGStatesTracker(use_cudagraphable_rng=True) + cudagraphable_tracker.set_states({"state1": state.clone_state()}) + for i in range(3): + with cudagraphable_tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + # set state to TE RNG + if TECudaRNGStatesTracker is not None: + te_tracker = TECudaRNGStatesTracker() + te_tracker.set_states({"state1": state}) + for i in range(3): + with te_tracker.fork("state1"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + ## from graphable RNG to non-graphable RNG + # get state from graphable RNG + cudagraphable_tracker = CudaRNGStatesTracker(use_cudagraphable_rng=True) + cudagraphable_tracker.add("state2", 123) + for i in range(3): + with cudagraphable_tracker.fork("state2"): + randn = torch.randn(10, device="cuda") + state = convert_cuda_rng_state(cudagraphable_tracker.states_["state2"], to_graphable=False) + rand_tensors = [] + for i in range(3): + with cudagraphable_tracker.fork("state2"): + randn = torch.randn(10, device="cuda") + rand_tensors.append(randn) + + # set state to non-graphable RNG + tracker = CudaRNGStatesTracker(use_cudagraphable_rng=False) + tracker.set_states({"state2": state}) + for i in range(3): + with tracker.fork("state2"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + ## from TE RNG to non-graphable RNG + if TECudaRNGStatesTracker is not None: + # get state from TE RNG + cudagraphable_tracker = TECudaRNGStatesTracker() + cudagraphable_tracker.add("state3", 123) + for i in range(3): + with cudagraphable_tracker.fork("state3"): + randn = torch.randn(10, device="cuda") + state = convert_cuda_rng_state(cudagraphable_tracker.states_["state3"], to_graphable=False) + rand_tensors = [] + for i in range(3): + with cudagraphable_tracker.fork("state3"): + randn = torch.randn(10, device="cuda") + rand_tensors.append(randn) + + # set state to non-graphable RNG + tracker = CudaRNGStatesTracker(use_cudagraphable_rng=False) + tracker.set_states({"state3": state}) + for i in range(3): + with tracker.fork("state3"): + randn = torch.randn(10, device="cuda") + assert torch.equal(randn, rand_tensors[i]) + + ## After all tests, check if the default rng state is still the same. + rng_state_final = torch.cuda.get_rng_state() + assert torch.equal(rng_state, rng_state_final) + + def test_model_parallel_cuda_manual_seed(): Utils.initialize_model_parallel(4, 2) model_parallel_cuda_manual_seed(0, force_reset_rng=True) From 4811190ef49e26bfafe917321265f20c001543dd Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Thu, 11 Dec 2025 20:12:02 -0800 Subject: [PATCH 2/6] update copyright Signed-off-by: Robin Zhang --- tests/unit_tests/tensor_parallel/test_random.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py index b4ca9e752b7..1a663532230 100644 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ b/tests/unit_tests/tensor_parallel/test_random.py @@ -1,3 +1,5 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import pytest import torch From be130ecd3d78c7dba10f72445d0176b0dc0146e4 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Fri, 12 Dec 2025 06:37:30 -0800 Subject: [PATCH 3/6] graphsafe get/set all_rng_states Signed-off-by: Robin Zhang --- megatron/core/tensor_parallel/random.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 981e93037d3..dcaf684901b 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -412,10 +412,24 @@ def model_parallel_cuda_manual_seed( _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed) +def _is_graph_safe_cuda_rng_tracker(cuda_rng_tracker): + """Check if the cuda rng tracker is graph safe version.""" + if HAVE_TE and is_te_min_version("1.5.0"): + from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker + + if isinstance(cuda_rng_tracker, TECudaRNGStatesTracker): + return True + if getattr(cuda_rng_tracker, "use_cudagraphable_rng", False): + return True + return False + + def _get_all_rng_states(): """Get all the rng states.""" cpu_rng_state = torch.get_rng_state() - cuda_rng_state = _get_cuda_rng_state() + cuda_rng_state = _get_cuda_rng_state( + graph_safe=_is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) + ) cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() return cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker @@ -423,7 +437,9 @@ def _get_all_rng_states(): def _set_all_rng_states(cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker): """Set all the rng states.""" torch.set_rng_state(cpu_rng_state) - _set_cuda_rng_state(cuda_rng_state) + _set_cuda_rng_state( + cuda_rng_state, graph_safe=_is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) + ) get_cuda_rng_tracker().set_states(cuda_rng_state_tracker) From bbcdde1b5696e09b0ed9f2c7919f72f79bb2bbaf Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Fri, 12 Dec 2025 07:12:25 -0800 Subject: [PATCH 4/6] update Signed-off-by: Robin Zhang --- megatron/core/tensor_parallel/__init__.py | 2 ++ megatron/core/tensor_parallel/random.py | 6 +++--- megatron/training/checkpointing.py | 13 +++---------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index 00287074613..e629e5982b1 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -32,6 +32,7 @@ get_cuda_rng_tracker, get_data_parallel_rng_tracker_name, get_expert_parallel_rng_tracker_name, + is_graph_safe_cuda_rng_tracker, model_parallel_cuda_manual_seed, ) from .utils import ( @@ -68,6 +69,7 @@ "get_cuda_rng_tracker", "model_parallel_cuda_manual_seed", "get_expert_parallel_rng_tracker_name", + "is_graph_safe_cuda_rng_tracker", "CheckpointWithoutOutput", # utils.py "split_tensor_along_last_dim", diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index dcaf684901b..c76eb387ac5 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -412,7 +412,7 @@ def model_parallel_cuda_manual_seed( _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed) -def _is_graph_safe_cuda_rng_tracker(cuda_rng_tracker): +def is_graph_safe_cuda_rng_tracker(cuda_rng_tracker): """Check if the cuda rng tracker is graph safe version.""" if HAVE_TE and is_te_min_version("1.5.0"): from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker @@ -428,7 +428,7 @@ def _get_all_rng_states(): """Get all the rng states.""" cpu_rng_state = torch.get_rng_state() cuda_rng_state = _get_cuda_rng_state( - graph_safe=_is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) + graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) ) cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() return cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker @@ -438,7 +438,7 @@ def _set_all_rng_states(cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker): """Set all the rng states.""" torch.set_rng_state(cpu_rng_state) _set_cuda_rng_state( - cuda_rng_state, graph_safe=_is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) + cuda_rng_state, graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker()) ) get_cuda_rng_tracker().set_states(cuda_rng_state_tracker) diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 1af23f48ce0..19206312b67 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1767,14 +1767,7 @@ def load_model_state_dict(module, state_dict, strict: bool): if not release and not args.finetune and not args.no_load_rng and not ignore_rng_state: try: cuda_rng_tracker = tensor_parallel.get_cuda_rng_tracker() - try: - from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker - except ImportError: - TECudaRNGStatesTracker = None - use_cudagraphable_rng = ( - TECudaRNGStatesTracker is not None - and isinstance(cuda_rng_tracker, TECudaRNGStatesTracker) - ) or getattr(cuda_rng_tracker, 'use_cudagraphable_rng', False) + graph_safe_rng = tensor_parallel.is_graph_safe_cuda_rng_tracker(cuda_rng_tracker) if 'rng_state' in state_dict: if args.ckpt_format == "fsdp_dtensor": # FSDP DTensor checkpoints store rng_state in a different format. @@ -1801,7 +1794,7 @@ def load_model_state_dict(module, state_dict, strict: bool): if not rng_state['rng_tracker_states']: raise KeyError rng_tracker_states = { - k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=use_cudagraphable_rng) + k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng) for k, v in rng_state['rng_tracker_states'].items() } else: # backward compatability @@ -1813,7 +1806,7 @@ def load_model_state_dict(module, state_dict, strict: bool): if not state_dict['rng_tracker_states']: raise KeyError rng_tracker_states = { - k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=use_cudagraphable_rng) + k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng) for k, v in state_dict['rng_tracker_states'].items() } cuda_rng_tracker.set_states(rng_tracker_states) From 4e6f2dc68b4e75bf8e765815555b179252598465 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Sun, 14 Dec 2025 22:01:29 -0800 Subject: [PATCH 5/6] nested fork and cg rng context Signed-off-by: Robin Zhang --- megatron/core/tensor_parallel/random.py | 23 +++++++-- megatron/core/transformer/cuda_graphs.py | 7 ++- megatron/training/arguments.py | 5 +- .../unit_tests/tensor_parallel/test_random.py | 50 +++++++++++++++++++ 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index c76eb387ac5..617d2803c12 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -196,6 +196,10 @@ def reset(self): # Seeds are just for book keeping and ensure no seed is set twice. self.seeds_ = set() + # Name of the rng state currently being used in the generator. + # The default one is "default-rng" and won't be pushed to the self.states_ dictionary. + self._current_state_name = "default-rng" + def get_states(self): """Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.""" @@ -242,10 +246,14 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): # Check if we have added the state if name not in self.states_: raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. + # Store current rng state and name. Store in self.states_ if it's not the default state. orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng) - # Set rng state to the desired one + orig_state_name = self._current_state_name + if orig_state_name != "default-rng": + self.states_[orig_state_name] = orig_cuda_rng_state + # Set rng state and name to the desired one. _set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng) + self._current_state_name = name # Record cpu RNG state cpu_rng_state = torch.get_rng_state() # Do the stuff we wanted to do. @@ -255,10 +263,19 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): # Throw a warning if cpu RNG state changed if not torch.all(cpu_rng_state == torch.get_rng_state()).item(): logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context') + # Check if the current state name is the same as the desired state name. + if self._current_state_name != name: + raise Exception( + f'current state name {self._current_state_name} is not the same as the desired ' + f'state name {name}.' + ) # Update the current rng state for later use. self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng) - # And set the state to the original state we started with. + # And set the state and name to the original state we started with. + if orig_state_name != "default-rng": + orig_cuda_rng_state = self.states_[orig_state_name] _set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng) + self._current_state_name = orig_state_name # RNG tracker object. diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index bcc90dc1240..ced850bc9f9 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1873,7 +1873,12 @@ def create_cudagraphs(self): # Prepare CUDA Graph capturing input data and call `make_graphed_callables`. sample_args, kwargs = self._get_cuda_graph_input_data() - graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs) + if self.config.sequence_parallel: + rng_context = get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + with rng_context: + graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs) # Push the captured graphs to the corresponding TransformerBlock. num_layers_accumulated = 0 diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index ed107525ca2..ec8432666c6 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1276,7 +1276,10 @@ def validate_args(args, defaults={}): # CUDA Graphs if args.cuda_graph_impl != "none": - if not args.te_rng_tracker: + if ( + "transformer_engine" in (args.transformer_impl, args.cuda_graph_impl) + and not args.te_rng_tracker + ): args.te_rng_tracker = True warn_rank_0("te_rng_tracker is not enabled, enabling it for CUDA graphs.", args.rank) assert ( diff --git a/tests/unit_tests/tensor_parallel/test_random.py b/tests/unit_tests/tensor_parallel/test_random.py index 1a663532230..a15ad83cb90 100644 --- a/tests/unit_tests/tensor_parallel/test_random.py +++ b/tests/unit_tests/tensor_parallel/test_random.py @@ -36,6 +36,56 @@ def test_cuda_rng_states_tracker(): assert torch.equal(rng_tracker.get_states()['state2'], rng_state) +@pytest.mark.parametrize("use_cudagraphable_rng", [True, False]) +def test_double_fork_cuda_rng_states_tracker(use_cudagraphable_rng): + rng_tracker = CudaRNGStatesTracker(use_cudagraphable_rng=use_cudagraphable_rng) + rng_tracker.add("state1", 1234) + rng_tracker.add("state2", 5678) + randn_double_fork_1 = [] + randn_double_fork_2 = [] + with rng_tracker.fork("state1"): + randn_double_fork_1.append(torch.randn(10, device="cuda")) + with rng_tracker.fork("state2"): + randn_double_fork_2.append(torch.randn(10, device="cuda")) + with rng_tracker.fork("state1"): + randn_double_fork_1.append(torch.randn(10, device="cuda")) + randn_double_fork_2.append(torch.randn(10, device="cuda")) + randn_double_fork_1.append(torch.randn(10, device="cuda")) + if use_cudagraphable_rng: + double_fork_state1 = rng_tracker.get_states()["state1"].get_state() + double_fork_state2 = rng_tracker.get_states()["state2"].get_state() + else: + double_fork_state1 = rng_tracker.get_states()["state1"] + double_fork_state2 = rng_tracker.get_states()["state2"] + + rng_tracker.reset() + rng_tracker.add("state1", 1234) + rng_tracker.add("state2", 5678) + randn_single_fork_1 = [] + randn_single_fork_2 = [] + with rng_tracker.fork("state1"): + randn_single_fork_1.append(torch.randn(10, device="cuda")) + randn_single_fork_1.append(torch.randn(10, device="cuda")) + randn_single_fork_1.append(torch.randn(10, device="cuda")) + with rng_tracker.fork("state2"): + randn_single_fork_2.append(torch.randn(10, device="cuda")) + randn_single_fork_2.append(torch.randn(10, device="cuda")) + if use_cudagraphable_rng: + single_fork_state1 = rng_tracker.get_states()["state1"].get_state() + single_fork_state2 = rng_tracker.get_states()["state2"].get_state() + else: + single_fork_state1 = rng_tracker.get_states()["state1"] + single_fork_state2 = rng_tracker.get_states()["state2"] + + assert torch.equal(randn_double_fork_1[0], randn_single_fork_1[0]) + assert torch.equal(randn_double_fork_1[1], randn_single_fork_1[1]) + assert torch.equal(randn_double_fork_1[2], randn_single_fork_1[2]) + assert torch.equal(randn_double_fork_2[0], randn_single_fork_2[0]) + assert torch.equal(randn_double_fork_2[1], randn_single_fork_2[1]) + assert torch.equal(double_fork_state1, single_fork_state1) + assert torch.equal(double_fork_state2, single_fork_state2) + + def test_convert_cuda_rng_state(): ## Get the default rng state torch.cuda.manual_seed(999) From 1f0fc35492a5365c97240c8d079f7b9aaae9455a Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Wed, 17 Dec 2025 03:03:26 -0800 Subject: [PATCH 6/6] internal api Signed-off-by: Robin Zhang --- megatron/core/transformer/moe/moe_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index a17d913f6da..28cff06f5ec 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -14,6 +14,7 @@ from megatron.core.transformer.cuda_graphs import is_graph_capturing from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import internal_api try: import transformer_engine as te # pylint: disable=unused-import @@ -914,6 +915,7 @@ def get_moe_layer_wise_logging_tracker(): return _MOE_LAYER_WISE_LOGGING_TRACKER +@internal_api class RandomSTE(torch.autograd.Function): """ Straight-Through Estimator(STE) function that returns random values