Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions megatron/core/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from .random import (
CheckpointWithoutOutput,
checkpoint,
convert_cuda_rng_state,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
get_expert_parallel_rng_tracker_name,
is_graph_safe_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .utils import (
Expand Down Expand Up @@ -64,9 +66,11 @@
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"convert_cuda_rng_state",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
"get_expert_parallel_rng_tracker_name",
"is_graph_safe_cuda_rng_tracker",
"CheckpointWithoutOutput",
# utils.py
"split_tensor_along_last_dim",
Expand Down
78 changes: 73 additions & 5 deletions megatron/core/tensor_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,41 @@ def cb():
_lazy_call(cb)


def convert_cuda_rng_state(
state: Union[torch.Tensor, torch.Generator], to_graphable: bool = False
) -> Union[torch.Tensor, torch.Generator]:
"""
Convert the cuda rng state tensor to the graphable version,
or from the graphable version to the non-graphable tensor version.
"""
if to_graphable:
if isinstance(state, torch.Tensor):
# Convert to the graphable version.
# Store current rng state.
orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
# Set rng state to the desired one
_set_cuda_rng_state(state, graph_safe=False)
# Get the graphable state
graphable_state = _get_cuda_rng_state(clone=True, graph_safe=True)
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state, graph_safe=False)
return graphable_state
elif isinstance(state, torch.Generator):
# already graphable, just return it.
return state
else:
raise ValueError(f"Invalid state type: {type(state)}")
else:
if isinstance(state, torch.Tensor):
# already non-graphable, just return it.
return state
elif isinstance(state, torch.Generator):
# Convert to the non-graphable tensor version.
return state.get_state()
else:
raise ValueError(f"Invalid state type: {type(state)}")


def get_expert_parallel_rng_tracker_name():
"""Get the expert parallel rng tracker name"""
global _EXPERT_PARALLEL_RNG_TRACKER_NAME
Expand Down Expand Up @@ -161,6 +196,10 @@ def reset(self):
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()

# Name of the rng state currently being used in the generator.
# The default one is "default-rng" and won't be pushed to the self.states_ dictionary.
self._current_state_name = "default-rng"

def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
Expand Down Expand Up @@ -207,10 +246,14 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
# Store current rng state and name. Store in self.states_ if it's not the default state.
orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# Set rng state to the desired one
orig_state_name = self._current_state_name
if orig_state_name != "default-rng":
self.states_[orig_state_name] = orig_cuda_rng_state
# Set rng state and name to the desired one.
_set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng)
self._current_state_name = name
# Record cpu RNG state
cpu_rng_state = torch.get_rng_state()
# Do the stuff we wanted to do.
Expand All @@ -220,10 +263,19 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
# Throw a warning if cpu RNG state changed
if not torch.all(cpu_rng_state == torch.get_rng_state()).item():
logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context')
# Check if the current state name is the same as the desired state name.
if self._current_state_name != name:
raise Exception(
f'current state name {self._current_state_name} is not the same as the desired '
f'state name {name}.'
)
# Update the current rng state for later use.
self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
# And set the state to the original state we started with.
# And set the state and name to the original state we started with.
if orig_state_name != "default-rng":
orig_cuda_rng_state = self.states_[orig_state_name]
_set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng)
self._current_state_name = orig_state_name


# RNG tracker object.
Expand Down Expand Up @@ -377,18 +429,34 @@ def model_parallel_cuda_manual_seed(
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)


def is_graph_safe_cuda_rng_tracker(cuda_rng_tracker):
"""Check if the cuda rng tracker is graph safe version."""
if HAVE_TE and is_te_min_version("1.5.0"):
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker

if isinstance(cuda_rng_tracker, TECudaRNGStatesTracker):
return True
if getattr(cuda_rng_tracker, "use_cudagraphable_rng", False):
return True
return False


