Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch
from torch.utils._pytree import tree_map as tree_map_pyt

from megatron.core import parallel_state
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel.random import (
CudaRNGStatesTracker,
get_all_rng_states,
Expand Down Expand Up @@ -1387,7 +1387,12 @@ class CudaGraphManager(torch.nn.Module):
global_mempool = None

def __init__(
self, config: TransformerConfig, base_module=None, function_name=None, need_backward=True
self,
config: TransformerConfig,
base_module=None,
function_name=None,
need_backward=True,
pg_collection=None,
):
super().__init__()
"""Creates a CudaGraphManager to manage CUDA graphs for a Megatron module.
Expand All @@ -1396,6 +1401,9 @@ def __init__(
config: TransformerConfig object containing CUDA graph settings for memory
pooling, graph retention, gradient accumulation, FP8/FP4, and warmup steps.
"""
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
rng_tracker = get_cuda_rng_tracker()
self.need_backward = need_backward

Expand Down Expand Up @@ -1441,7 +1449,7 @@ def wrapped_func(*args, **kwargs):
# Without pipeline parallelism, microbatches execute one at a time.
# Therefore modules will always execute in the same order, so cudagraphs
# can both be reused and share a single mempool.
self.reuse_cudagraphs = parallel_state.get_pipeline_model_parallel_world_size() == 1
self.reuse_cudagraphs = self.pg_collection.pp.size() == 1
if CudaGraphManager.global_mempool is None:
CudaGraphManager.global_mempool = torch.cuda.graph_pool_handle()
# Cudagraph stream capture requires no operations on the default stream prior to the
Expand Down Expand Up @@ -1689,7 +1697,9 @@ class TECudaGraphHelper:
parameters that are covered by cudagraphs.
"""

def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]):
def __init__(
self, model, config, seq_length, micro_batch_size, optimizers=[], pg_collection=None
):
assert HAVE_TE_GRAPHS, "CUDA Graphs are not supported without TE."
assert (
config.cuda_graph_impl == "transformer_engine"
Expand Down Expand Up @@ -1730,6 +1740,15 @@ def _discover_layers(self):
self.callables_per_chunk_is_mtp = []
self.flattened_callables = []
self.flattened_callables_is_mtp = []
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
self.tp_group = self.pg_collection.tp
self.dp_cp_group = self.pg_collection.dp_cp
self.pp_group = self.pg_collection.pp
from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator

self.p2p_communicator = P2PCommunicator(pp_group=self.pp_group, config=self.config)
for chunk_number, model_chunk in enumerate(self.model):
try:
chunk_with_decoder = get_attr_wrapped_model(
Expand All @@ -1739,8 +1758,8 @@ def _discover_layers(self):
num_graphable_layers = 0
log_on_each_pipeline_stage(
logger=logger,
tp_group=None,
dp_cp_group=None,
tp_group=self.tp_group,
dp_cp_group=self.dp_cp_group,
level=logging.DEBUG,
msg=f'Rank {torch.distributed.get_rank()}: '
f'No valid layer in model chunk {chunk_number}.',
Expand All @@ -1767,8 +1786,8 @@ def _discover_layers(self):
callables_is_mtp.append(True)
log_on_each_pipeline_stage(
logger=logger,
tp_group=None,
dp_cp_group=None,
tp_group=self.tp_group,
dp_cp_group=self.dp_cp_group,
level=logging.DEBUG,
msg=f'Rank {torch.distributed.get_rank()}: '
f'{num_decoder_layers} decoder layers and {num_mtp_layers} MTP layers in '
Expand All @@ -1790,8 +1809,8 @@ def _discover_layers(self):

log_on_each_pipeline_stage(
logger=logger,
tp_group=None,
dp_cp_group=None,
tp_group=self.tp_group,
dp_cp_group=self.dp_cp_group,
level=logging.INFO,
msg=f'Rank {torch.distributed.get_rank()}: '
f'{len(self.flattened_callables)} graphable layers.',
Expand Down Expand Up @@ -2048,6 +2067,23 @@ def get_rotary_pos_emb(transformer_module, transformer_input):

return sample_args, sample_kwargs

def _get_amax_reduction_group(self, with_context_parallel=False, tp_only_amax_red=False):
"""Get the FP8 amax reduction group the caller rank belongs to."""
if with_context_parallel:
if not tp_only_amax_red:
assert self.pg_collection.tp_dp_cp is not None
return self.pg_collection.tp_dp_cp
else:
assert self.pg_collection.tp_cp is not None
return self.pg_collection.tp_cp
else:
if not tp_only_amax_red:
assert self.pg_collection.tp_dp is not None
return self.pg_collection.tp_dp
else:
assert self.pg_collection.tp is not None
return self.pg_collection.tp

def _get_cuda_graph_input_data(self):
"""
Create the CUDA Graph capturing input data.
Expand All @@ -2061,10 +2097,7 @@ def _get_cuda_graph_input_data(self):
)

