diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 15c5adfc7a2..9b8ea30b52a 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -20,7 +20,7 @@ is_vp_last_stage, ) from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.cuda_graphs import create_cudagraphs +from megatron.core.transformer.cuda_graphs import create_cudagraphs, set_current_microbatch from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.utils import ( @@ -197,27 +197,6 @@ def custom_backward(output, grad_output): ) -def set_current_microbatch(model, microbatch_id): - """Set the current microbatch.""" - decoder_exists = True - model_with_decoder = None - try: - model_with_decoder = get_attr_wrapped_model( - model, "decoder", allow_none=False, return_model_obj=True - ) - except RuntimeError: - decoder_exists = False - if decoder_exists and model_with_decoder is not None: - for layer in model_with_decoder.decoder.layers: - layer.current_microbatch = microbatch_id - if hasattr(model_with_decoder, 'mtp'): - for layer in model_with_decoder.mtp.layers: - assert hasattr( - layer, 'mtp_model_layer' - ), f"MTP layer {layer} must have 'mtp_model_layer' attribute" - layer.mtp_model_layer.current_microbatch = microbatch_id - - def forward_step_calc_loss( model, output_tensor, diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 48a023e6ddc..50caafd411b 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -1701,14 +1701,21 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]): # Number of microbatches to capture. The value will be set in _get_cuda_graph_input_data(). self.num_microbatches = None - # Get callables with captureable layers. + self._discover_layers() + + # One helper object can only capture CUDA Graphs once. Use this flag to check if the graphs + # have been created. + self._graphs_created = False + + def _discover_layers(self): + """Discover captureable layers from the model and populate internal data structures.""" self.chunks_with_decoder = [] self.num_layers_per_chunk = [] self.callables_per_chunk = [] self.callables_per_chunk_is_mtp = [] self.flattened_callables = [] self.flattened_callables_is_mtp = [] - for chunk_number, model_chunk in enumerate(model): + for chunk_number, model_chunk in enumerate(self.model): try: chunk_with_decoder = get_attr_wrapped_model( model_chunk, 'decoder', allow_none=False, return_model_obj=True @@ -1733,13 +1740,13 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]): callables, callables_is_mtp = [], [] for layer_number in range(num_decoder_layers): layer = chunk_with_decoder.decoder.layers[layer_number] - if _layer_is_graphable(layer, config): + if _layer_is_graphable(layer, self.config): num_graphable_layers += 1 callables.append(layer) callables_is_mtp.append(False) for layer_number in range(num_mtp_layers): layer = chunk_with_decoder.mtp.layers[layer_number].mtp_model_layer - if _layer_is_graphable(layer, config): + if _layer_is_graphable(layer, self.config): num_graphable_layers += 1 callables.append(layer) callables_is_mtp.append(True) @@ -1775,10 +1782,6 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]): f'{len(self.flattened_callables)} graphable layers.', ) - # One helper object can only capture CUDA Graphs once. Use this flag to check if the graphs - # have been created. - self._graphs_created = False - def graphs_created(self): """ Returns whether the CUDA Graphs have been created. @@ -2430,3 +2433,316 @@ def get_layer_range(c_id): add_order(c_id, l_b, is_wgrad=True) return new_order, chunk_id_list + + +# --------------------------------------------------------------------------- +# set_current_microbatch: sets per-layer microbatch index for TE graph replay +# --------------------------------------------------------------------------- + + +def set_current_microbatch(model, microbatch_id): + """Set the current microbatch on all layers that use TE CUDA graph replay. + + current_microbatch is read by _te_cuda_graph_replay to select the + correct graph index. This helper is called from the pipeline-parallel + schedule before each forward step. + """ + decoder_exists = True + model_with_decoder = None + try: + model_with_decoder = get_attr_wrapped_model( + model, "decoder", allow_none=False, return_model_obj=True + ) + except RuntimeError: + decoder_exists = False + if decoder_exists and model_with_decoder is not None: + for layer in model_with_decoder.decoder.layers: + layer.current_microbatch = microbatch_id + if hasattr(model_with_decoder, 'mtp'): + for layer in model_with_decoder.mtp.layers: + assert hasattr( + layer, 'mtp_model_layer' + ), f"MTP layer {layer} must have 'mtp_model_layer' attribute" + layer.mtp_model_layer.current_microbatch = microbatch_id + + # Also set current_microbatch on vision encoder layers so that + # _te_cuda_graph_replay selects the correct graph index. Without this, + # vision layers always use graph 0 (since current_microbatch defaults to 0), + # causing all microbatch forwards to overwrite the same static buffers. + # When backward runs for earlier microbatches, the buffers contain stale + # data from later forwards, producing NaN gradients. + try: + model_with_vision = get_attr_wrapped_model( + model, "vision_model", allow_none=True, return_model_obj=True + ) + except RuntimeError: + model_with_vision = None + if model_with_vision is not None and hasattr(model_with_vision, 'vision_model'): + vision_model = model_with_vision.vision_model + if hasattr(vision_model, 'decoder') and hasattr(vision_model.decoder, 'layers'): + for layer in vision_model.decoder.layers: + layer.current_microbatch = microbatch_id + + +# --------------------------------------------------------------------------- +# Vision encoder CUDA graph helpers +# --------------------------------------------------------------------------- + + +def _wrap_graph_for_vision(graph_fn): + """Wrap a graphed callable to filter out None outputs. + + During make_graphed_callables warmup, vision encoder layers go through their + normal forward() path which returns (output, context=None). _te_cuda_graph_replay + asserts len(output) == 1 but gets 2 elements. This wrapper filters out None + values so replay sees (output,) instead of (output, None). + """ + + def wrapped(*args, **kwargs): + result = graph_fn(*args, **kwargs) + if isinstance(result, tuple): + filtered = tuple(r for r in result if r is not None) + return filtered if filtered else result + return result + + for attr in ('backward_dw', 'reset'): + if hasattr(graph_fn, attr): + setattr(wrapped, attr, getattr(graph_fn, attr)) + return wrapped + + +def get_vision_cuda_graph_seq_length(vision_config, default_seq_length: int = 4096) -> int: + """Calculate the sequence length for vision encoder CUDA graphs. + + For vision encoders, the sequence length depends on: + - max_vision_cuda_graph_seq_length: explicit maximum (if set) + - num_position_embeddings: maximum number of patches + - spatial_merge_size: pooling factor that reduces sequence length + + Args: + vision_config: The TransformerConfig for vision encoder + default_seq_length: Default sequence length if cannot be calculated + + Returns: + The sequence length to use for CUDA graph capture + """ + if ( + hasattr(vision_config, 'max_vision_cuda_graph_seq_length') + and vision_config.max_vision_cuda_graph_seq_length + ): + return vision_config.max_vision_cuda_graph_seq_length + + if hasattr(vision_config, 'num_position_embeddings'): + seq_length = vision_config.num_position_embeddings + if hasattr(vision_config, 'spatial_merge_size'): + merge_factor = vision_config.spatial_merge_size**2 + seq_length = seq_length // merge_factor + return seq_length + + return default_seq_length + + +class VisionTECudaGraphHelper(TECudaGraphHelper): + """Helper to capture CUDA Graphs for vision encoder layers using TE. + + Inherits from TECudaGraphHelper and overrides only the + vision-specific behaviour: + + * Layer discovery finds vision_model.decoder.layers instead of the + language decoder layers. + * num_model_chunks is always 1 (vision has no virtual pipeline stages). + * Batch dimension is always 1 (images are concatenated along the sequence + dimension). + * Sample argument generation uses a simple loop (no rotary embeddings or + buffer-reuse optimization). + * Captured graph outputs are wrapped to filter None values that arise + from vision encoder layers returning (output, None). + + Args: + model: The full model (list of model chunks) containing vision_model. + vision_config: TransformerConfig for the vision encoder. + vision_seq_length: Sequence length for vision (max vision tokens). + micro_batch_size: Micro-batch size (unused for sample-arg generation + since the vision encoder always uses batch-dim = 1). + num_microbatches: Number of microbatches per step. + """ + + def __init__( + self, + model, + vision_config, + vision_seq_length: int, + micro_batch_size: int, + num_microbatches: int = 1, + ): + super().__init__(model, vision_config, vision_seq_length, micro_batch_size) + # Vision encoder concatenates all images along the sequence dimension + # with a fixed batch dimension of 1, regardless of the training MBS. + self.micro_batch_size = 1 + self.num_model_chunks = 1 + self.num_microbatches = num_microbatches + + def _discover_layers(self): + """Discover captureable layers from the vision encoder.""" + self.vision_model = None + vision_layers = [] + + for model_chunk in self.model: + try: + unwrapped = get_attr_wrapped_model( + model_chunk, 'vision_model', allow_none=True, return_model_obj=True + ) + if unwrapped is not None and hasattr(unwrapped, 'vision_model'): + self.vision_model = unwrapped.vision_model + break + except (RuntimeError, AttributeError): + continue + + if self.vision_model is not None: + if hasattr(self.vision_model, 'decoder') and hasattr( + self.vision_model.decoder, 'layers' + ): + for layer in self.vision_model.decoder.layers: + if _layer_is_graphable(layer, self.config): + vision_layers.append(layer) + + if vision_layers: + self.chunks_with_decoder = [self.vision_model] + self.num_layers_per_chunk = [len(vision_layers)] + self.callables_per_chunk = [vision_layers] + self.callables_per_chunk_is_mtp = [[False] * len(vision_layers)] + self.flattened_callables = list(vision_layers) + self.flattened_callables_is_mtp = [False] * len(vision_layers) + else: + if self.vision_model is None: + logger.warning( + 'VisionTECudaGraphHelper: No vision_model found in model. ' + 'CUDA graphs will not be captured for vision encoder.' + ) + self.chunks_with_decoder = [None] + self.num_layers_per_chunk = [0] + self.callables_per_chunk = [[]] + self.callables_per_chunk_is_mtp = [[]] + self.flattened_callables = [] + self.flattened_callables_is_mtp = [] + + # backward-compat aliases used by callers / tests + self.callables = vision_layers + self.num_layers = len(vision_layers) + + if vision_layers: + logger.info( + f'VisionTECudaGraphHelper: Found {self.num_layers} graphable vision encoder ' + f'layers. seq_length={self.seq_length} (all images concatenated, batch_dim=1)' + ) + + def _start_capturing(self): + """Start capturing for vision encoder. + + Unlike the parent, this skips torch.distributed.barrier() because + with PP > 1 only the first pipeline stage has vision layers — other + ranks return early from create_cudagraphs and never reach this + point, so a barrier would deadlock. + """ + assert not self._graphs_created, 'CUDA Graphs have already been created.' + gc.collect() + torch.cuda.empty_cache() + if FREEZE_GC: + gc.freeze() + _set_capture_start() + log_single_rank(logger, logging.INFO, 'Start vision encoder CUDA Graphs capture...') + return time.time() + + def _finish_capturing(self, start_time): + """Finish capturing for vision encoder. + + Unlike the parent, this skips: + - torch.distributed.barrier() (asymmetric: only first PP stage captures). + - model_chunk.zero_grad_buffer() / optimizer.zero_grad() (handled + by the LM decoder helper's _finish_capturing which runs on all ranks). + - clear_aux_losses_tracker / reset_model_temporary_tensors + (LM-specific cleanup already handled by the LM helper). + """ + log_single_rank( + logger, + logging.INFO, + f'Time spent in vision encoder CUDA Graphs capture on rank ' + f'{torch.distributed.get_rank()}: {time.time() - start_time}s', + ) + _set_capture_end() + if FREEZE_GC: + gc.unfreeze() + gc.collect() + torch.cuda.empty_cache() + self._graphs_created = True + + def _get_sample_arguments(self, order, chunk_id_list=None): + """Generate sample arguments for vision encoder CUDA Graph capturing. + + Vision uses a simple per-layer-per-microbatch loop with batch_dim=1 + and no rotary embeddings (unlike the parent's buffer-reuse + optimization). The order and chunk_id_list arguments are + unused because vision has num_model_chunks=1 and does not need + the pipeline-schedule-aware buffer lifecycle tracking. + + Returns: + Tuple of (sample_args, sample_kwargs) lists for each + (layer, microbatch) pair. + """ + if not self.flattened_callables: + return [], [] + + sample_args = [] + sample_kwargs_list = [] + hidden_size = self.config.hidden_size + + for _microbatch_idx in range(self.num_microbatches): + for layer in self.flattened_callables: + hidden_states = torch.zeros( + self.seq_length, + 1, + hidden_size, + dtype=torch.bfloat16, + device='cuda', + requires_grad=True, + ) + + if hasattr(layer, 'get_layer_static_inputs'): + static_inputs = layer.get_layer_static_inputs(self.seq_length, 1) + hidden_states = static_inputs.pop('hidden_states', hidden_states) + sample_args.append((hidden_states,)) + sample_kwargs_list.append(static_inputs) + else: + sample_args.append((hidden_states,)) + sample_kwargs_list.append({}) + + return sample_args, sample_kwargs_list + + def create_cudagraphs(self): + """Capture CUDA Graphs for vision encoder layers per microbatch. + + Delegates to the parent's capture workflow, then wraps the captured + graphs with _wrap_graph_for_vision to filter None from + (output, None) tuples so that _te_cuda_graph_replay's + len == 1 assertion passes. + """ + if not self.flattened_callables: + logger.warning( + 'VisionTECudaGraphHelper: No graphable layers found. ' + 'Skipping CUDA graph capture.' + ) + return + + super().create_cudagraphs() + + for layer in self.flattened_callables: + if hasattr(layer, 'cuda_graphs'): + layer.cuda_graphs = [_wrap_graph_for_vision(g) for g in layer.cuda_graphs] + + def cuda_graph_set_manual_hooks(self): + """No-op: vision encoder layers do not use DDP parameter-gather hooks. + + The parent derives hooks from model_chunk._make_forward_pre_hook which + requires overlap_param_gather=True. Vision encoder parameters are not + distributed with the same overlap strategy, so we skip hook setup. + """ diff --git a/tests/unit_tests/transformer/test_vision_cuda_graphs.py b/tests/unit_tests/transformer/test_vision_cuda_graphs.py new file mode 100644 index 00000000000..2a5679866a0 --- /dev/null +++ b/tests/unit_tests/transformer/test_vision_cuda_graphs.py @@ -0,0 +1,673 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import gc +import os +from copy import deepcopy +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import torch + +from megatron.core import parallel_state +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec +from megatron.core.tensor_parallel.random import ( + HAVE_TE, + initialize_rng_tracker, + model_parallel_cuda_manual_seed, +) +from megatron.core.transformer.cuda_graphs import ( + HAVE_TE_GRAPHS, + VisionTECudaGraphHelper, + _layer_is_graphable, + _wrap_graph_for_vision, + get_vision_cuda_graph_seq_length, + set_current_microbatch, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import is_te_min_version +from tests.unit_tests.test_utilities import Utils + +TE_MIN_VERSION = "2.13.0" +_te_version_ok = HAVE_TE and is_te_min_version(TE_MIN_VERSION) +if not _te_version_ok and __name__ != "__main__": + pytest.skip( + f"Vision CUDA graph tests require TransformerEngine >= {TE_MIN_VERSION}", + allow_module_level=True, + ) + + +# --------------------------------------------------------------------------- +# Tests for _layer_is_graphable +# --------------------------------------------------------------------------- +class TestVisionLayerIsGraphable: + def test_non_transformer_layer_returns_false(self): + config = SimpleNamespace(cuda_graph_impl="transformer_engine") + layer = torch.nn.Linear(4, 4) + assert _layer_is_graphable(layer, config) is False + + def test_wrong_cuda_graph_impl_returns_false(self): + from megatron.core.transformer.transformer_layer import TransformerLayer + + config = SimpleNamespace(cuda_graph_impl="local") + layer = MagicMock(spec=TransformerLayer) + # isinstance check with MagicMock(spec=...) should pass + assert _layer_is_graphable(layer, config) is False + + def test_correct_config_with_transformer_layer(self): + """Real TransformerLayer + cuda_graph_impl='transformer_engine' -> True.""" + initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + config = TransformerConfig( + num_layers=1, + hidden_size=16, + num_attention_heads=2, + use_cpu_initialization=True, + cuda_graph_impl="transformer_engine", + ) + from megatron.core.transformer.transformer_block import TransformerBlock + + block = TransformerBlock(config, get_vit_layer_with_transformer_engine_spec()) + layer = block.layers[0] + assert _layer_is_graphable(layer, config) is True + + Utils.destroy_model_parallel() + + +# --------------------------------------------------------------------------- +# Tests for _wrap_graph_for_vision +# --------------------------------------------------------------------------- +class TestWrapGraphForVision: + def test_filters_none_from_tuple(self): + def fake_graph(*args, **kwargs): + return (torch.tensor(1.0), None) + + wrapped = _wrap_graph_for_vision(fake_graph) + result = wrapped() + assert result == (torch.tensor(1.0),) + + def test_returns_non_tuple_unchanged(self): + t = torch.tensor(42.0) + + def fake_graph(*args, **kwargs): + return t + + wrapped = _wrap_graph_for_vision(fake_graph) + result = wrapped() + assert result is t + + def test_preserves_all_non_none(self): + a, b = torch.tensor(1.0), torch.tensor(2.0) + + def fake_graph(*args, **kwargs): + return (a, b) + + wrapped = _wrap_graph_for_vision(fake_graph) + result = wrapped() + assert result == (a, b) + + def test_all_none_returns_original(self): + def fake_graph(*args, **kwargs): + return (None, None) + + wrapped = _wrap_graph_for_vision(fake_graph) + result = wrapped() + # filtered is empty -> returns original tuple + assert result == (None, None) + + def test_preserves_te_attributes(self): + def fake_graph(*args, **kwargs): + return (torch.tensor(1.0),) + + fake_graph.backward_dw = "bwd_dw_fn" + fake_graph.reset = "reset_fn" + + wrapped = _wrap_graph_for_vision(fake_graph) + assert wrapped.backward_dw == "bwd_dw_fn" + assert wrapped.reset == "reset_fn" + + def test_missing_te_attributes_not_set(self): + def fake_graph(*args, **kwargs): + return (torch.tensor(1.0),) + + wrapped = _wrap_graph_for_vision(fake_graph) + assert not hasattr(wrapped, 'backward_dw') + assert not hasattr(wrapped, 'reset') + + +# --------------------------------------------------------------------------- +# Tests for get_vision_cuda_graph_seq_length +# --------------------------------------------------------------------------- +class TestGetVisionCudaGraphSeqLength: + def test_explicit_max_seq_length(self): + config = SimpleNamespace(max_vision_cuda_graph_seq_length=2048) + assert get_vision_cuda_graph_seq_length(config) == 2048 + + def test_explicit_max_seq_length_zero_falls_through(self): + """max_vision_cuda_graph_seq_length=0 is falsy, should fall through.""" + config = SimpleNamespace(max_vision_cuda_graph_seq_length=0) + assert get_vision_cuda_graph_seq_length(config, default_seq_length=999) == 999 + + def test_num_position_embeddings_only(self): + config = SimpleNamespace(num_position_embeddings=1024) + assert get_vision_cuda_graph_seq_length(config) == 1024 + + def test_num_position_embeddings_with_spatial_merge(self): + config = SimpleNamespace(num_position_embeddings=1024, spatial_merge_size=2) + # merge_factor = 2**2 = 4, seq = 1024 // 4 = 256 + assert get_vision_cuda_graph_seq_length(config) == 256 + + def test_spatial_merge_size_3(self): + config = SimpleNamespace(num_position_embeddings=900, spatial_merge_size=3) + # merge_factor = 9, seq = 900 // 9 = 100 + assert get_vision_cuda_graph_seq_length(config) == 100 + + def test_default_seq_length(self): + config = SimpleNamespace() + assert get_vision_cuda_graph_seq_length(config) == 4096 + + def test_custom_default(self): + config = SimpleNamespace() + assert get_vision_cuda_graph_seq_length(config, default_seq_length=512) == 512 + + def test_explicit_overrides_position_embeddings(self): + config = SimpleNamespace( + max_vision_cuda_graph_seq_length=8192, num_position_embeddings=1024 + ) + assert get_vision_cuda_graph_seq_length(config) == 8192 + + +# --------------------------------------------------------------------------- +# Integration test for VisionTECudaGraphHelper with LLaVA model +# --------------------------------------------------------------------------- +@pytest.mark.skipif( + not (HAVE_TE and is_te_min_version("1.5.0")), + reason="use_te_rng_tracker requires TransformerEngine version >= 1.5", +) +class TestVisionTECudaGraphHelper: + """Test VisionTECudaGraphHelper initialization, sample args, and graph lifecycle.""" + + def setup_method(self, method): + initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + ) + model_parallel_cuda_manual_seed(123) + + from megatron.core.models.multimodal.llava_model import LLaVAModel + + self.language_hidden_size = 64 + self.vision_hidden_size = 16 + self.vision_num_layers = 2 + + language_config = TransformerConfig( + num_layers=2, + hidden_size=self.language_hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + ) + + self.vision_config = TransformerConfig( + num_layers=self.vision_num_layers, + hidden_size=self.vision_hidden_size, + num_attention_heads=2, + use_cpu_initialization=True, + cuda_graph_impl="transformer_engine", + bf16=True, + pipeline_dtype=torch.bfloat16, + ) + + vision_projection_config = TransformerConfig( + num_layers=1, + hidden_size=self.language_hidden_size, + ffn_hidden_size=32, + num_attention_heads=1, + use_cpu_initialization=True, + bf16=True, + pipeline_dtype=torch.bfloat16, + ) + + language_layer_spec = get_gpt_layer_with_transformer_engine_spec() + vision_layer_spec = get_vit_layer_with_transformer_engine_spec() + vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) + + self.vision_config.vision_model_type = "clip" + language_config.language_model_type = "dummy" + + self.llava_model = LLaVAModel( + language_transformer_config=language_config, + language_transformer_layer_spec=language_layer_spec, + language_vocab_size=8192, + language_max_sequence_length=4096, + vision_transformer_config=self.vision_config, + vision_transformer_layer_spec=vision_layer_spec, + drop_vision_class_token=False, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_spec, + img_h=336, + img_w=336, + patch_dim=14, + pre_process=True, + post_process=True, + add_encoder=True, + add_decoder=True, + ) + self.llava_model.bfloat16() + + self.vision_seq_length = 576 + self.micro_batch_size = 2 + + def teardown_method(self, method): + Utils.destroy_model_parallel() + gc.collect() + + def _make_helper(self, num_microbatches=1): + return VisionTECudaGraphHelper( + model=[self.llava_model], + vision_config=self.vision_config, + vision_seq_length=self.vision_seq_length, + micro_batch_size=self.micro_batch_size, + num_microbatches=num_microbatches, + ) + + # -- Initialization tests -- + + def test_init_finds_vision_layers(self): + helper = self._make_helper() + assert helper.vision_model is not None, "Should find vision_model" + assert helper.num_layers == self.vision_num_layers + assert len(helper.callables) == self.vision_num_layers + assert helper.graphs_created() is False + + def test_init_no_vision_model_warns(self): + """When model has no vision_model attr, helper should degrade gracefully.""" + dummy_model = torch.nn.Linear(4, 4) + helper = VisionTECudaGraphHelper( + model=[dummy_model], + vision_config=self.vision_config, + vision_seq_length=self.vision_seq_length, + micro_batch_size=self.micro_batch_size, + ) + assert helper.vision_model is None + assert len(helper.callables) == 0 + assert helper.graphs_created() is False + + # -- _get_sample_arguments tests -- + + def test_get_sample_arguments_shapes(self): + helper = self._make_helper(num_microbatches=1) + # order is unused by vision override; pass a dummy + sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1]) + + expected_count = self.vision_num_layers * 1 # layers * microbatches + assert len(sample_args) == expected_count + assert len(sample_kwargs_list) == expected_count + + for i, (args_item, kwargs_item) in enumerate(zip(sample_args, sample_kwargs_list)): + assert isinstance(args_item, tuple), f"sample_args[{i}] should be tuple" + assert len(args_item) == 1, f"sample_args[{i}] should have one element (hidden_states)" + hs = args_item[0] + assert hs.shape == (self.vision_seq_length, 1, self.vision_hidden_size), ( + f"Expected ({self.vision_seq_length}, 1, {self.vision_hidden_size}), " + f"got {hs.shape}" + ) + assert hs.dtype == torch.bfloat16 + assert hs.device.type == 'cuda' + assert hs.requires_grad is True + + def test_get_sample_arguments_multi_microbatch(self): + helper = self._make_helper(num_microbatches=3) + sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1]) + + expected_count = self.vision_num_layers * 3 + assert len(sample_args) == expected_count + assert len(sample_kwargs_list) == expected_count + + def test_get_sample_arguments_empty_when_no_callables(self): + dummy_model = torch.nn.Linear(4, 4) + helper = VisionTECudaGraphHelper( + model=[dummy_model], + vision_config=self.vision_config, + vision_seq_length=self.vision_seq_length, + micro_batch_size=self.micro_batch_size, + ) + sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1]) + assert sample_args == [] + assert sample_kwargs_list == [] + + # -- create_cudagraphs / delete_cuda_graphs lifecycle -- + + @pytest.mark.skipif( + not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")), + reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0", + ) + def test_create_and_delete_cudagraphs(self): + """Full lifecycle: create graphs, verify state, delete, verify cleanup.""" + self.llava_model.cuda() + helper = self._make_helper(num_microbatches=1) + + assert not helper.graphs_created() + + helper.create_cudagraphs() + assert helper.graphs_created() + + # Each vision layer should have cuda_graphs attached + for layer in helper.callables: + assert hasattr(layer, 'cuda_graphs'), "Layer should have cuda_graphs after capture" + assert len(layer.cuda_graphs) == 1 # 1 microbatch + + # cudagraph_manager should have been removed during capture + for layer in helper.callables: + assert not hasattr( + layer, 'cudagraph_manager' + ), "cudagraph_manager should be removed before TE capture" + + helper.delete_cuda_graphs() + assert not helper.graphs_created() + + # cuda_graphs should be empty after delete + for layer in helper.callables: + assert layer.cuda_graphs == [], "cuda_graphs should be empty after delete" + + @pytest.mark.skipif( + not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")), + reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0", + ) + def test_create_cudagraphs_multi_microbatch(self): + """Verify that graphs are created per-microbatch per-layer.""" + self.llava_model.cuda() + num_mb = 2 + helper = self._make_helper(num_microbatches=num_mb) + + helper.create_cudagraphs() + assert helper.graphs_created() + + for layer in helper.callables: + assert hasattr(layer, 'cuda_graphs') + # PP=1 collapses to 1 microbatch internally + assert len(layer.cuda_graphs) == helper.num_microbatches + + helper.delete_cuda_graphs() + + def test_create_cudagraphs_no_callables_is_noop(self): + """create_cudagraphs on empty helper should not crash.""" + dummy_model = torch.nn.Linear(4, 4) + helper = VisionTECudaGraphHelper( + model=[dummy_model], + vision_config=self.vision_config, + vision_seq_length=self.vision_seq_length, + micro_batch_size=self.micro_batch_size, + ) + helper.create_cudagraphs() + assert not helper.graphs_created() + + def test_delete_cudagraphs_before_create_asserts(self): + """delete_cuda_graphs before creation should raise AssertionError.""" + helper = self._make_helper() + with pytest.raises(AssertionError): + helper.delete_cuda_graphs() + + +# --------------------------------------------------------------------------- +# Integration test with PP=2: vision encoder on first pipeline stage only +# --------------------------------------------------------------------------- +@pytest.mark.skipif( + not (HAVE_TE and is_te_min_version("1.5.0")), + reason="use_te_rng_tracker requires TransformerEngine version >= 1.5", +) +class TestVisionTECudaGraphHelperPP2: + """Test VisionTECudaGraphHelper with PP=2. + + With pipeline_model_parallel_size=2 the LLaVA model is split so that the + vision encoder lives exclusively on the first pipeline stage: + - pp_rank 0: add_encoder=True, pre_process=True, post_process=False + - pp_rank 1: add_encoder=False, pre_process=False, post_process=True + + This test verifies that: + 1. On stage 0 the helper finds and captures vision layers. + 2. On stage 1 the helper gracefully finds no vision layers. + 3. With PP>1, num_microbatches is NOT collapsed to 1. + """ + + def setup_method(self, method): + initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=2, + virtual_pipeline_model_parallel_size=None, + ) + model_parallel_cuda_manual_seed(123) + + from megatron.core.models.multimodal.llava_model import LLaVAModel + + self.language_hidden_size = 64 + self.vision_hidden_size = 16 + self.vision_num_layers = 2 + self.language_num_layers = 4 + + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + is_first_stage = pp_rank == 0 + is_last_stage = pp_rank == (parallel_state.get_pipeline_model_parallel_world_size() - 1) + + language_config = TransformerConfig( + num_layers=self.language_num_layers, + hidden_size=self.language_hidden_size, + num_attention_heads=4, + use_cpu_initialization=True, + pipeline_model_parallel_size=2, + bf16=True, + pipeline_dtype=torch.bfloat16, + ) + + self.vision_config = TransformerConfig( + num_layers=self.vision_num_layers, + hidden_size=self.vision_hidden_size, + num_attention_heads=2, + use_cpu_initialization=True, + cuda_graph_impl="transformer_engine", + bf16=True, + pipeline_dtype=torch.bfloat16, + ) + + vision_projection_config = TransformerConfig( + num_layers=1, + hidden_size=self.language_hidden_size, + ffn_hidden_size=32, + num_attention_heads=1, + use_cpu_initialization=True, + bf16=True, + pipeline_dtype=torch.bfloat16, + ) + + language_layer_spec = get_gpt_layer_with_transformer_engine_spec() + vision_layer_spec = get_vit_layer_with_transformer_engine_spec() + vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) + + self.vision_config.vision_model_type = "clip" + language_config.language_model_type = "dummy" + + self.is_first_stage = is_first_stage + self.llava_model = LLaVAModel( + language_transformer_config=language_config, + language_transformer_layer_spec=language_layer_spec, + language_vocab_size=8192, + language_max_sequence_length=4096, + vision_transformer_config=self.vision_config, + vision_transformer_layer_spec=vision_layer_spec, + drop_vision_class_token=False, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_spec, + img_h=336, + img_w=336, + patch_dim=14, + pre_process=is_first_stage, + post_process=is_last_stage, + add_encoder=is_first_stage, + add_decoder=True, + ) + self.llava_model.bfloat16() + + self.vision_seq_length = 576 + self.micro_batch_size = 2 + + def teardown_method(self, method): + Utils.destroy_model_parallel() + gc.collect() + + def _make_helper(self, num_microbatches=4): + return VisionTECudaGraphHelper( + model=[self.llava_model], + vision_config=self.vision_config, + vision_seq_length=self.vision_seq_length, + micro_batch_size=self.micro_batch_size, + num_microbatches=num_microbatches, + ) + + def test_pp2_first_stage_finds_vision_layers(self): + """Stage 0 should discover all vision encoder layers.""" + if not self.is_first_stage: + pytest.skip("This assertion is only for pp_rank 0") + + helper = self._make_helper(num_microbatches=4) + assert helper.vision_model is not None + assert helper.num_layers == self.vision_num_layers + assert len(helper.callables) == self.vision_num_layers + + def test_pp2_last_stage_has_no_vision_layers(self): + """Stage 1 should find no vision model (encoder lives on stage 0).""" + if self.is_first_stage: + pytest.skip("This assertion is only for pp_rank 1") + + helper = self._make_helper(num_microbatches=4) + assert helper.vision_model is None + assert len(helper.callables) == 0 + assert not helper.graphs_created() + + def test_pp2_num_microbatches_preserved(self): + """With PP>1, num_microbatches should NOT be collapsed to 1.""" + if not self.is_first_stage: + pytest.skip("Vision layers only on pp_rank 0") + + num_mb = 8 + helper = self._make_helper(num_microbatches=num_mb) + # _get_sample_arguments generates layers * microbatches entries + sample_args, sample_kwargs_list = helper._get_sample_arguments(order=[1, -1]) + expected_count = self.vision_num_layers * num_mb + assert len(sample_args) == expected_count, ( + f"With PP>1, expected {expected_count} sample_args " + f"(layers={self.vision_num_layers} * mb={num_mb}), got {len(sample_args)}" + ) + + @pytest.mark.skipif( + not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")), + reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0", + ) + def test_pp2_create_cudagraphs_first_stage(self): + """On stage 0, CUDA graphs should be captured with the full pipeline order.""" + if not self.is_first_stage: + pytest.skip("Vision layers only on pp_rank 0") + + self.llava_model.cuda() + num_mb = 4 + helper = self._make_helper(num_microbatches=num_mb) + + assert not helper.graphs_created() + + helper.create_cudagraphs() + assert helper.graphs_created() + + # num_microbatches should be preserved (PP>1 does not collapse) + assert helper.num_microbatches == num_mb + + # Each layer should have one graph per microbatch + for layer in helper.callables: + assert hasattr(layer, 'cuda_graphs') + assert ( + len(layer.cuda_graphs) == num_mb + ), f"Expected {num_mb} graphs per layer, got {len(layer.cuda_graphs)}" + + # Cleanup + helper.delete_cuda_graphs() + assert not helper.graphs_created() + for layer in helper.callables: + assert layer.cuda_graphs == [] + + @pytest.mark.skipif( + not (HAVE_TE_GRAPHS and is_te_min_version("2.7.0")), + reason="TE CUDA graph capture requires TransformerEngine >= 2.7.0", + ) + def test_pp2_create_cudagraphs_last_stage_noop(self): + """On stage 1 (no vision model), create_cudagraphs should be a no-op.""" + if self.is_first_stage: + pytest.skip("This assertion is only for pp_rank 1") + + helper = self._make_helper(num_microbatches=4) + helper.create_cudagraphs() + assert not helper.graphs_created() + + +if __name__ == "__main__": + if not _te_version_ok: + print(f"SKIPPED: Vision CUDA graph tests require TransformerEngine >= {TE_MIN_VERSION}") + exit(0) + + from _pytest.outcomes import Skipped + + def run_test(test_obj, test_fn_name): + """Run a test method, treating pytest.skip() as a non-error.""" + test_obj.setup_method(method=None) + try: + getattr(test_obj, test_fn_name)() + except Skipped as e: + print(f" SKIPPED {test_fn_name}: {e}") + finally: + test_obj.teardown_method(method=None) + + # Quick smoke tests for pure functions + t = TestWrapGraphForVision() + t.test_filters_none_from_tuple() + t.test_returns_non_tuple_unchanged() + t.test_preserves_all_non_none() + t.test_all_none_returns_original() + t.test_preserves_te_attributes() + t.test_missing_te_attributes_not_set() + print("_wrap_graph_for_vision tests passed.") + + t2 = TestGetVisionCudaGraphSeqLength() + t2.test_explicit_max_seq_length() + t2.test_explicit_max_seq_length_zero_falls_through() + t2.test_num_position_embeddings_only() + t2.test_num_position_embeddings_with_spatial_merge() + t2.test_spatial_merge_size_3() + t2.test_default_seq_length() + t2.test_custom_default() + t2.test_explicit_overrides_position_embeddings() + print("get_vision_cuda_graph_seq_length tests passed.") + + # Integration tests (require GPU + distributed init) + t3 = TestVisionTECudaGraphHelper() + run_test(t3, "test_init_finds_vision_layers") + run_test(t3, "test_get_sample_arguments_shapes") + run_test(t3, "test_create_and_delete_cudagraphs") + print("TestVisionTECudaGraphHelper tests passed.") + + # PP=2 integration tests (require 2+ GPUs) + if Utils.world_size >= 2: + t4 = TestVisionTECudaGraphHelperPP2() + run_test(t4, "test_pp2_first_stage_finds_vision_layers") + run_test(t4, "test_pp2_last_stage_has_no_vision_layers") + run_test(t4, "test_pp2_num_microbatches_preserved") + run_test(t4, "test_pp2_create_cudagraphs_first_stage") + run_test(t4, "test_pp2_create_cudagraphs_last_stage_noop") + print("TestVisionTECudaGraphHelperPP2 tests passed.") + else: + print("SKIPPED TestVisionTECudaGraphHelperPP2 (requires 2+ GPUs)") + + print("All vision CUDA graph tests passed.")