def _get_all_rng_states():
"""Get all the rng states."""
cpu_rng_state = torch.get_rng_state()
cuda_rng_state = _get_cuda_rng_state()
cuda_rng_state = _get_cuda_rng_state(
graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker())
)
cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
return cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker


def _set_all_rng_states(cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker):
"""Set all the rng states."""
torch.set_rng_state(cpu_rng_state)
_set_cuda_rng_state(cuda_rng_state)
_set_cuda_rng_state(
cuda_rng_state, graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker())
)
get_cuda_rng_tracker().set_states(cuda_rng_state_tracker)


Expand Down
7 changes: 6 additions & 1 deletion megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,7 +1911,12 @@ def create_cudagraphs(self):

# Prepare CUDA Graph capturing input data and call `make_graphed_callables`.
sample_args, kwargs = self._get_cuda_graph_input_data()
graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs)
if self.config.sequence_parallel:
rng_context = get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
with rng_context:
graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs)

# Push the captured graphs to the corresponding TransformerBlock.
num_layers_accumulated = 0
Expand Down
20 changes: 5 additions & 15 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from megatron.core.fp4_utils import get_fp4_align_size
from megatron.core.fp8_utils import get_fp8_align_size
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -918,6 +919,7 @@ def get_moe_layer_wise_logging_tracker():
return _MOE_LAYER_WISE_LOGGING_TRACKER


@internal_api
class RandomSTE(torch.autograd.Function):
"""
Straight-Through Estimator(STE) function that returns random values
Expand All @@ -926,26 +928,14 @@ class RandomSTE(torch.autograd.Function):
This is used to generate random logits of router for load-balanced benchmark.
"""

generator = None
random_logits = None

@staticmethod
def forward(ctx, logits):
"""
Forward pass returns random logits with rank-specific seed.
"""
if is_graph_capturing() and RandomSTE.random_logits is not None:
return RandomSTE.random_logits

if RandomSTE.generator is None:
global_rank = torch.distributed.get_rank()
base_seed = 42
seed = base_seed + global_rank
RandomSTE.generator = torch.Generator(device=logits.device)
RandomSTE.generator.manual_seed(seed)

RandomSTE.random_logits = logits.clone().normal_(generator=RandomSTE.generator)
return RandomSTE.random_logits
with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()):
random_logits = logits.clone().normal_()
return random_logits

@staticmethod
def backward(ctx, grad_output):
Expand Down
5 changes: 4 additions & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,10 @@ def validate_args(args, defaults={}):

# CUDA Graphs
if args.cuda_graph_impl != "none":
if args.transformer_impl == 'transformer_engine' and not args.te_rng_tracker:
if (
"transformer_engine" in (args.transformer_impl, args.cuda_graph_impl)
and not args.te_rng_tracker
):
args.te_rng_tracker = True
warn_rank_0("te_rng_tracker is not enabled, enabling it for CUDA graphs.", args.rank)
assert (
Expand Down
15 changes: 11 additions & 4 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,8 @@ def load_model_state_dict(module, state_dict, strict: bool):
# rng states.
if not release and not args.finetune and not args.no_load_rng and not ignore_rng_state:
try:
cuda_rng_tracker = tensor_parallel.get_cuda_rng_tracker()
graph_safe_rng = tensor_parallel.is_graph_safe_cuda_rng_tracker(cuda_rng_tracker)
if 'rng_state' in state_dict:
if args.ckpt_format == "fsdp_dtensor":
# FSDP DTensor checkpoints store rng_state in a different format.
Expand All @@ -1791,8 +1793,10 @@ def load_model_state_dict(module, state_dict, strict: bool):
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
rng_tracker_states = {
k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng)
for k, v in rng_state['rng_tracker_states'].items()
}
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
Expand All @@ -1801,8 +1805,11 @@ def load_model_state_dict(module, state_dict, strict: bool):
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
rng_tracker_states = {
k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng)
for k, v in state_dict['rng_tracker_states'].items()
}
cuda_rng_tracker.set_states(rng_tracker_states)
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
Expand Down
Loading
Loading