# If PP is not enabled, we only need to capture one microbatch.
if (
parallel_state.get_pipeline_model_parallel_world_size() == 1
and not self.config.overlap_moe_expert_parallel_comm
):
if self.pp_group.size() == 1 and not self.config.overlap_moe_expert_parallel_comm:
assert (
self.num_model_chunks == 1
), "If PP is not enabled, there should be only one model chunk."
Expand All @@ -2076,7 +2109,8 @@ def _get_cuda_graph_input_data(self):
self.num_microbatches,
self.num_model_chunks,
self.config.microbatch_group_size_per_vp_stage,
False,
forward_only=False,
p2p_communicator=self.p2p_communicator,
)
schedule_table = get_schedule_table(
self.num_microbatches,
Expand All @@ -2088,8 +2122,8 @@ def _get_cuda_graph_input_data(self):
)
log_on_each_pipeline_stage(
logger=logger,
tp_group=None,
dp_cp_group=None,
tp_group=self.tp_group,
dp_cp_group=self.dp_cp_group,
level=logging.DEBUG,
msg=f'Rank {torch.distributed.get_rank()}: ORDER {order}',
)
Expand All @@ -2114,8 +2148,8 @@ def _get_cuda_graph_input_data(self):
self.num_microbatches = len(_order_without_wgrad) // self.num_model_chunks // 2
log_on_each_pipeline_stage(
logger=logger,
tp_group=None,
dp_cp_group=None,
tp_group=self.tp_group,
dp_cp_group=self.dp_cp_group,
level=logging.DEBUG,
msg=f'Rank {torch.distributed.get_rank()}: '
f'ORDER after overlap_moe_expert_parallel_comm {order}',
Expand Down Expand Up @@ -2186,8 +2220,12 @@ def _get_fp8_enabled():
get_fp8_recipe(self.config) if self.config.fp8 else get_fp4_recipe(self.config)
)
kwargs['fp8_weight_caching'] = True
if is_te_min_version("1.14.0") and parallel_state.model_parallel_is_initialized():
kwargs['fp8_group'] = parallel_state.get_amax_reduction_group(
if (
is_te_min_version("1.14.0")
and self.pg_collection is not None
and self.pg_collection.tp is not None
):
kwargs['fp8_group'] = self._get_amax_reduction_group(
with_context_parallel=True, tp_only_amax_red=self.config.tp_only_amax_red
)
else:
Expand Down Expand Up @@ -2310,8 +2348,8 @@ def delete_cuda_graphs(self):

log_on_each_pipeline_stage(
logger=logger,
tp_group=None,
dp_cp_group=None,
tp_group=self.tp_group,
dp_cp_group=self.dp_cp_group,
level=logging.INFO,
msg=f'Rank {torch.distributed.get_rank()}: '
f'{graphs_reset} graphs deleted with explicit reset, '
Expand Down
Loading