diff --git a/docs/models/vlm/ministral3.md b/docs/models/vlm/ministral3.md index b47a08df6c..71dcec8695 100644 --- a/docs/models/vlm/ministral3.md +++ b/docs/models/vlm/ministral3.md @@ -7,7 +7,7 @@ Ministral 3 models support multimodal tasks including image captioning, visual q Ministral family models are supported via the Bridge system with auto-detected configuration and weight mapping. ```{important} -Please update `transformers` version to 5.0.0rc0 in order to use the Ministral 3 models. +Please upgrade to `transformers` v5 and upgrade `mistral-common` in order to use the Ministral 3 models. ``` ## Available Models diff --git a/src/megatron/bridge/data/vlm_datasets/hf_provider.py b/src/megatron/bridge/data/vlm_datasets/hf_provider.py index c63c95e780..8014276753 100644 --- a/src/megatron/bridge/data/vlm_datasets/hf_provider.py +++ b/src/megatron/bridge/data/vlm_datasets/hf_provider.py @@ -67,6 +67,9 @@ class HFDatasetConversationProvider(DatasetProvider): # DataloaderConfig fields are inherited (num_workers, dataloader_type, etc.) dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single" + # Enable batch-level online sequence packing (dataset-level packing is available in FinetuneDatasetProvider) + pack_sequences_in_batch: bool = False + def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]: registry: Dict[str, Callable[..., List[Dict[str, Any]]]] = { "make_rdr_dataset": make_rdr_dataset, diff --git a/src/megatron/bridge/models/gemma/gemma3_provider.py b/src/megatron/bridge/models/gemma/gemma3_provider.py index 63d6413e32..80966260bc 100644 --- a/src/megatron/bridge/models/gemma/gemma3_provider.py +++ b/src/megatron/bridge/models/gemma/gemma3_provider.py @@ -367,11 +367,34 @@ def __init__( **kwargs, ) + def forward( + self, + max_seq_len: int, + offset: int = 0, + packed_seq: bool = False, + cp_group: torch.distributed.ProcessGroup | None = None, + ) -> Tensor: + """Get global and local rope embedding. + + Note: Caching is bypassed when cp_group is provided since ProcessGroup is unhashable. + """ + # ProcessGroup is unhashable, so bypass caching when cp_group is provided + if cp_group is not None: + rope_global = super().forward(max_seq_len, offset, packed_seq, cp_group) + rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, cp_group) + return rope_local, rope_global + return self._forward_cached(max_seq_len, offset, packed_seq) + @lru_cache(maxsize=32) - def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: - """Get global and local rope embedding""" - rope_global = super().forward(max_seq_len, offset, packed_seq) - rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq) + def _forward_cached( + self, + max_seq_len: int, + offset: int = 0, + packed_seq: bool = False, + ) -> Tensor: + """Cached forward for hashable parameters only.""" + rope_global = super().forward(max_seq_len, offset, packed_seq, None) + rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, None) return rope_local, rope_global diff --git a/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py b/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py index 0b1e937840..e68225ca1e 100644 --- a/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py +++ b/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py @@ -14,22 +14,30 @@ import types from dataclasses import dataclass -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional import torch import torch.nn as nn import torch.nn.functional as F from megatron.core.tensor_parallel.layers import ColumnParallelLinear +from megatron.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule from torch import Tensor from transformers import AutoModel, Gemma3Model from megatron.bridge.models.gpt_provider import GPTModelProvider -from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +from megatron.bridge.utils.common_utils import ( + hook_hf_module_setattr_for_tp_grad_sync, + slice_batch_for_context_parallel, +) from megatron.bridge.utils.import_utils import safe_import_from +if TYPE_CHECKING: + from megatron.core.packed_seq_params import PackedSeqParams + + TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm") @@ -110,12 +118,16 @@ def forward( pixel_values: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, runtime_gather_output: Optional[bool] = None, + packed_seq_params: Optional["PackedSeqParams"] = None, *, loss_mask: Optional[Tensor] = None, - ) -> Tensor: + ) -> tuple[Tensor, Tensor | None]: r""" - image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): - The temporal, height and width of feature shape of each image in LLM. + Forward pass combining HuggingFace vision encoder with Megatron language model. + + Returns: + tuple: (output_tensor, loss_mask) where output_tensor contains model output + and loss_mask is the CP-sliced mask for consistent loss computation. """ if self.pre_process: if inputs_embeds is None: @@ -134,7 +146,7 @@ def forward( special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = special_image_mask.sum(dim=1).item(dim=0)[0] + image_tokens_in_text = special_image_mask[:, :, 0].sum().item() raise ValueError( f"Number of images does not match number of special image tokens in the input text. " f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " @@ -144,18 +156,38 @@ def forward( inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D) + # Apply sequence parallelism scatter if enabled + if self.config.sequence_parallel: + inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds) + + # Compute attention mask on FULL sequence (before CP slicing) + # This is needed because image regions need bidirectional attention attention_mask = self._compute_attention_mask(input_ids) + # CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask + # This must happen AFTER vision-text merge so image token positions are correct + inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, + pg_collection=self.config._pg_collection, + ) + outputs = self.language_model.forward( input_ids=None, position_ids=position_ids, - attention_mask=attention_mask, # (B, 1, T, T) - decoder_input=inputs_embeds, # (T, B, D) - labels=labels, # (B, T) + attention_mask=attention_mask, + decoder_input=inputs_embeds, + labels=labels, loss_mask=loss_mask, runtime_gather_output=runtime_gather_output, + packed_seq_params=packed_seq_params, ) - return outputs + # Return both outputs and the CP-sliced loss_mask for consistent loss computation + return (outputs, loss_mask) def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): """Freeze model modules. @@ -191,7 +223,7 @@ def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_ def _compute_attention_mask( self, input_ids: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Optional[torch.Tensor]: if not self.pre_process: return None batch_size, seq_len = input_ids.shape diff --git a/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py b/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py index 3c97fab9ff..34d20c42fb 100644 --- a/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py +++ b/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py @@ -23,7 +23,7 @@ """ import types -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch import transformers @@ -37,6 +37,10 @@ from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +if TYPE_CHECKING: + from megatron.core.packed_seq_params import PackedSeqParams + + def is_transformers_min_version(version): """Check if minimum version of transformers is installed.""" try: @@ -158,6 +162,7 @@ def forward( video_grid_thw: Optional[torch.LongTensor] = None, labels: Optional[torch.Tensor] = None, runtime_gather_output: Optional[bool] = None, + packed_seq_params: Optional["PackedSeqParams"] = None, *, loss_mask: Optional[Tensor] = None, ) -> Tensor: @@ -233,6 +238,7 @@ def forward( labels=labels, loss_mask=loss_mask, runtime_gather_output=runtime_gather_output, + packed_seq_params=packed_seq_params, ) return outputs diff --git a/src/megatron/bridge/models/ministral3/ministral3_provider.py b/src/megatron/bridge/models/ministral3/ministral3_provider.py index 3f5ce8e59d..5c6605d13b 100644 --- a/src/megatron/bridge/models/ministral3/ministral3_provider.py +++ b/src/megatron/bridge/models/ministral3/ministral3_provider.py @@ -260,11 +260,17 @@ def __init__( self.beta = 0 # No effect self.max_position_embeddings = self.config.seq_length + @staticmethod def _get_llama_4_attn_scale( - self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int + positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple ) -> torch.Tensor: scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) - return scaling.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + # Add dimensions to match query shape: [seq_len] -> [seq_len, 1, 1] for packed or [seq_len, 1, 1, 1] for unpacked + # Query can be either [seq_len, num_heads, head_dim] (packed) or [seq_len, batch, num_heads, head_dim] (unpacked) + num_dims_to_add = len(query_shape) - 1 + for _ in range(num_dims_to_add): + scaling = scaling.unsqueeze(-1) + return scaling def forward( self, @@ -276,6 +282,8 @@ def forward( **kwargs, ): positions_ids = torch.arange(query.shape[0], device=query.device) - query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings).to(query.dtype) + query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings, query.shape).to( + query.dtype + ) return super().forward(query, key, value, attention_mask, attn_mask_type, **kwargs) diff --git a/src/megatron/bridge/models/ministral3/modeling_ministral3.py b/src/megatron/bridge/models/ministral3/modeling_ministral3.py index 00590b801b..e541059e7a 100644 --- a/src/megatron/bridge/models/ministral3/modeling_ministral3.py +++ b/src/megatron/bridge/models/ministral3/modeling_ministral3.py @@ -24,14 +24,21 @@ """ import types -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch from megatron.core.transformer.module import MegatronModule from torch import Tensor from megatron.bridge.models.gpt_provider import GPTModelProvider -from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +from megatron.bridge.utils.common_utils import ( + hook_hf_module_setattr_for_tp_grad_sync, + slice_batch_for_context_parallel, +) + + +if TYPE_CHECKING: + from megatron.core.packed_seq_params import PackedSeqParams # Import HuggingFace Mistral3 model classes with fallback @@ -185,9 +192,10 @@ def forward( labels: Optional[torch.Tensor] = None, runtime_gather_output: Optional[bool] = None, image_sizes: Optional[torch.Tensor] = None, + packed_seq_params: Optional["PackedSeqParams"] = None, *, loss_mask: Optional[Tensor] = None, - ) -> Tensor: + ) -> tuple[Tensor, Tensor | None]: """ Forward pass combining HuggingFace vision encoder with Megatron language model. @@ -202,7 +210,8 @@ def forward( loss_mask: Mask for loss computation. Returns: - Model output (logits or loss depending on mode). + tuple: (output_tensor, loss_mask) where output_tensor contains model output + and loss_mask is the CP-sliced mask for consistent loss computation. """ if self.pre_process: if inputs_embeds is None: @@ -237,6 +246,18 @@ def forward( # Transpose back to Megatron format [seq_len, batch, hidden] inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() + # CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask + # This must happen AFTER vision-text merge so image token positions are correct + inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, + pg_collection=self.config._pg_collection, + ) + # Forward through Megatron language model outputs = self.language_model.forward( input_ids=None, @@ -246,8 +267,10 @@ def forward( labels=labels, loss_mask=loss_mask, runtime_gather_output=runtime_gather_output, + packed_seq_params=packed_seq_params, ) - return outputs + # Return both outputs and the CP-sliced loss_mask for consistent loss computation + return (outputs, loss_mask) def freeze( self, diff --git a/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py b/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py index bd1447347f..d00fd7ab9f 100644 --- a/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py +++ b/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py @@ -213,6 +213,7 @@ def forward( labels=labels, loss_mask=loss_mask, runtime_gather_output=runtime_gather_output, + packed_seq_params=packed_seq_params, ) return outputs diff --git a/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py b/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py index 1e36a9a8dd..fc3520edd1 100644 --- a/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py +++ b/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py @@ -234,9 +234,12 @@ def _gemma3_vl_common( model_cfg.freeze_vision_model = freeze_vision_model model_cfg.freeze_vision_projection = freeze_vision_projection model_cfg.seq_length = seq_length + model_cfg.cp_comm_type = "a2a" # Optimizer and scheduler - use finetune_lr if provided, otherwise use lr effective_lr = finetune_lr if finetune_lr is not None else lr + if min_lr > effective_lr: + min_lr = effective_lr * 0.1 opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( lr_warmup_iters=lr_warmup_iters, lr_decay_iters=lr_decay_iters if lr_decay_iters is not None else train_iters, diff --git a/src/megatron/bridge/training/utils/packed_seq_utils.py b/src/megatron/bridge/training/utils/packed_seq_utils.py index 98dbd6d5ac..e631be426a 100644 --- a/src/megatron/bridge/training/utils/packed_seq_utils.py +++ b/src/megatron/bridge/training/utils/packed_seq_utils.py @@ -44,14 +44,18 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams: cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin") if cu_seqlens_argmin is not None: - cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()] - else: + argmin_idx = cu_seqlens_argmin.item() + assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1 + cu_seqlens_padded = cu_seqlens_padded[:argmin_idx] + elif torch.min(cu_seqlens_padded) == -1: cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)] if cu_seqlens_unpadded is not None: if cu_seqlens_unpadded_argmin is not None: - cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()] - else: + argmin_idx = cu_seqlens_unpadded_argmin.item() + assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1 + cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx] + elif torch.min(cu_seqlens_unpadded) == -1: cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)] max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None diff --git a/src/megatron/bridge/training/vlm_step.py b/src/megatron/bridge/training/vlm_step.py index 1e013eeebd..3c9724ca80 100644 --- a/src/megatron/bridge/training/vlm_step.py +++ b/src/megatron/bridge/training/vlm_step.py @@ -19,7 +19,7 @@ import torch from megatron.core.models.gpt import GPTModel from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage -from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config +from megatron.core.utils import get_model_config from megatron.bridge.training.config import ConfigContainer from megatron.bridge.training.losses import ( @@ -97,19 +97,133 @@ def get_batch_from_iterator( return _batch_required_keys -def get_batch( - data_iterator: Iterable, cfg: ConfigContainer, use_mtp: bool = False, *, pg_collection -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - Any, -]: +def pack_batch_sequences( + tokens: torch.Tensor, + labels: torch.Tensor, + loss_mask: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + pad_token_id: int = 0, + pad_to_multiple_of: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Pack sequences in a batch by concatenating them and removing padding. + + Args: + tokens: [batch_size, seq_len] + labels: [batch_size, seq_len] + loss_mask: [batch_size, seq_len] + attention_mask: [batch_size, 1, seq_len, seq_len] or None + position_ids: [batch_size, seq_len] + pad_token_id: Token ID used for padding + + Returns: + Tuple of: + - packed_tokens: [1, total_len] - concatenated sequences + - packed_labels: [1, total_len] + - packed_loss_mask: [1, total_len] + - packed_attention_mask: None (not used with packing) + - packed_position_ids: [1, total_len] + - cu_seqlens: [num_sequences + 1] - cumulative sequence lengths + - max_seqlen: tensor - max sequence length in packed batch + """ + batch_size, seq_len = tokens.shape + device = tokens.device + + # Find actual sequence lengths (excluding padding) + # Assuming padding is at the end and uses pad_token_id (0) + seq_lengths = [] + valid_sequences = [] + + for i in range(batch_size): + # Find first padding token or use full length + non_pad_mask = tokens[i] != pad_token_id + if non_pad_mask.any(): + # Find the last non-padding token + last_valid_idx = non_pad_mask.nonzero(as_tuple=True)[0][-1].item() + 1 + else: + # Empty sequence, skip + continue + + seq_lengths.append(last_valid_idx) + valid_sequences.append(i) + + if len(valid_sequences) == 0: + # No valid sequences, return empty packed batch + logger.warning("No valid sequences found in batch, skipping packing") + return ( + tokens[:, :0], # Empty tensor + labels[:, :0], + loss_mask[:, :0], + attention_mask, + position_ids[:, :0], + torch.tensor([0], dtype=torch.int32, device=device), # Empty cu_seqlens + torch.tensor(0, dtype=torch.int32, device=device), + ) + + # Build cumulative sequence lengths + cu_seqlens = [0] + padded_seq_lengths = [] + for length in seq_lengths: + if pad_to_multiple_of > 1: + padded_len = ((length + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + else: + padded_len = length + padded_seq_lengths.append(padded_len) + # Use padded lengths for cu_seqlens so THD RoPE splits sum correctly under CP. + cu_seqlens.append(cu_seqlens[-1] + padded_len) + + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + max_seqlen = torch.tensor(max(padded_seq_lengths), dtype=torch.int32, device=device) + total_len = cu_seqlens[-1].item() + + # Concatenate sequences (remove padding) + packed_tokens = torch.zeros(1, total_len, dtype=tokens.dtype, device=device) + packed_labels = torch.zeros(1, total_len, dtype=labels.dtype, device=device) + packed_loss_mask = torch.zeros(1, total_len, dtype=loss_mask.dtype, device=device) + packed_position_ids = torch.zeros(1, total_len, dtype=position_ids.dtype, device=device) + + offset = 0 + for i, seq_idx in enumerate(valid_sequences): + length = seq_lengths[i] + padded_len = padded_seq_lengths[i] + pad_len = padded_len - length + packed_tokens[0, offset : offset + length] = tokens[seq_idx, :length] + packed_labels[0, offset : offset + length] = labels[seq_idx, :length] + packed_loss_mask[0, offset : offset + length] = loss_mask[seq_idx, :length] + packed_position_ids[0, offset : offset + length] = position_ids[seq_idx, :length] + if pad_len > 0: + packed_tokens[0, offset + length : offset + padded_len] = pad_token_id + packed_labels[0, offset + length : offset + padded_len] = -100 + packed_loss_mask[0, offset + length : offset + padded_len] = 0 + start_pos = position_ids[seq_idx, length - 1] + 1 + packed_position_ids[0, offset + length : offset + padded_len] = torch.arange( + start_pos, + start_pos + pad_len, + device=device, + dtype=position_ids.dtype, + ) + offset += padded_len + + logger.debug( + f"Packed {len(valid_sequences)} sequences: lengths={seq_lengths}, total_len={total_len}, max_len={max_seqlen}" + ) + + # Attention mask is not used with packed sequences (handled by cu_seqlens) + packed_attention_mask = None + + return ( + packed_tokens, + packed_labels, + packed_loss_mask, + packed_attention_mask, + packed_position_ids, + cu_seqlens, + max_seqlen, + ) + + +def get_batch(data_iterator: Iterable, cfg: ConfigContainer, use_mtp: bool = False, *, pg_collection) -> tuple[...]: """Generate a batch. Args: @@ -133,75 +247,110 @@ def get_batch( is_first_pp_stage=is_first, is_last_pp_stage=is_last, ) + enable_packing = getattr(cfg.dataset, "pack_sequences_in_batch", False) - # Slice only text tensors for context parallelism - cp_keys = ("tokens", "input_ids", "labels", "loss_mask", "attention_mask", "position_ids") - cp_slice = {k: batch.get(k) for k in cp_keys if k in batch} - cp_slice = get_batch_on_this_cp_rank(cp_slice, cp_group=pg_collection.cp) - for k, v in cp_slice.items(): - batch[k] = v - - # When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length - if getattr(cfg.model, "pipeline_model_parallel_size", 1) > 1: - seq_len = cfg.model.seq_length + if not enable_packing: + # When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length + if getattr(cfg.model, "pipeline_model_parallel_size", 1) > 1: + seq_len = cfg.model.seq_length - tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") - tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0) - if batch.get("tokens") is not None: - batch["tokens"] = tokens_or_input # type: ignore[assignment] - else: - batch["input_ids"] = tokens_or_input # type: ignore[assignment] - batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), seq_len, seq_len, pad_value=-100) # type: ignore[assignment] - batch["loss_mask"] = pad_or_truncate_2d_to_len(batch.get("loss_mask"), seq_len, seq_len, pad_value=0) # type: ignore[assignment] - batch["position_ids"] = pad_or_truncate_pos_to_len(batch.get("position_ids"), seq_len, seq_len) # type: ignore[assignment] - if batch.get("attention_mask") is not None: - batch["attention_mask"] = pad_or_truncate_attn_to_len(batch.get("attention_mask"), seq_len, seq_len) # type: ignore[assignment] - else: - # No PP: pad sequence length to nearest multiple of 128 for efficiency (capped at model seq_length) - seq_cap = cfg.model.seq_length - - def _ceil_to_mult(n: int, mult: int) -> int: - return ((n + mult - 1) // mult) * mult - - tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") - if tokens_or_input is not None: - cur_len = tokens_or_input.size(1) - target_len = min(seq_cap, _ceil_to_mult(cur_len, 128)) - - # tokens/input_ids - padded_tokens = pad_or_truncate_2d_to_len(tokens_or_input, target_len, seq_cap, pad_value=0) + tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") + tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0) if batch.get("tokens") is not None: - batch["tokens"] = padded_tokens # type: ignore[assignment] + batch["tokens"] = tokens_or_input # type: ignore[assignment] else: - batch["input_ids"] = padded_tokens # type: ignore[assignment] + batch["input_ids"] = tokens_or_input # type: ignore[assignment] + batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), seq_len, seq_len, pad_value=-100) # type: ignore[assignment] + batch["loss_mask"] = pad_or_truncate_2d_to_len(batch.get("loss_mask"), seq_len, seq_len, pad_value=0) # type: ignore[assignment] + batch["position_ids"] = pad_or_truncate_pos_to_len(batch.get("position_ids"), seq_len, seq_len) # type: ignore[assignment] + if batch.get("attention_mask") is not None: + batch["attention_mask"] = pad_or_truncate_attn_to_len(batch.get("attention_mask"), seq_len, seq_len) # type: ignore[assignment] + else: + # No PP: pad sequence length to nearest multiple of 128 for efficiency (capped at model seq_length) + seq_cap = cfg.model.seq_length - # labels and loss mask - batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), target_len, seq_cap, pad_value=-100) # type: ignore[assignment] - batch["loss_mask"] = pad_or_truncate_2d_to_len(batch.get("loss_mask"), target_len, seq_cap, pad_value=0) # type: ignore[assignment] + def _ceil_to_mult(n: int, mult: int) -> int: + return ((n + mult - 1) // mult) * mult - # position_ids: extend with increasing positions - pos = batch.get("position_ids") - pos = pad_or_truncate_pos_to_len(pos, target_len, seq_cap) - if pos is not None: - batch["position_ids"] = pos # type: ignore[assignment] + tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") + if tokens_or_input is not None: + cur_len = tokens_or_input.size(1) + target_len = min(seq_cap, _ceil_to_mult(cur_len, 128)) - # attention_mask if present - attn = batch.get("attention_mask") - if attn is not None: - attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap) - batch["attention_mask"] = attn # type: ignore[assignment] + # tokens/input_ids + padded_tokens = pad_or_truncate_2d_to_len(tokens_or_input, target_len, seq_cap, pad_value=0) + if batch.get("tokens") is not None: + batch["tokens"] = padded_tokens # type: ignore[assignment] + else: + batch["input_ids"] = padded_tokens # type: ignore[assignment] + + # labels and loss mask + batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), target_len, seq_cap, pad_value=-100) # type: ignore[assignment] + batch["loss_mask"] = pad_or_truncate_2d_to_len( + batch.get("loss_mask"), target_len, seq_cap, pad_value=0 + ) # type: ignore[assignment] + + # position_ids: extend with increasing positions + pos = batch.get("position_ids") + pos = pad_or_truncate_pos_to_len(pos, target_len, seq_cap) + if pos is not None: + batch["position_ids"] = pos # type: ignore[assignment] + + # attention_mask if present + attn = batch.get("attention_mask") + if attn is not None: + attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap) + batch["attention_mask"] = attn # type: ignore[assignment] visual_inputs = batch.get("visual_inputs") + cp_size = pg_collection.cp.size() if pg_collection is not None and pg_collection.cp is not None else 1 + + if enable_packing: + # Pack sequences + tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids") + ( + packed_tokens, + packed_labels, + packed_loss_mask, + packed_attention_mask, + packed_position_ids, + cu_seqlens, + max_seqlen, + ) = pack_batch_sequences( + tokens=tokens_or_input, + labels=batch.get("labels"), + loss_mask=batch.get("loss_mask"), + attention_mask=batch.get("attention_mask"), + position_ids=batch.get("position_ids"), + pad_token_id=0, + pad_to_multiple_of=cp_size * 2 if cp_size > 1 else 1, + ) + + # Update batch dict with packed tensors + if batch.get("tokens") is not None: + batch["tokens"] = packed_tokens + else: + batch["input_ids"] = packed_tokens + batch["labels"] = packed_labels + batch["loss_mask"] = packed_loss_mask + batch["attention_mask"] = packed_attention_mask + batch["position_ids"] = packed_position_ids + + # # Add packing metadata + logger.debug(f"Packed batch: cu_seqlens={cu_seqlens.tolist()}, max_seqlen={max_seqlen}") + else: + # No packing, use dummy values + cu_seqlens = None + max_seqlen = None return ( (batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")), batch.get("labels"), - batch.get("loss_mask"), + batch.get("loss_mask"), # Full packed loss_mask, will be CP-sliced by model batch.get("attention_mask"), batch.get("position_ids"), - batch.get("cu_seqlens"), - batch.get("cu_seqlens_argmin"), - batch.get("max_seqlen"), + cu_seqlens, + max_seqlen, visual_inputs, ) @@ -227,6 +376,7 @@ def forward_step( use_mtp = (getattr(config, "mtp_num_layers", None) or 0) > 0 timers("batch-generator", log_level=2).start() + pg_collection = get_pg_collection(model) with straggler_timer(bdata=True): ( tokens, @@ -235,10 +385,9 @@ def forward_step( attention_mask, position_ids, cu_seqlens, - cu_seqlens_argmin, max_seqlen, visual_inputs, - ) = get_batch(data_iterator, state.cfg, use_mtp, pg_collection=get_pg_collection(model)) + ) = get_batch(data_iterator, state.cfg, use_mtp, pg_collection=pg_collection) timers("batch-generator").stop() forward_args = { @@ -246,6 +395,7 @@ def forward_step( "position_ids": position_ids, "attention_mask": attention_mask, "labels": labels, + "loss_mask": loss_mask, # Pass full loss_mask so model can slice it consistently with labels } if visual_inputs is not None: @@ -255,7 +405,6 @@ def forward_step( if cu_seqlens is not None: packed_seq_params = { "cu_seqlens": cu_seqlens, - "cu_seqlens_argmin": cu_seqlens_argmin, "max_seqlen": max_seqlen, } forward_args["packed_seq_params"] = get_packed_seq_params(packed_seq_params) @@ -273,7 +422,12 @@ def forward_step( loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) return schedule_plan, loss_function else: - output_tensor = model(**forward_args) + model_output = model(**forward_args) + # Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CP + if isinstance(model_output, tuple): + output_tensor, loss_mask = model_output + else: + output_tensor = model_output loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) diff --git a/src/megatron/bridge/utils/common_utils.py b/src/megatron/bridge/utils/common_utils.py index 59ac803bf5..1946230bd6 100644 --- a/src/megatron/bridge/utils/common_utils.py +++ b/src/megatron/bridge/utils/common_utils.py @@ -271,3 +271,97 @@ def resolve_path(path: str) -> Path: """Resolve a path to an absolute path.""" return Path(path).expanduser().absolute().resolve() + + +def slice_batch_for_context_parallel( + inputs_embeds: torch.Tensor, + labels: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + packed_seq_params, + pg_collection, +): + """Slice batch tensors for Context Parallelism (CP) in VLM models. + + This function handles CP slicing AFTER vision-text embedding merge, ensuring + image token positions are correctly preserved. It supports both: + - THD format (packed sequences): Uses TransformerEngine's thd_get_partitioned_indices + - BSHD format: Uses Megatron's get_batch_on_this_cp_rank with zigzag pattern + + Args: + inputs_embeds: Input embeddings tensor in (T, B, D) format. + labels: Labels tensor. + loss_mask: Loss mask tensor. + position_ids: Position IDs tensor. + attention_mask: Attention mask tensor. + packed_seq_params: PackedSeqParams for THD format, or None for BSHD. + pg_collection: ProcessGroupCollection containing CP group info. + + Returns: + Tuple of (inputs_embeds, labels, loss_mask, position_ids, attention_mask) + with all tensors sliced for this CP rank. inputs_embeds remains in (T, B, D) format. + """ + from megatron.core.utils import get_batch_on_this_cp_rank + + cp_size = pg_collection.cp.size() + if cp_size <= 1: + return inputs_embeds, labels, loss_mask, position_ids, attention_mask + + cp_rank = pg_collection.cp.rank() + + # (T, B, D) -> (B, T, D) for slicing + if inputs_embeds is not None: + inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() + + # For THD (packed) format, use TE's thd_get_partitioned_indices + # This properly slices WITHIN each packed sequence, not across them + if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": + import transformer_engine_torch as tex + + if inputs_embeds is None: + raise ValueError("inputs_embeds is required for THD CP slicing") + + cu_seqlens = packed_seq_params.cu_seqlens_q + cu_seqlens_padded = ( + packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens + ) + seq_len = inputs_embeds.size(1) + + index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank) + + # Slice all tensors using THD indices + if inputs_embeds is not None: + inputs_embeds = inputs_embeds.index_select(1, index) + if labels is not None: + labels = labels.index_select(1, index) + if loss_mask is not None: + loss_mask = loss_mask.index_select(1, index) + if position_ids is not None: + position_ids = position_ids.index_select(1, index) + # Note: attention_mask and packed_seq_params stay unchanged for ring attention + else: + # For BSHD format, use standard zigzag slicing + cp_group = pg_collection.cp + cp_batch = get_batch_on_this_cp_rank( + { + "decoder_input": inputs_embeds, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "attention_mask": attention_mask, + }, + cp_group=cp_group, + ) + + inputs_embeds = cp_batch.get("decoder_input") + labels = cp_batch.get("labels") + loss_mask = cp_batch.get("loss_mask") + position_ids = cp_batch.get("position_ids") + attention_mask = cp_batch.get("attention_mask") + + # Transpose back to (T, B, D) + if inputs_embeds is not None: + inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() + + return inputs_embeds, labels, loss_mask, position_ids, attention_mask diff --git a/tests/functional_tests/L2_Launch_recipes_gemma_vl.sh b/tests/functional_tests/L2_Launch_recipes_gemma_vl.sh new file mode 100644 index 0000000000..18870efe9d --- /dev/null +++ b/tests/functional_tests/L2_Launch_recipes_gemma_vl.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeuo pipefail # Exit immediately if a command exits with a non-zero status + +export CUDA_VISIBLE_DEVICES="0,1" + +# Run Gemma3-VL recipe functional tests on 2 GPUs +# This script tests Gemma3-VL finetune recipe configurations with their default +# settings to ensure they can run basic training without crashes. +uv run python -m torch.distributed.run --nproc_per_node=1 --nnodes=1 \ + -m coverage run --data-file=/opt/Megatron-Bridge/.coverage \ + --source=/opt/Megatron-Bridge/ --parallel-mode \ + -m pytest -o log_cli=true -o log_cli_level=INFO -v -s -x \ + -m "not pleasefixme" --tb=short -rA \ + tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py + +coverage combine -q + diff --git a/tests/functional_tests/models/ministral3_vl/__init__.py b/tests/functional_tests/models/ministral3_vl/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/tests/functional_tests/models/ministral3_vl/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/functional_tests/models/ministral3_vl/test_ministral3_vl_conversion.py b/tests/functional_tests/models/ministral3_vl/test_ministral3_vl_conversion.py new file mode 100644 index 0000000000..1d2d0df65c --- /dev/null +++ b/tests/functional_tests/models/ministral3_vl/test_ministral3_vl_conversion.py @@ -0,0 +1,257 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import subprocess +from pathlib import Path + +import pytest +import torch + + +# Ministral 3 VL toy model configuration based on typical Ministral 3 VL structure +# This is a minimized version for testing purposes +HF_MINISTRAL3_VL_TOY_MODEL_CONFIG = { + "architectures": ["Mistral3ForConditionalGeneration"], + "model_type": "mistral3", + "torch_dtype": "bfloat16", + "transformers_version": "5.0.0", + "image_token_index": 10, + "text_config": { + "model_type": "mistral3_text", + "hidden_size": 512, + "intermediate_size": 1536, # 3 * hidden_size for FFN + "num_hidden_layers": 4, + "num_attention_heads": 8, + "num_key_value_heads": 2, # GQA with 4 query groups + "vocab_size": 32768, + "max_position_embeddings": 4096, + "rms_norm_eps": 1e-5, + "tie_word_embeddings": True, + "rope_theta": 1000000, + "rope_parameters": { + "rope_type": "yarn", + "factor": 16.0, + "original_max_position_embeddings": 16384, + "llama_4_scaling_beta": 0.0, + }, + }, + "vision_config": { + "model_type": "pixtral", + "hidden_size": 256, + "intermediate_size": 1024, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "image_size": 448, + "patch_size": 14, + "num_channels": 3, + }, + "spatial_merge_size": 2, + "vision_feature_layer": -1, +} + + +class TestMinistral3VLConversion: + """ + Test Ministral 3 VL model conversion from local HuggingFace model with different parallelism configurations. + """ + + @pytest.fixture(scope="class") + def ministral3_vl_toy_model_path(self, tmp_path_factory): + """ + Create and save a HuggingFace Ministral 3 VL toy model from config to a temporary directory. + + Args: + tmp_path_factory: Pytest temporary path factory for class-scoped fixtures + + Returns: + str: Path to the saved HuggingFace model directory + """ + # Skip if transformers doesn't have Mistral3 support + pytest.importorskip("transformers", minversion="5.0.0") + + try: + from transformers import Mistral3ForConditionalGeneration + from transformers.models.mistral3.configuration_mistral3 import Mistral3Config + except ImportError: + pytest.skip("Mistral3ForConditionalGeneration not available in transformers") + + # Create a temporary directory for this test class + temp_dir = tmp_path_factory.mktemp("ministral3_vl_toy_model") + model_dir = temp_dir / "ministral3_vl_toy" + + # Create config from the toy model config + config_dict = HF_MINISTRAL3_VL_TOY_MODEL_CONFIG.copy() + + # Create config object + config = Mistral3Config(**config_dict) + config.torch_dtype = torch.bfloat16 + + # Create model with random weights and convert to bfloat16 + model = Mistral3ForConditionalGeneration(config) + model = model.bfloat16() + + # Debug: Check model dtype before saving + for name, param in model.named_parameters(): + print(f"Before save - {name}: {param.dtype}") + break # Just check the first parameter + + # Create minimal tokenizer files + tokenizer_config = { + "tokenizer_class": "LlamaTokenizer", + "vocab_size": 32768, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + } + + model_dir.mkdir(parents=True, exist_ok=True) + with open(model_dir / "tokenizer_config.json", "w") as f: + json.dump(tokenizer_config, f, indent=2) + + # Save model and config to directory + model.save_pretrained(model_dir, safe_serialization=True) + + # Save config.json explicitly to ensure compatibility + config_path = model_dir / "config.json" + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=2) + + return str(model_dir) + + def test_toy_model_creation(self, ministral3_vl_toy_model_path): + """ + Test that the toy model is created correctly and can be loaded. + + Args: + ministral3_vl_toy_model_path: Path to the toy Ministral 3 VL model (from fixture) + """ + # Verify the model directory exists + model_path = Path(ministral3_vl_toy_model_path) + assert model_path.exists(), f"Model directory not found at {model_path}" + + # Check essential files exist + config_file = model_path / "config.json" + assert config_file.exists(), f"config.json not found at {config_file}" + + # Check for model weights (safetensors preferred) + weights_file = model_path / "model.safetensors" + if not weights_file.exists(): + weights_file = model_path / "pytorch_model.bin" + assert weights_file.exists(), f"Model weights file not found in {model_path}" + + # Check for tokenizer files + tokenizer_config_file = model_path / "tokenizer_config.json" + assert tokenizer_config_file.exists(), f"tokenizer_config.json not found at {tokenizer_config_file}" + + # Load and verify config + with open(config_file) as f: + config_data = json.load(f) + + assert config_data["model_type"] == "mistral3" + assert "text_config" in config_data + assert "vision_config" in config_data + assert config_data["text_config"]["hidden_size"] == 512 + assert config_data["text_config"]["num_hidden_layers"] == 4 + assert config_data["text_config"]["num_attention_heads"] == 8 + assert config_data["vision_config"]["hidden_size"] == 256 + + @pytest.mark.run_only_on("GPU") + @pytest.mark.parametrize( + "tp,pp,test_name", + [ + (2, 1, "TP"), + (1, 2, "PP"), + ], + ) + def test_ministral3_vl_conversion_parallelism(self, ministral3_vl_toy_model_path, tmp_path, tp, pp, test_name): + """ + Test Ministral 3 VL model conversion with different parallelism configurations. + + Args: + ministral3_vl_toy_model_path: Path to the toy Ministral 3 VL model (from fixture) + tmp_path: Pytest temporary path fixture + tp: Tensor parallelism size + pp: Pipeline parallelism size + test_name: Name of the test for identification + """ + + # Create temporary output directory for conversion results + test_output_dir = tmp_path / f"ministral3_vl_{test_name}" + test_output_dir.mkdir(exist_ok=True) + + # Run hf_megatron_roundtrip_multi_gpu.py with specified parallelism configuration on our toy model + cmd = [ + "python", + "-m", + "torch.distributed.run", + "--nproc_per_node=2", + "--nnodes=1", + "-m", + "coverage", + "run", + "--data-file=/opt/Megatron-Bridge/.coverage", + "--source=/opt/Megatron-Bridge/", + "--parallel-mode", + "examples/conversion/hf_megatron_roundtrip_multi_gpu.py", + "--hf-model-id", + ministral3_vl_toy_model_path, # Use our local toy model instead of downloading + "--output-dir", + str(test_output_dir), + "--tp", + str(tp), + "--pp", + str(pp), + ] + + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent.parent + ) + print(cmd) + + # Check that the conversion completed successfully + if result.returncode != 0: + print(f"STDOUT: {result.stdout}") + print(f"STDERR: {result.stderr}") + assert False, f"Ministral 3 VL {test_name} conversion failed with return code {result.returncode}" + + # Verify that the converted model was saved + # The output directory should be named after the last part of the model path + model_name = Path(ministral3_vl_toy_model_path).name # "ministral3_vl_toy" + converted_model_dir = test_output_dir / model_name + assert converted_model_dir.exists(), f"Converted model directory not found at {converted_model_dir}" + + # Check that essential model files exist + config_file = converted_model_dir / "config.json" + assert config_file.exists(), f"config.json not found in converted model at {config_file}" + + # Check for model weights file (could be either safetensors or pytorch_model.bin) + weights_file_safetensors = converted_model_dir / "model.safetensors" + weights_file_pytorch = converted_model_dir / "pytorch_model.bin" + assert weights_file_safetensors.exists() or weights_file_pytorch.exists(), ( + f"Model weights file not found in converted model at {converted_model_dir}" + ) + + # Verify the config contains Ministral 3 VL-specific parameters + with open(config_file) as f: + saved_config = json.load(f) + + assert saved_config["model_type"] == "mistral3", "Model type should be mistral3" + assert "text_config" in saved_config, "VL model should have text_config" + assert "vision_config" in saved_config, "VL model should have vision_config" + assert saved_config["text_config"]["hidden_size"] == 512, "Hidden size should match toy config" + assert saved_config["text_config"]["num_attention_heads"] == 8, ( + "Number of attention heads should match toy config" + ) diff --git a/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py b/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py index 6d4ae51323..3b94e2f84c 100644 --- a/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py +++ b/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py @@ -36,7 +36,7 @@ class TestGemma3VLRecipes: """Test class for Gemma3-VL recipe functional tests.""" @pytest.mark.run_only_on("GPU") - @pytest.mark.parametrize("config_func,recipe_name,parallelism_overrides", GEMMA3_VL_FINETUNE_RECIPES) + @pytest.mark.parametrize("config_func,recipe_name,model_overrides", GEMMA3_VL_FINETUNE_RECIPES) def test_gemma3_vl_finetune_recipes(self, config_func, recipe_name, model_overrides, tmp_path): """Functional test for Gemma3-VL recipes with appropriate parallelism configurations.""" run_pretrain_vl_recipe_test(config_func, recipe_name, tmp_path, model_overrides=model_overrides) diff --git a/tests/unit_tests/models/gemma/test_gemma3_provider.py b/tests/unit_tests/models/gemma/test_gemma3_provider.py index 1e842ad42d..a06d07c104 100644 --- a/tests/unit_tests/models/gemma/test_gemma3_provider.py +++ b/tests/unit_tests/models/gemma/test_gemma3_provider.py @@ -416,6 +416,86 @@ def test_gemma3_rotary_embedding_initialization(self): # Verify that RotaryEmbedding was called for local rope assert mock_rotary_embedding.call_count >= 1 + def test_gemma3_rotary_embedding_forward_with_cp_group(self): + """Test Gemma3RotaryEmbedding forward method with cp_group (non-None path).""" + # Create a minimal Gemma3RotaryEmbedding instance via __new__ to avoid complex init + rope_emb = Gemma3RotaryEmbedding.__new__(Gemma3RotaryEmbedding) + + # Mock the rope_local attribute + mock_rope_local = Mock() + mock_rope_local.forward = Mock(return_value=torch.tensor([1.0, 2.0])) + rope_emb.rope_local = mock_rope_local + + # Mock the parent class forward method (called via super().forward) + mock_global_output = torch.tensor([3.0, 4.0]) + mock_local_output = torch.tensor([1.0, 2.0]) + + # Create a mock cp_group (ProcessGroup) + mock_cp_group = Mock() + + with patch.object( + Gemma3RotaryEmbedding.__bases__[0], "forward", return_value=mock_global_output + ) as mock_super_forward: + result = rope_emb.forward(max_seq_len=1024, offset=0, packed_seq=False, cp_group=mock_cp_group) + + # Verify super().forward was called with cp_group + mock_super_forward.assert_called_once_with(1024, 0, False, mock_cp_group) + + # Verify rope_local.forward was called with cp_group + mock_rope_local.forward.assert_called_once_with(1024, 0, False, mock_cp_group) + + # Verify return is (rope_local, rope_global) tuple + assert isinstance(result, tuple) + assert len(result) == 2 + rope_local_result, rope_global_result = result + assert torch.equal(rope_local_result, mock_local_output) + assert torch.equal(rope_global_result, mock_global_output) + + def test_gemma3_rotary_embedding_forward_without_cp_group(self): + """Test Gemma3RotaryEmbedding forward method without cp_group (cached path).""" + # Create a minimal Gemma3RotaryEmbedding instance via __new__ + rope_emb = Gemma3RotaryEmbedding.__new__(Gemma3RotaryEmbedding) + + # Mock the _forward_cached method + mock_cached_result = (torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])) + rope_emb._forward_cached = Mock(return_value=mock_cached_result) + + # Call forward without cp_group (None) + result = rope_emb.forward(max_seq_len=1024, offset=0, packed_seq=False, cp_group=None) + + # Verify _forward_cached was called + rope_emb._forward_cached.assert_called_once_with(1024, 0, False) + + # Verify result matches cached result + assert result == mock_cached_result + + def test_gemma3_rotary_embedding_forward_cached(self): + """Test Gemma3RotaryEmbedding _forward_cached method.""" + # Create a minimal Gemma3RotaryEmbedding instance via __new__ + rope_emb = Gemma3RotaryEmbedding.__new__(Gemma3RotaryEmbedding) + + # Mock the rope_local attribute + mock_rope_local = Mock() + mock_rope_local.forward = Mock(return_value=torch.tensor([1.0, 2.0])) + rope_emb.rope_local = mock_rope_local + + mock_global_output = torch.tensor([3.0, 4.0]) + + with patch.object( + Gemma3RotaryEmbedding.__bases__[0], "forward", return_value=mock_global_output + ) as mock_super_forward: + result = rope_emb._forward_cached(max_seq_len=512, offset=10, packed_seq=True) + + # Verify super().forward was called with cp_group=None + mock_super_forward.assert_called_once_with(512, 10, True, None) + + # Verify rope_local.forward was called with cp_group=None + mock_rope_local.forward.assert_called_once_with(512, 10, True, None) + + # Verify return is (rope_local, rope_global) tuple + assert isinstance(result, tuple) + assert len(result) == 2 + def test_te_row_parallel_linear_layer_norm(self): """Test TERowParallelLinearLayerNorm initialization and forward.""" # Test that the class exists and can be imported diff --git a/tests/unit_tests/models/ministral3/test_ministral3_provider.py b/tests/unit_tests/models/ministral3/test_ministral3_provider.py index 01564a61fc..c8058e1343 100644 --- a/tests/unit_tests/models/ministral3/test_ministral3_provider.py +++ b/tests/unit_tests/models/ministral3/test_ministral3_provider.py @@ -12,15 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import torch from megatron.bridge.models.ministral3.ministral3_provider import ( Ministral3ModelProvider, Ministral3ModelProvider3B, Ministral3ModelProvider8B, Ministral3ModelProvider14B, + MinistralTEDotProductAttention, ) +pytestmark = pytest.mark.unit + + class TestMinistral3ModelProvider: """Test cases for Ministral3ModelProvider base class.""" @@ -172,3 +178,166 @@ def test_ministral3_14b_initialization(self): assert provider.ffn_hidden_size == 16384 assert provider.num_layers == 40 assert provider.rotary_base == 1000000000.0 + + +class TestGetLlama4AttnScale: + """Test cases for _get_llama_4_attn_scale function used in MinistralTEDotProductAttention. + + This function computes attention scaling based on Llama 4 attention parameters. + The key change in PR 1997 is that it now handles different query shapes for + packed (3D) vs unpacked (4D) tensors. + """ + + # Use the actual production implementation + _get_llama_4_attn_scale = staticmethod(MinistralTEDotProductAttention._get_llama_4_attn_scale) + + def test_unpacked_4d_query_shape(self): + """Test attention scaling with unpacked 4D query shape [seq_len, batch, num_heads, head_dim].""" + seq_len = 8 + batch_size = 2 + num_heads = 4 + head_dim = 64 + + positions_ids = torch.arange(seq_len) + beta = 0.1 + max_position_embeddings = 16384 + query_shape = (seq_len, batch_size, num_heads, head_dim) + + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) + + # Output should have shape [seq_len, 1, 1, 1] for broadcasting + assert scaling.shape == (seq_len, 1, 1, 1) + + # First position should have scaling = 1 (since log(1 + 0) = 0) + expected_first = 1 + beta * torch.log(torch.tensor(1.0)) + assert torch.isclose(scaling[0, 0, 0, 0], expected_first, atol=1e-6) + + def test_packed_3d_query_shape(self): + """Test attention scaling with packed 3D query shape [seq_len, num_heads, head_dim].""" + seq_len = 16 + num_heads = 8 + head_dim = 32 + + positions_ids = torch.arange(seq_len) + beta = 0.2 + max_position_embeddings = 8192 + query_shape = (seq_len, num_heads, head_dim) + + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) + + # Output should have shape [seq_len, 1, 1] for broadcasting (3D - 1 = 2 dims added) + assert scaling.shape == (seq_len, 1, 1) + + # Verify scaling values are computed correctly + expected = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) + assert torch.allclose(scaling.squeeze(), expected, atol=1e-6) + + def test_scaling_formula_correctness(self): + """Test that the scaling formula matches expected Llama 4 attention scaling.""" + positions_ids = torch.tensor([0, 1, 100, 1000, 16384, 32768]) + beta = 0.15 + max_position_embeddings = 16384 + query_shape = (6, 1, 1, 1) + + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) + + # Manual computation of expected values + # For position 0: 1 + 0.15 * log(1 + 0) = 1 + # For position 16384: 1 + 0.15 * log(1 + 1) = 1 + 0.15 * log(2) + # For position 32768: 1 + 0.15 * log(1 + 2) = 1 + 0.15 * log(3) + + expected_0 = 1.0 + expected_16384 = 1 + beta * torch.log(torch.tensor(2.0)) + expected_32768 = 1 + beta * torch.log(torch.tensor(3.0)) + + assert torch.isclose(scaling[0].squeeze(), torch.tensor(expected_0), atol=1e-6) + assert torch.isclose(scaling[4].squeeze(), expected_16384, atol=1e-6) + assert torch.isclose(scaling[5].squeeze(), expected_32768, atol=1e-6) + + def test_beta_zero_returns_ones(self): + """Test that beta=0 returns all ones (no scaling).""" + positions_ids = torch.arange(10) + beta = 0.0 + max_position_embeddings = 4096 + query_shape = (10, 4, 64) + + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) + + assert torch.allclose(scaling.squeeze(), torch.ones(10), atol=1e-6) + + def test_different_query_shapes_get_correct_dims(self): + """Test that different query shapes result in correct number of dimensions added.""" + positions_ids = torch.arange(4) + beta = 0.1 + max_position_embeddings = 1000 + + # 2D query shape + query_shape_2d = (4, 32) + scaling_2d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_2d) + assert scaling_2d.shape == (4, 1) # 2-1 = 1 dim added + + # 3D query shape (packed THD) + query_shape_3d = (4, 8, 32) + scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_3d) + assert scaling_3d.shape == (4, 1, 1) # 3-1 = 2 dims added + + # 4D query shape (unpacked BSHD) + query_shape_4d = (4, 2, 8, 32) + scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_4d) + assert scaling_4d.shape == (4, 1, 1, 1) # 4-1 = 3 dims added + + def test_broadcasting_compatibility(self): + """Test that scaling tensor is broadcastable to query tensor.""" + seq_len = 8 + num_heads = 4 + head_dim = 64 + + positions_ids = torch.arange(seq_len) + beta = 0.1 + max_position_embeddings = 16384 + + # Test for 3D packed format + query_3d = torch.randn(seq_len, num_heads, head_dim) + scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_3d.shape) + + # Broadcasting should work + result_3d = query_3d * scaling_3d.to(query_3d.dtype) + assert result_3d.shape == query_3d.shape + + # Test for 4D unpacked format + batch = 2 + query_4d = torch.randn(seq_len, batch, num_heads, head_dim) + scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_4d.shape) + + # Broadcasting should work + result_4d = query_4d * scaling_4d.to(query_4d.dtype) + assert result_4d.shape == query_4d.shape + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") + def test_gpu_tensor_support(self): + """Test that the function works with GPU tensors if available.""" + positions_ids = torch.arange(8, device="cuda") + beta = 0.1 + max_position_embeddings = 1024 + query_shape = (8, 4, 32) + + scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) + + assert scaling.device.type == "cuda" + assert scaling.shape == (8, 1, 1) + + def test_dtype_preservation(self): + """Test that output dtype matches input positions_ids dtype.""" + positions_ids_float32 = torch.arange(4, dtype=torch.float32) + positions_ids_float64 = torch.arange(4, dtype=torch.float64) + beta = 0.1 + max_position_embeddings = 100 + query_shape = (4, 2, 8) + + scaling_32 = self._get_llama_4_attn_scale(positions_ids_float32, beta, max_position_embeddings, query_shape) + scaling_64 = self._get_llama_4_attn_scale(positions_ids_float64, beta, max_position_embeddings, query_shape) + + # Note: torch.arange with int creates int tensors, but the function uses float operations + # The scaling result will be float due to log operation + assert scaling_32.dtype == torch.float32 + assert scaling_64.dtype == torch.float64 diff --git a/tests/unit_tests/training/test_gpt_step.py b/tests/unit_tests/training/test_gpt_step.py index 141fd9827f..d677c1334f 100644 --- a/tests/unit_tests/training/test_gpt_step.py +++ b/tests/unit_tests/training/test_gpt_step.py @@ -115,9 +115,8 @@ def test_packed_seq_params_no_padding(self): # Verify the result is a PackedSeqParams object assert isinstance(result, PackedSeqParams) - # When there's no -1 padding, argmin returns 0 (index of min value) - # So cu_seqlens[:0] returns empty tensor - expected_cu_seqlens = torch.empty(0, dtype=torch.int32) # Empty tensor + # When there's no -1 padding, the tensor is returned unchanged + expected_cu_seqlens = torch.tensor([0, 7, 14], dtype=torch.int32) assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens) assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens) @@ -181,6 +180,113 @@ def test_packed_seq_params_all_fields_match(self): assert torch.equal(result.cu_seqlens_q, result.cu_seqlens_kv) assert torch.equal(result.max_seqlen_q, result.max_seqlen_kv) + def test_packed_seq_params_with_cu_seqlens_unpadded(self): + """Test functionality with cu_seqlens_unpadded for THD CP support.""" + # Padded cu_seqlens (includes padding for CP divisibility) + cu_seqlens_padded = torch.tensor([[0, 8, 16, -1, -1]], dtype=torch.int32) + # Unpadded cu_seqlens (actual sequence boundaries) + cu_seqlens_unpadded = torch.tensor([[0, 6, 14, -1, -1]], dtype=torch.int32) + + batch = { + "cu_seqlens": cu_seqlens_padded, + "cu_seqlens_unpadded": cu_seqlens_unpadded, + "max_seqlen": torch.tensor([[10]], dtype=torch.int32), + } + + result = get_packed_seq_params(batch) + + # cu_seqlens_q and cu_seqlens_kv should use unpadded values + expected_unpadded = torch.tensor([0, 6, 14], dtype=torch.int32) + assert torch.equal(result.cu_seqlens_q, expected_unpadded) + assert torch.equal(result.cu_seqlens_kv, expected_unpadded) + + # cu_seqlens_q_padded and cu_seqlens_kv_padded should use padded values + expected_padded = torch.tensor([0, 8, 16], dtype=torch.int32) + assert torch.equal(result.cu_seqlens_q_padded, expected_padded) + assert torch.equal(result.cu_seqlens_kv_padded, expected_padded) + + def test_packed_seq_params_cu_seqlens_unpadded_with_argmin(self): + """Test cu_seqlens_unpadded processing with argmin hint.""" + batch = { + "cu_seqlens": torch.tensor([[0, 4, 8, 12, -1, -1]], dtype=torch.int32), + "cu_seqlens_argmin": torch.tensor(4), # Index where -1 starts + "cu_seqlens_unpadded": torch.tensor([[0, 3, 7, 10, -1, -1]], dtype=torch.int32), + "cu_seqlens_unpadded_argmin": torch.tensor(4), # Index where -1 starts + } + + result = get_packed_seq_params(batch) + + # Verify unpadded values are used for q/kv + expected_unpadded = torch.tensor([0, 3, 7, 10], dtype=torch.int32) + assert torch.equal(result.cu_seqlens_q, expected_unpadded) + assert torch.equal(result.cu_seqlens_kv, expected_unpadded) + + # Verify padded values are set for _padded fields + expected_padded = torch.tensor([0, 4, 8, 12], dtype=torch.int32) + assert torch.equal(result.cu_seqlens_q_padded, expected_padded) + assert torch.equal(result.cu_seqlens_kv_padded, expected_padded) + + def test_packed_seq_params_without_unpadded_fallback(self): + """Test fallback to cu_seqlens when cu_seqlens_unpadded is not provided.""" + batch = { + "cu_seqlens": torch.tensor([[0, 5, 10, 15, -1]], dtype=torch.int32), + "max_seqlen": torch.tensor([[8]], dtype=torch.int32), + } + + result = get_packed_seq_params(batch) + + expected_cu_seqlens = torch.tensor([0, 5, 10, 15], dtype=torch.int32) + + # Without unpadded, q/kv should use padded values + assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens) + assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens) + + # Padded fields should be None when cu_seqlens_unpadded is not provided + # (to avoid slower TE kernel paths) + assert result.cu_seqlens_q_padded is None + assert result.cu_seqlens_kv_padded is None + + def test_packed_seq_params_no_padding_in_cu_seqlens(self): + """Test when cu_seqlens has no -1 padding markers.""" + batch = { + "cu_seqlens": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1 padding + "max_seqlen": torch.tensor([[7]], dtype=torch.int32), + } + + result = get_packed_seq_params(batch) + + # When no -1 present and min != -1, the tensor should remain as-is + expected = torch.tensor([0, 5, 10], dtype=torch.int32) + assert torch.equal(result.cu_seqlens_q, expected) + # Padded fields are None when cu_seqlens_unpadded is not provided + assert result.cu_seqlens_q_padded is None + + def test_packed_seq_params_qkv_format_is_thd(self): + """Test that qkv_format is always set to 'thd'.""" + batch = { + "cu_seqlens": torch.tensor([[0, 10, -1]], dtype=torch.int32), + } + + result = get_packed_seq_params(batch) + + assert result.qkv_format == "thd" + + def test_packed_seq_params_cu_seqlens_unpadded_no_padding(self): + """Test cu_seqlens_unpadded with no padding markers.""" + batch = { + "cu_seqlens": torch.tensor([[0, 6, 12]], dtype=torch.int32), + "cu_seqlens_unpadded": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1 + } + + result = get_packed_seq_params(batch) + + # Unpadded should be used as-is since no -1 and min != -1 + expected_unpadded = torch.tensor([0, 5, 10], dtype=torch.int32) + expected_padded = torch.tensor([0, 6, 12], dtype=torch.int32) + + assert torch.equal(result.cu_seqlens_q, expected_unpadded) + assert torch.equal(result.cu_seqlens_q_padded, expected_padded) + class TestCreateLossFunction: """Tests for the _create_loss_function helper function.""" diff --git a/tests/unit_tests/training/test_vlm_step.py b/tests/unit_tests/training/test_vlm_step.py index 7a4cd47016..778a15b326 100755 --- a/tests/unit_tests/training/test_vlm_step.py +++ b/tests/unit_tests/training/test_vlm_step.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from megatron.bridge.training.utils.visual_inputs import Qwen2_5_VLVisualInputs -from megatron.bridge.training.vlm_step import forward_step, get_batch, get_batch_from_iterator +from megatron.bridge.training.vlm_step import ( + forward_step, + get_batch, + get_batch_from_iterator, + pack_batch_sequences, +) class _Iterator: @@ -92,6 +98,30 @@ def _as_nocuda(t): assert out_vi.pixel_values is not None and out_vi.image_grid_thw is not None +class _MockProcessGroup: + """Mock process group with rank/size methods for testing.""" + + def rank(self): + return 0 + + def size(self): + return 1 + + +class _MockPGCollection: + """Mock PG collection for testing.""" + + def __init__(self, cp_size=1): + self.pp = _MockProcessGroup() + self._cp_size = cp_size + + @property + def cp(self): + pg = _MockProcessGroup() + pg.size = lambda: self._cp_size + return pg + + def test_get_batch_padding_paths(monkeypatch): # Simulate both first and last pipeline stages so tensors are returned monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True) @@ -135,21 +165,8 @@ def test_get_batch_padding_paths(monkeypatch): # Iterator it = _Iterator(batch) - # Create a proper mock pg_collection with rank/size methods - class _MockProcessGroup: - def rank(self): - return 0 - - def size(self): - return 1 - - class _PG: - def __init__(self): - self.pp = _MockProcessGroup() - self.cp = _MockProcessGroup() - tokens, labels, loss_mask, attention_mask, position_ids, *_ = get_batch( - it, cfg, use_mtp=False, pg_collection=_PG() + it, cfg, use_mtp=False, pg_collection=_MockPGCollection() ) # Length padded up to min(seq_cap, ceil_to_128(4)) == 32 assert tokens.shape[1] == 32 @@ -158,6 +175,175 @@ def __init__(self): assert position_ids.shape[1] == 32 +def test_get_batch_enable_packing_path(monkeypatch): + """Test get_batch with pack_sequences_in_batch=True (enable_packing path).""" + # Simulate both first and last pipeline stages so tensors are returned + monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True) + monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_last_stage", lambda pg: True, raising=True) + + # Disable context parallel slicing effects + monkeypatch.setattr( + "megatron.core.utils.get_batch_on_this_cp_rank", + lambda x: x, + raising=True, + ) + + # Config with packing enabled + cfg = type("Cfg", (), {})() + cfg.model = type( + "M", + (), + { + "seq_length": 64, + "pipeline_model_parallel_size": 1, + }, + )() + cfg.dataset = type( + "D", + (), + { + "skip_getting_attention_mask_from_dataset": True, + "pack_sequences_in_batch": True, # Enable packing + }, + )() + + # Batch with 2 sequences of different lengths (with padding) + # Seq 1: [1, 2, 3, 0, 0, 0, 0, 0] - length 3 + # Seq 2: [4, 5, 6, 7, 8, 0, 0, 0] - length 5 + tokens = torch.tensor( + [ + [1, 2, 3, 0, 0, 0, 0, 0], + [4, 5, 6, 7, 8, 0, 0, 0], + ] + ) + labels = torch.tensor( + [ + [2, 3, -100, -100, -100, -100, -100, -100], + [5, 6, 7, 8, -100, -100, -100, -100], + ] + ) + loss_mask = torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + position_ids = torch.arange(8).unsqueeze(0).expand(2, -1).clone() + + vi = Qwen2_5_VLVisualInputs(pixel_values=torch.randn(1, 1, 3, 4, 4), image_grid_thw=torch.tensor([[[1, 2, 2]]])) + batch = { + "input_ids": tokens, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "attention_mask": None, + "visual_inputs": vi, + } + + it = _Iterator(batch) + + ( + out_tokens, + out_labels, + out_loss_mask, + out_attention_mask, + out_position_ids, + cu_seqlens, + max_seqlen, + visual_inputs, + ) = get_batch(it, cfg, use_mtp=False, pg_collection=_MockPGCollection()) + + # Verify packing occurred + # With pad_to_multiple_of=1 (cp_size=1), total packed length = 3 + 5 = 8 + assert out_tokens.shape == (1, 8), f"Expected packed shape (1, 8), got {out_tokens.shape}" + assert out_labels.shape == (1, 8) + assert out_loss_mask.shape == (1, 8) + assert out_position_ids.shape == (1, 8) + + # Verify cu_seqlens is populated (not None) + assert cu_seqlens is not None, "cu_seqlens should be set when packing is enabled" + assert cu_seqlens.tolist() == [0, 3, 8], f"Expected cu_seqlens [0, 3, 8], got {cu_seqlens.tolist()}" + + # Verify max_seqlen + assert max_seqlen is not None, "max_seqlen should be set when packing is enabled" + assert max_seqlen.item() == 5, f"Expected max_seqlen 5, got {max_seqlen.item()}" + + # Verify attention_mask is None for packed sequences + assert out_attention_mask is None, "attention_mask should be None for packed sequences" + + # Verify packed tokens content + expected_tokens = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + assert torch.equal(out_tokens.cpu(), expected_tokens), f"Expected {expected_tokens}, got {out_tokens}" + + # Verify visual_inputs passed through + assert visual_inputs is not None + + +def test_get_batch_enable_packing_with_cp(monkeypatch): + """Test get_batch packing with context parallelism (pad_to_multiple_of > 1).""" + monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True) + monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_last_stage", lambda pg: True, raising=True) + monkeypatch.setattr( + "megatron.core.utils.get_batch_on_this_cp_rank", + lambda x: x, + raising=True, + ) + + cfg = type("Cfg", (), {})() + cfg.model = type("M", (), {"seq_length": 64, "pipeline_model_parallel_size": 1})() + cfg.dataset = type( + "D", + (), + { + "skip_getting_attention_mask_from_dataset": True, + "pack_sequences_in_batch": True, + }, + )() + + # Sequences: length 3 and length 5 + # With CP=2, pad_to_multiple_of = 2*2 = 4 + # Seq 1: 3 -> padded to 4 + # Seq 2: 5 -> padded to 6 + # Total: 4 + 6 = 10 + tokens = torch.tensor( + [ + [1, 2, 3, 0, 0, 0, 0, 0], + [4, 5, 6, 7, 8, 0, 0, 0], + ] + ) + labels = torch.tensor( + [ + [2, 3, -100, -100, -100, -100, -100, -100], + [5, 6, 7, 8, -100, -100, -100, -100], + ] + ) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.arange(8).unsqueeze(0).expand(2, -1).clone() + + batch = { + "input_ids": tokens, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "attention_mask": None, + "visual_inputs": None, + } + + it = _Iterator(batch) + + # Use CP size of 2 + out_tokens, out_labels, out_loss_mask, _, out_position_ids, cu_seqlens, max_seqlen, _ = get_batch( + it, cfg, use_mtp=False, pg_collection=_MockPGCollection(cp_size=2) + ) + + # With CP=2, pad_to_multiple_of = 4 + # Seq 1: 3 -> 4, Seq 2: 5 -> 8 (next multiple of 4) + # Total: 4 + 8 = 12 + assert out_tokens.shape[1] == 12, f"Expected packed length 12, got {out_tokens.shape[1]}" + assert cu_seqlens.tolist() == [0, 4, 12], f"Expected cu_seqlens [0, 4, 12], got {cu_seqlens.tolist()}" + assert max_seqlen.item() == 8, f"Expected max_seqlen 8, got {max_seqlen.item()}" + + def test_forward_step_schedule_plan(monkeypatch): # Configure pipeline last/first to enable labels & loss_mask path monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True) @@ -264,3 +450,382 @@ def __init__(self): # Execute schedule plan path plan, loss_fn = forward_step(state, it, model, return_schedule_plan=True) assert isinstance(plan, torch.Tensor) + + +class TestPackBatchSequences: + """Tests for the pack_batch_sequences function.""" + + def test_basic_packing(self): + """Test basic sequence packing functionality.""" + batch_size, seq_len = 2, 8 + # Tokens with padding at the end (pad_token_id=0) + tokens = torch.tensor( + [ + [1, 2, 3, 0, 0, 0, 0, 0], # length 3 + [4, 5, 6, 7, 0, 0, 0, 0], # length 4 + ] + ) + labels = torch.tensor( + [ + [2, 3, -100, -100, -100, -100, -100, -100], + [5, 6, 7, -100, -100, -100, -100, -100], + ] + ) + loss_mask = torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + attention_mask = None + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=attention_mask, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=1, + ) + + packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result + + # Packed output should have shape [1, total_valid_len] + assert packed_tokens.shape[0] == 1 + total_len = packed_tokens.shape[1] + assert total_len == 7 # 3 + 4 + + # cu_seqlens should have num_sequences + 1 elements + assert len(cu_seqlens) == 3 # [0, 3, 7] + assert cu_seqlens[0] == 0 + assert cu_seqlens[1] == 3 # first sequence length + assert cu_seqlens[2] == 7 # total length + + # max_seqlen should be max of sequence lengths + assert max_seqlen.item() == 4 + + # Attention mask should be None for packed sequences + assert packed_attn is None + + def test_packing_with_pad_to_multiple_of(self): + """Test packing with padding to a multiple (for CP compatibility).""" + batch_size = 2 + tokens = torch.tensor( + [ + [1, 2, 3, 0, 0, 0, 0, 0, 0, 0], # length 3 -> padded to 4 (mult of 2) + [4, 5, 6, 7, 8, 0, 0, 0, 0, 0], # length 5 -> padded to 6 (mult of 2) + ] + ) + labels = torch.tensor( + [ + [2, 3, -100, -100, -100, -100, -100, -100, -100, -100], + [5, 6, 7, 8, -100, -100, -100, -100, -100, -100], + ] + ) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.arange(10).unsqueeze(0).expand(batch_size, -1) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=2, # Pad each sequence to multiple of 2 + ) + + packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result + + # Total length should be 4 + 6 = 10 (padded lengths) + assert packed_tokens.shape[1] == 10 + + # cu_seqlens should use padded lengths + assert cu_seqlens[0] == 0 + assert cu_seqlens[1] == 4 # 3 -> 4 (padded) + assert cu_seqlens[2] == 10 # 5 -> 6, total = 4 + 6 + + # max_seqlen should be 6 (longest padded sequence) + assert max_seqlen.item() == 6 + + def test_packing_with_larger_multiple(self): + """Test packing with larger pad_to_multiple_of (e.g., for CP=4).""" + tokens = torch.tensor( + [ + [1, 2, 0, 0, 0, 0, 0, 0], # length 2 -> padded to 4 + [3, 4, 5, 0, 0, 0, 0, 0], # length 3 -> padded to 4 + ] + ) + labels = torch.tensor( + [ + [2, -100, -100, -100, -100, -100, -100, -100], + [4, 5, -100, -100, -100, -100, -100, -100], + ] + ) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.arange(8).unsqueeze(0).expand(2, -1) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=4, + ) + + packed_tokens, *_, cu_seqlens, max_seqlen = result + + # Both sequences padded to 4, total = 8 + assert packed_tokens.shape[1] == 8 + assert cu_seqlens.tolist() == [0, 4, 8] + assert max_seqlen.item() == 4 + + def test_packing_single_sequence(self): + """Test packing a single sequence.""" + tokens = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]) # length 5 + labels = torch.tensor([[2, 3, 4, 5, -100, -100, -100, -100]]) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.arange(8).unsqueeze(0) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=1, + ) + + packed_tokens, *_, cu_seqlens, max_seqlen = result + + assert packed_tokens.shape[1] == 5 + assert cu_seqlens.tolist() == [0, 5] + assert max_seqlen.item() == 5 + + def test_packing_no_padding_sequences(self): + """Test packing sequences with no padding.""" + tokens = torch.tensor( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ) + labels = torch.tensor( + [ + [2, 3, 4, -100], + [6, 7, 8, -100], + ] + ) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.arange(4).unsqueeze(0).expand(2, -1) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=1, + ) + + packed_tokens, *_, cu_seqlens, max_seqlen = result + + # Both sequences full length + assert packed_tokens.shape[1] == 8 + assert cu_seqlens.tolist() == [0, 4, 8] + + def test_packing_preserves_loss_mask_zeros(self): + """Test that loss_mask zeros are preserved during packing.""" + tokens = torch.tensor([[1, 2, 3, 0, 0]]) + labels = torch.tensor([[2, 3, -100, -100, -100]]) + loss_mask = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0]]) # Second token masked + position_ids = torch.arange(5).unsqueeze(0) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=1, + ) + + packed_tokens, packed_labels, packed_loss_mask, *_ = result + + # Only first 3 tokens should be kept + assert packed_loss_mask.shape[1] == 3 + assert packed_loss_mask[0, 0].item() == 1.0 + assert packed_loss_mask[0, 1].item() == 0.0 # Preserved + assert packed_loss_mask[0, 2].item() == 1.0 + + def test_packing_position_ids_reset(self): + """Test that position_ids are correctly packed.""" + tokens = torch.tensor( + [ + [1, 2, 0, 0], + [3, 4, 5, 0], + ] + ) + labels = torch.zeros_like(tokens) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.tensor( + [ + [0, 1, 2, 3], + [0, 1, 2, 3], + ] + ) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=1, + ) + + _, _, _, _, packed_pos, *_ = result + + # Position IDs should be extracted from original sequences + assert packed_pos.shape[1] == 5 # 2 + 3 + assert packed_pos[0, 0].item() == 0 # First seq, pos 0 + assert packed_pos[0, 1].item() == 1 # First seq, pos 1 + assert packed_pos[0, 2].item() == 0 # Second seq, pos 0 + assert packed_pos[0, 3].item() == 1 # Second seq, pos 1 + assert packed_pos[0, 4].item() == 2 # Second seq, pos 2 + + def test_packing_empty_batch_warning(self, caplog): + """Test that all-padding batch returns empty tensors with warning.""" + tokens = torch.tensor([[0, 0, 0, 0]]) # All padding + labels = torch.tensor([[-100, -100, -100, -100]]) + loss_mask = torch.zeros(1, 4) + position_ids = torch.arange(4).unsqueeze(0) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=1, + ) + + packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result + + # No valid sequences found, should return empty tensors + assert packed_tokens.shape == (1, 0) + assert packed_labels.shape == (1, 0) + assert packed_loss_mask.shape == (1, 0) + assert packed_pos.shape == (1, 0) + # cu_seqlens should have just [0] for empty batch + assert len(cu_seqlens) == 1 + assert cu_seqlens[0].item() == 0 + assert max_seqlen.item() == 0 + + def test_packing_different_dtypes(self): + """Test packing with different tensor dtypes.""" + tokens = torch.tensor([[1, 2, 3, 0]], dtype=torch.long) + labels = torch.tensor([[2, 3, -100, -100]], dtype=torch.long) + loss_mask = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32) + position_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=1, + ) + + packed_tokens, packed_labels, packed_loss_mask, _, packed_pos, cu_seqlens, _ = result + + # Dtypes should be preserved + assert packed_tokens.dtype == torch.long + assert packed_labels.dtype == torch.long + assert packed_loss_mask.dtype == torch.float32 + assert packed_pos.dtype == torch.long + assert cu_seqlens.dtype == torch.int32 + + def test_packing_padding_extends_position_ids(self): + """Test that padding extends position_ids correctly.""" + tokens = torch.tensor([[1, 2, 3, 0]]) # length 3 + labels = torch.zeros_like(tokens) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.tensor([[0, 1, 2, 3]]) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + pad_to_multiple_of=4, # Pad to 4 + ) + + _, _, _, _, packed_pos, cu_seqlens, _ = result + + # Length should be 4 (padded) + assert packed_pos.shape[1] == 4 + + # Original positions should be preserved + assert packed_pos[0, 0].item() == 0 + assert packed_pos[0, 1].item() == 1 + assert packed_pos[0, 2].item() == 2 + # Padding position should be extended + assert packed_pos[0, 3].item() == 3 + + def test_packing_cu_seqlens_dtype(self): + """Test that cu_seqlens is int32 as expected by attention kernels.""" + tokens = torch.tensor([[1, 2, 0]]) + labels = torch.zeros_like(tokens) + loss_mask = torch.ones_like(tokens, dtype=torch.float) + position_ids = torch.arange(3).unsqueeze(0) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + ) + + _, _, _, _, _, cu_seqlens, _ = result + + assert cu_seqlens.dtype == torch.int32 + + def test_packing_gpu_tensor(self): + """Test packing works on GPU if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + tokens = torch.tensor([[1, 2, 3, 0, 0]], device="cuda") + labels = torch.tensor([[2, 3, -100, -100, -100]], device="cuda") + loss_mask = torch.ones_like(tokens, dtype=torch.float, device="cuda") + position_ids = torch.arange(5, device="cuda").unsqueeze(0) + + result = pack_batch_sequences( + tokens=tokens, + labels=labels, + loss_mask=loss_mask, + attention_mask=None, + position_ids=position_ids, + pad_token_id=0, + ) + + packed_tokens, _, _, _, _, cu_seqlens, _ = result + + assert packed_tokens.device.type == "cuda" + assert cu_seqlens.device.type == "cuda" diff --git a/tests/unit_tests/utils/test_slice_batch_for_cp.py b/tests/unit_tests/utils/test_slice_batch_for_cp.py new file mode 100644 index 0000000000..2ed4b96eaf --- /dev/null +++ b/tests/unit_tests/utils/test_slice_batch_for_cp.py @@ -0,0 +1,495 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for slice_batch_for_context_parallel function in common_utils.""" + +from dataclasses import dataclass +from typing import Optional +from unittest.mock import MagicMock, patch + +import torch + +from megatron.bridge.utils.common_utils import slice_batch_for_context_parallel + + +@dataclass +class MockPackedSeqParams: + """Mock PackedSeqParams for testing THD format.""" + + cu_seqlens_q: torch.Tensor + cu_seqlens_kv: torch.Tensor + cu_seqlens_q_padded: Optional[torch.Tensor] = None + cu_seqlens_kv_padded: Optional[torch.Tensor] = None + max_seqlen_q: Optional[torch.Tensor] = None + max_seqlen_kv: Optional[torch.Tensor] = None + qkv_format: str = "thd" + + +class MockPGCollection: + """Mock ProcessGroupCollection for testing.""" + + def __init__(self, cp_size: int = 1, cp_rank: int = 0): + self._cp_size = cp_size + self._cp_rank = cp_rank + self.cp = MagicMock() + self.cp.size.return_value = cp_size + self.cp.rank.return_value = cp_rank + + +class TestSliceBatchForContextParallelCpSize1: + """Tests for slice_batch_for_context_parallel when CP size <= 1 (no-op case).""" + + def test_cp_size_1_returns_unchanged(self): + """Test that tensors are returned unchanged when cp_size=1.""" + batch_size, seq_len, hidden = 2, 16, 64 + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + labels = torch.randint(0, 1000, (batch_size, seq_len)) + loss_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len) + + pg_collection = MockPGCollection(cp_size=1, cp_rank=0) + + result = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=None, + pg_collection=pg_collection, + ) + + out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result + + # All tensors should be unchanged + assert torch.equal(out_embeds, inputs_embeds) + assert torch.equal(out_labels, labels) + assert torch.equal(out_loss_mask, loss_mask) + assert torch.equal(out_pos_ids, position_ids) + assert torch.equal(out_attn_mask, attention_mask) + + def test_cp_size_0_returns_unchanged(self): + """Test that tensors are returned unchanged when cp_size=0.""" + batch_size, seq_len, hidden = 1, 8, 32 + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + labels = torch.randint(0, 100, (batch_size, seq_len)) + loss_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0) + attention_mask = None + + pg_collection = MockPGCollection(cp_size=0, cp_rank=0) + + result = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=None, + pg_collection=pg_collection, + ) + + out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result + + assert torch.equal(out_embeds, inputs_embeds) + assert torch.equal(out_labels, labels) + + +class TestSliceBatchForContextParallelBSHD: + """Tests for slice_batch_for_context_parallel with BSHD format (non-packed).""" + + def test_bshd_format_uses_get_batch_on_this_cp_rank(self): + """Test that BSHD format triggers get_batch_on_this_cp_rank.""" + batch_size, seq_len, hidden = 2, 16, 64 + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + labels = torch.randint(0, 1000, (batch_size, seq_len)) + loss_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + attention_mask = torch.ones(batch_size, 1, seq_len, seq_len) + + pg_collection = MockPGCollection(cp_size=2, cp_rank=0) + + # Mock get_batch_on_this_cp_rank to return sliced tensors + def mock_get_batch(batch_dict, cp_group=None): + # Simulate slicing by returning half of each tensor + result = {} + for k, v in batch_dict.items(): + if v is not None and isinstance(v, torch.Tensor): + if k == "decoder_input": # (B, T, D) format after transpose + result[k] = v[:, : seq_len // 2, :] + elif v.dim() == 4: # attention_mask + result[k] = v[:, :, : seq_len // 2, : seq_len // 2] + else: + result[k] = v[:, : seq_len // 2] + else: + result[k] = v + return result + + with patch( + "megatron.core.utils.get_batch_on_this_cp_rank", + side_effect=mock_get_batch, + ): + result = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=attention_mask, + packed_seq_params=None, # BSHD format + pg_collection=pg_collection, + ) + + out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result + + # Check that output tensors are sliced (half the sequence length) + assert out_embeds.shape[0] == seq_len // 2 # T dimension + assert out_embeds.shape[1] == batch_size # B dimension + assert out_labels.shape[1] == seq_len // 2 + assert out_loss_mask.shape[1] == seq_len // 2 + + def test_bshd_format_handles_none_tensors(self): + """Test that BSHD format handles None tensors gracefully.""" + batch_size, seq_len, hidden = 1, 8, 32 + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + + pg_collection = MockPGCollection(cp_size=2, cp_rank=0) + + def mock_get_batch(batch_dict, cp_group=None): + result = {} + for k, v in batch_dict.items(): + if v is not None and isinstance(v, torch.Tensor): + if k == "decoder_input": + result[k] = v[:, : seq_len // 2, :] + else: + result[k] = v[:, : seq_len // 2] if v.dim() > 1 else v + else: + result[k] = None + return result + + with patch( + "megatron.core.utils.get_batch_on_this_cp_rank", + side_effect=mock_get_batch, + ): + result = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=None, + loss_mask=None, + position_ids=None, + attention_mask=None, + packed_seq_params=None, + pg_collection=pg_collection, + ) + + out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result + + assert out_embeds is not None + assert out_labels is None + assert out_loss_mask is None + assert out_pos_ids is None + assert out_attn_mask is None + + +class TestSliceBatchForContextParallelTHD: + """Tests for slice_batch_for_context_parallel with THD (packed) format.""" + + def test_thd_format_uses_tex_partitioned_indices(self): + """Test that THD format triggers TransformerEngine's thd_get_partitioned_indices.""" + batch_size, seq_len, hidden = 1, 16, 64 + cp_size = 2 + cp_rank = 0 + + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + labels = torch.randint(0, 1000, (batch_size, seq_len)) + loss_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0) + + cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32) + packed_seq_params = MockPackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens, + qkv_format="thd", + ) + + pg_collection = MockPGCollection(cp_size=cp_size, cp_rank=cp_rank) + + # Mock tex.thd_get_partitioned_indices + mock_indices = torch.tensor([0, 1, 2, 3, 8, 9, 10, 11]) # First half of each sequence + + with patch.dict("sys.modules", {"transformer_engine_torch": MagicMock()}): + import sys + + mock_tex = sys.modules["transformer_engine_torch"] + mock_tex.thd_get_partitioned_indices.return_value = mock_indices + + result = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=None, + packed_seq_params=packed_seq_params, + pg_collection=pg_collection, + ) + + out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result + + # Verify tex.thd_get_partitioned_indices was called + mock_tex.thd_get_partitioned_indices.assert_called_once() + + # Check output shapes match the indices + assert out_embeds.shape[0] == len(mock_indices) # T dimension + assert out_embeds.shape[1] == batch_size # B dimension + + def test_thd_format_with_padded_cu_seqlens(self): + """Test THD format uses cu_seqlens_q_padded when available.""" + batch_size, seq_len, hidden = 1, 20, 32 + cp_size = 2 + cp_rank = 1 + + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + labels = torch.randint(0, 100, (batch_size, seq_len)) + loss_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0) + + # Padded cu_seqlens (includes padding for divisibility) + cu_seqlens = torch.tensor([0, 6, 16], dtype=torch.int32) + cu_seqlens_padded = torch.tensor([0, 10, 20], dtype=torch.int32) + + packed_seq_params = MockPackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, + qkv_format="thd", + ) + + pg_collection = MockPGCollection(cp_size=cp_size, cp_rank=cp_rank) + + mock_indices = torch.tensor([5, 6, 7, 8, 9, 15, 16, 17, 18, 19]) # Second half + + with patch.dict("sys.modules", {"transformer_engine_torch": MagicMock()}): + import sys + + mock_tex = sys.modules["transformer_engine_torch"] + mock_tex.thd_get_partitioned_indices.return_value = mock_indices + + result = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=None, + packed_seq_params=packed_seq_params, + pg_collection=pg_collection, + ) + + out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result + + # Verify padded cu_seqlens was used + call_args = mock_tex.thd_get_partitioned_indices.call_args + assert torch.equal(call_args[0][0], cu_seqlens_padded) + + def test_thd_format_without_padded_cu_seqlens_fallback(self): + """Test THD format falls back to cu_seqlens_q when padded version is None.""" + batch_size, seq_len, hidden = 1, 12, 16 + cp_size = 2 + cp_rank = 0 + + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + labels = torch.randint(0, 50, (batch_size, seq_len)) + loss_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0) + + cu_seqlens = torch.tensor([0, 6, 12], dtype=torch.int32) + + packed_seq_params = MockPackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + cu_seqlens_q_padded=None, # No padded version + qkv_format="thd", + ) + + pg_collection = MockPGCollection(cp_size=cp_size, cp_rank=cp_rank) + + mock_indices = torch.tensor([0, 1, 2, 6, 7, 8]) + + with patch.dict("sys.modules", {"transformer_engine_torch": MagicMock()}): + import sys + + mock_tex = sys.modules["transformer_engine_torch"] + mock_tex.thd_get_partitioned_indices.return_value = mock_indices + + slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=loss_mask, + position_ids=position_ids, + attention_mask=None, + packed_seq_params=packed_seq_params, + pg_collection=pg_collection, + ) + + # Verify cu_seqlens_q was used as fallback + call_args = mock_tex.thd_get_partitioned_indices.call_args + assert torch.equal(call_args[0][0], cu_seqlens) + + +class TestSliceBatchForContextParallelTranspose: + """Tests verifying correct tensor transpose operations.""" + + def test_input_transposed_before_slicing(self): + """Test that inputs_embeds is transposed from (T,B,D) to (B,T,D) for slicing.""" + batch_size, seq_len, hidden = 2, 8, 32 + + # Input is (T, B, D) format + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + + pg_collection = MockPGCollection(cp_size=2, cp_rank=0) + + captured_batch = {} + + def mock_get_batch(batch_dict, cp_group=None): + captured_batch.update(batch_dict) + # Check that decoder_input is in (B, T, D) format + di = batch_dict.get("decoder_input") + if di is not None: + assert di.shape == (batch_size, seq_len, hidden), f"Expected (B,T,D), got {di.shape}" + return batch_dict + + with patch( + "megatron.core.utils.get_batch_on_this_cp_rank", + side_effect=mock_get_batch, + ): + slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=None, + loss_mask=None, + position_ids=None, + attention_mask=None, + packed_seq_params=None, # BSHD format + pg_collection=pg_collection, + ) + + # Verify decoder_input was captured with correct shape + assert "decoder_input" in captured_batch + assert captured_batch["decoder_input"].shape == (batch_size, seq_len, hidden) + + def test_output_transposed_back_to_tbd(self): + """Test that outputs are transposed back to (T,B,D) format.""" + batch_size, seq_len, hidden = 2, 16, 64 + sliced_seq_len = seq_len // 2 + + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + + pg_collection = MockPGCollection(cp_size=2, cp_rank=0) + + def mock_get_batch(batch_dict, cp_group=None): + result = {} + for k, v in batch_dict.items(): + if v is not None and isinstance(v, torch.Tensor): + if k == "decoder_input": + # Return sliced (B, T/2, D) + result[k] = v[:, :sliced_seq_len, :] + else: + result[k] = v + else: + result[k] = v + return result + + with patch( + "megatron.core.utils.get_batch_on_this_cp_rank", + side_effect=mock_get_batch, + ): + result = slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=None, + loss_mask=None, + position_ids=None, + attention_mask=None, + packed_seq_params=None, + pg_collection=pg_collection, + ) + + out_embeds, *_ = result + + # Output should be in (T, B, D) format + assert out_embeds.shape == (sliced_seq_len, batch_size, hidden) + + +class TestSliceBatchForContextParallelEdgeCases: + """Edge case tests for slice_batch_for_context_parallel.""" + + def test_none_inputs_embeds_with_cp_size_gt_1(self): + """Test handling when inputs_embeds is None but CP is enabled.""" + pg_collection = MockPGCollection(cp_size=2, cp_rank=0) + + def mock_get_batch(batch_dict, cp_group=None): + return batch_dict + + with patch( + "megatron.core.utils.get_batch_on_this_cp_rank", + side_effect=mock_get_batch, + ): + result = slice_batch_for_context_parallel( + inputs_embeds=None, + labels=torch.randint(0, 100, (1, 8)), + loss_mask=torch.ones(1, 8), + position_ids=torch.arange(8).unsqueeze(0), + attention_mask=None, + packed_seq_params=None, + pg_collection=pg_collection, + ) + + out_embeds, *_ = result + assert out_embeds is None + + def test_non_thd_qkv_format_uses_bshd_path(self): + """Test that non-THD qkv_format (e.g., 'sbhd') uses BSHD slicing path.""" + batch_size, seq_len, hidden = 1, 8, 16 + + inputs_embeds = torch.randn(seq_len, batch_size, hidden) + labels = torch.randint(0, 50, (batch_size, seq_len)) + + # Non-THD format + packed_seq_params = MockPackedSeqParams( + cu_seqlens_q=torch.tensor([0, 8]), + cu_seqlens_kv=torch.tensor([0, 8]), + qkv_format="sbhd", # Not THD + ) + + pg_collection = MockPGCollection(cp_size=2, cp_rank=0) + + mock_called = {"get_batch": False} + + def mock_get_batch(batch_dict, cp_group=None): + mock_called["get_batch"] = True + return batch_dict + + with patch( + "megatron.core.utils.get_batch_on_this_cp_rank", + side_effect=mock_get_batch, + ): + slice_batch_for_context_parallel( + inputs_embeds=inputs_embeds, + labels=labels, + loss_mask=None, + position_ids=None, + attention_mask=None, + packed_seq_params=packed_seq_params, + pg_collection=pg_collection, + ) + + # Verify BSHD path was used (get_batch_on_this_cp_rank called) + assert mock_called["get_batch"]