diff --git a/docs/models/vlm/ministral3.md b/docs/models/vlm/ministral3.md
index b47a08df6c..71dcec8695 100644
--- a/docs/models/vlm/ministral3.md
+++ b/docs/models/vlm/ministral3.md
@@ -7,7 +7,7 @@ Ministral 3 models support multimodal tasks including image captioning, visual q
Ministral family models are supported via the Bridge system with auto-detected configuration and weight mapping.
```{important}
-Please update `transformers` version to 5.0.0rc0 in order to use the Ministral 3 models.
+Please upgrade to `transformers` v5 and upgrade `mistral-common` in order to use the Ministral 3 models.
```
## Available Models
diff --git a/src/megatron/bridge/data/vlm_datasets/hf_provider.py b/src/megatron/bridge/data/vlm_datasets/hf_provider.py
index c63c95e780..8014276753 100644
--- a/src/megatron/bridge/data/vlm_datasets/hf_provider.py
+++ b/src/megatron/bridge/data/vlm_datasets/hf_provider.py
@@ -67,6 +67,9 @@ class HFDatasetConversationProvider(DatasetProvider):
# DataloaderConfig fields are inherited (num_workers, dataloader_type, etc.)
dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single"
+ # Enable batch-level online sequence packing (dataset-level packing is available in FinetuneDatasetProvider)
+ pack_sequences_in_batch: bool = False
+
def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]:
registry: Dict[str, Callable[..., List[Dict[str, Any]]]] = {
"make_rdr_dataset": make_rdr_dataset,
diff --git a/src/megatron/bridge/models/gemma/gemma3_provider.py b/src/megatron/bridge/models/gemma/gemma3_provider.py
index 63d6413e32..80966260bc 100644
--- a/src/megatron/bridge/models/gemma/gemma3_provider.py
+++ b/src/megatron/bridge/models/gemma/gemma3_provider.py
@@ -367,11 +367,34 @@ def __init__(
**kwargs,
)
+ def forward(
+ self,
+ max_seq_len: int,
+ offset: int = 0,
+ packed_seq: bool = False,
+ cp_group: torch.distributed.ProcessGroup | None = None,
+ ) -> Tensor:
+ """Get global and local rope embedding.
+
+ Note: Caching is bypassed when cp_group is provided since ProcessGroup is unhashable.
+ """
+ # ProcessGroup is unhashable, so bypass caching when cp_group is provided
+ if cp_group is not None:
+ rope_global = super().forward(max_seq_len, offset, packed_seq, cp_group)
+ rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, cp_group)
+ return rope_local, rope_global
+ return self._forward_cached(max_seq_len, offset, packed_seq)
+
@lru_cache(maxsize=32)
- def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
- """Get global and local rope embedding"""
- rope_global = super().forward(max_seq_len, offset, packed_seq)
- rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq)
+ def _forward_cached(
+ self,
+ max_seq_len: int,
+ offset: int = 0,
+ packed_seq: bool = False,
+ ) -> Tensor:
+ """Cached forward for hashable parameters only."""
+ rope_global = super().forward(max_seq_len, offset, packed_seq, None)
+ rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, None)
return rope_local, rope_global
diff --git a/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py b/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py
index 0b1e937840..e68225ca1e 100644
--- a/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py
+++ b/src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py
@@ -14,22 +14,30 @@
import types
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
+from megatron.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import MegatronModule
from torch import Tensor
from transformers import AutoModel, Gemma3Model
from megatron.bridge.models.gpt_provider import GPTModelProvider
-from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
+from megatron.bridge.utils.common_utils import (
+ hook_hf_module_setattr_for_tp_grad_sync,
+ slice_batch_for_context_parallel,
+)
from megatron.bridge.utils.import_utils import safe_import_from
+if TYPE_CHECKING:
+ from megatron.core.packed_seq_params import PackedSeqParams
+
+
TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm")
@@ -110,12 +118,16 @@ def forward(
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
+ packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
- ) -> Tensor:
+ ) -> tuple[Tensor, Tensor | None]:
r"""
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
- The temporal, height and width of feature shape of each image in LLM.
+ Forward pass combining HuggingFace vision encoder with Megatron language model.
+
+ Returns:
+ tuple: (output_tensor, loss_mask) where output_tensor contains model output
+ and loss_mask is the CP-sliced mask for consistent loss computation.
"""
if self.pre_process:
if inputs_embeds is None:
@@ -134,7 +146,7 @@ def forward(
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if inputs_embeds[special_image_mask].numel() != image_features.numel():
- image_tokens_in_text = special_image_mask.sum(dim=1).item(dim=0)[0]
+ image_tokens_in_text = special_image_mask[:, :, 0].sum().item()
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
@@ -144,18 +156,38 @@ def forward(
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D)
+ # Apply sequence parallelism scatter if enabled
+ if self.config.sequence_parallel:
+ inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)
+
+ # Compute attention mask on FULL sequence (before CP slicing)
+ # This is needed because image regions need bidirectional attention
attention_mask = self._compute_attention_mask(input_ids)
+ # CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
+ # This must happen AFTER vision-text merge so image token positions are correct
+ inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ packed_seq_params=packed_seq_params,
+ pg_collection=self.config._pg_collection,
+ )
+
outputs = self.language_model.forward(
input_ids=None,
position_ids=position_ids,
- attention_mask=attention_mask, # (B, 1, T, T)
- decoder_input=inputs_embeds, # (T, B, D)
- labels=labels, # (B, T)
+ attention_mask=attention_mask,
+ decoder_input=inputs_embeds,
+ labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
+ packed_seq_params=packed_seq_params,
)
- return outputs
+ # Return both outputs and the CP-sliced loss_mask for consistent loss computation
+ return (outputs, loss_mask)
def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool):
"""Freeze model modules.
@@ -191,7 +223,7 @@ def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_
def _compute_attention_mask(
self,
input_ids: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ) -> Optional[torch.Tensor]:
if not self.pre_process:
return None
batch_size, seq_len = input_ids.shape
diff --git a/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py b/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py
index 3c97fab9ff..34d20c42fb 100644
--- a/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py
+++ b/src/megatron/bridge/models/glm_vl/modeling_glm_45v.py
@@ -23,7 +23,7 @@
"""
import types
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
import torch
import transformers
@@ -37,6 +37,10 @@
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
+if TYPE_CHECKING:
+ from megatron.core.packed_seq_params import PackedSeqParams
+
+
def is_transformers_min_version(version):
"""Check if minimum version of transformers is installed."""
try:
@@ -158,6 +162,7 @@ def forward(
video_grid_thw: Optional[torch.LongTensor] = None,
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
+ packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
@@ -233,6 +238,7 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
+ packed_seq_params=packed_seq_params,
)
return outputs
diff --git a/src/megatron/bridge/models/ministral3/ministral3_provider.py b/src/megatron/bridge/models/ministral3/ministral3_provider.py
index 3f5ce8e59d..5c6605d13b 100644
--- a/src/megatron/bridge/models/ministral3/ministral3_provider.py
+++ b/src/megatron/bridge/models/ministral3/ministral3_provider.py
@@ -260,11 +260,17 @@ def __init__(
self.beta = 0 # No effect
self.max_position_embeddings = self.config.seq_length
+ @staticmethod
def _get_llama_4_attn_scale(
- self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int
+ positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
) -> torch.Tensor:
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
- return scaling.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ # Add dimensions to match query shape: [seq_len] -> [seq_len, 1, 1] for packed or [seq_len, 1, 1, 1] for unpacked
+ # Query can be either [seq_len, num_heads, head_dim] (packed) or [seq_len, batch, num_heads, head_dim] (unpacked)
+ num_dims_to_add = len(query_shape) - 1
+ for _ in range(num_dims_to_add):
+ scaling = scaling.unsqueeze(-1)
+ return scaling
def forward(
self,
@@ -276,6 +282,8 @@ def forward(
**kwargs,
):
positions_ids = torch.arange(query.shape[0], device=query.device)
- query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings).to(query.dtype)
+ query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings, query.shape).to(
+ query.dtype
+ )
return super().forward(query, key, value, attention_mask, attn_mask_type, **kwargs)
diff --git a/src/megatron/bridge/models/ministral3/modeling_ministral3.py b/src/megatron/bridge/models/ministral3/modeling_ministral3.py
index 00590b801b..e541059e7a 100644
--- a/src/megatron/bridge/models/ministral3/modeling_ministral3.py
+++ b/src/megatron/bridge/models/ministral3/modeling_ministral3.py
@@ -24,14 +24,21 @@
"""
import types
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
import torch
from megatron.core.transformer.module import MegatronModule
from torch import Tensor
from megatron.bridge.models.gpt_provider import GPTModelProvider
-from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
+from megatron.bridge.utils.common_utils import (
+ hook_hf_module_setattr_for_tp_grad_sync,
+ slice_batch_for_context_parallel,
+)
+
+
+if TYPE_CHECKING:
+ from megatron.core.packed_seq_params import PackedSeqParams
# Import HuggingFace Mistral3 model classes with fallback
@@ -185,9 +192,10 @@ def forward(
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
image_sizes: Optional[torch.Tensor] = None,
+ packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
- ) -> Tensor:
+ ) -> tuple[Tensor, Tensor | None]:
"""
Forward pass combining HuggingFace vision encoder with Megatron language model.
@@ -202,7 +210,8 @@ def forward(
loss_mask: Mask for loss computation.
Returns:
- Model output (logits or loss depending on mode).
+ tuple: (output_tensor, loss_mask) where output_tensor contains model output
+ and loss_mask is the CP-sliced mask for consistent loss computation.
"""
if self.pre_process:
if inputs_embeds is None:
@@ -237,6 +246,18 @@ def forward(
# Transpose back to Megatron format [seq_len, batch, hidden]
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous()
+ # CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
+ # This must happen AFTER vision-text merge so image token positions are correct
+ inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ packed_seq_params=packed_seq_params,
+ pg_collection=self.config._pg_collection,
+ )
+
# Forward through Megatron language model
outputs = self.language_model.forward(
input_ids=None,
@@ -246,8 +267,10 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
+ packed_seq_params=packed_seq_params,
)
- return outputs
+ # Return both outputs and the CP-sliced loss_mask for consistent loss computation
+ return (outputs, loss_mask)
def freeze(
self,
diff --git a/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py b/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
index bd1447347f..d00fd7ab9f 100644
--- a/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
+++ b/src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
@@ -213,6 +213,7 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
+ packed_seq_params=packed_seq_params,
)
return outputs
diff --git a/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py b/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py
index 1e36a9a8dd..fc3520edd1 100644
--- a/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py
+++ b/src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py
@@ -234,9 +234,12 @@ def _gemma3_vl_common(
model_cfg.freeze_vision_model = freeze_vision_model
model_cfg.freeze_vision_projection = freeze_vision_projection
model_cfg.seq_length = seq_length
+ model_cfg.cp_comm_type = "a2a"
# Optimizer and scheduler - use finetune_lr if provided, otherwise use lr
effective_lr = finetune_lr if finetune_lr is not None else lr
+ if min_lr > effective_lr:
+ min_lr = effective_lr * 0.1
opt_config, scheduler = distributed_fused_adam_with_cosine_annealing(
lr_warmup_iters=lr_warmup_iters,
lr_decay_iters=lr_decay_iters if lr_decay_iters is not None else train_iters,
diff --git a/src/megatron/bridge/training/utils/packed_seq_utils.py b/src/megatron/bridge/training/utils/packed_seq_utils.py
index 98dbd6d5ac..e631be426a 100644
--- a/src/megatron/bridge/training/utils/packed_seq_utils.py
+++ b/src/megatron/bridge/training/utils/packed_seq_utils.py
@@ -44,14 +44,18 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams:
cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin")
if cu_seqlens_argmin is not None:
- cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()]
- else:
+ argmin_idx = cu_seqlens_argmin.item()
+ assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1
+ cu_seqlens_padded = cu_seqlens_padded[:argmin_idx]
+ elif torch.min(cu_seqlens_padded) == -1:
cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)]
if cu_seqlens_unpadded is not None:
if cu_seqlens_unpadded_argmin is not None:
- cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
- else:
+ argmin_idx = cu_seqlens_unpadded_argmin.item()
+ assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1
+ cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx]
+ elif torch.min(cu_seqlens_unpadded) == -1:
cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]
max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None
diff --git a/src/megatron/bridge/training/vlm_step.py b/src/megatron/bridge/training/vlm_step.py
index 1e013eeebd..3c9724ca80 100644
--- a/src/megatron/bridge/training/vlm_step.py
+++ b/src/megatron/bridge/training/vlm_step.py
@@ -19,7 +19,7 @@
import torch
from megatron.core.models.gpt import GPTModel
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
-from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config
+from megatron.core.utils import get_model_config
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.losses import (
@@ -97,19 +97,133 @@ def get_batch_from_iterator(
return _batch_required_keys
-def get_batch(
- data_iterator: Iterable, cfg: ConfigContainer, use_mtp: bool = False, *, pg_collection
-) -> tuple[
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- Any,
-]:
+def pack_batch_sequences(
+ tokens: torch.Tensor,
+ labels: torch.Tensor,
+ loss_mask: torch.Tensor,
+ attention_mask: torch.Tensor,
+ position_ids: torch.Tensor,
+ pad_token_id: int = 0,
+ pad_to_multiple_of: int = 1,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Pack sequences in a batch by concatenating them and removing padding.
+
+ Args:
+ tokens: [batch_size, seq_len]
+ labels: [batch_size, seq_len]
+ loss_mask: [batch_size, seq_len]
+ attention_mask: [batch_size, 1, seq_len, seq_len] or None
+ position_ids: [batch_size, seq_len]
+ pad_token_id: Token ID used for padding
+
+ Returns:
+ Tuple of:
+ - packed_tokens: [1, total_len] - concatenated sequences
+ - packed_labels: [1, total_len]
+ - packed_loss_mask: [1, total_len]
+ - packed_attention_mask: None (not used with packing)
+ - packed_position_ids: [1, total_len]
+ - cu_seqlens: [num_sequences + 1] - cumulative sequence lengths
+ - max_seqlen: tensor - max sequence length in packed batch
+ """
+ batch_size, seq_len = tokens.shape
+ device = tokens.device
+
+ # Find actual sequence lengths (excluding padding)
+ # Assuming padding is at the end and uses pad_token_id (0)
+ seq_lengths = []
+ valid_sequences = []
+
+ for i in range(batch_size):
+ # Find first padding token or use full length
+ non_pad_mask = tokens[i] != pad_token_id
+ if non_pad_mask.any():
+ # Find the last non-padding token
+ last_valid_idx = non_pad_mask.nonzero(as_tuple=True)[0][-1].item() + 1
+ else:
+ # Empty sequence, skip
+ continue
+
+ seq_lengths.append(last_valid_idx)
+ valid_sequences.append(i)
+
+ if len(valid_sequences) == 0:
+ # No valid sequences, return empty packed batch
+ logger.warning("No valid sequences found in batch, skipping packing")
+ return (
+ tokens[:, :0], # Empty tensor
+ labels[:, :0],
+ loss_mask[:, :0],
+ attention_mask,
+ position_ids[:, :0],
+ torch.tensor([0], dtype=torch.int32, device=device), # Empty cu_seqlens
+ torch.tensor(0, dtype=torch.int32, device=device),
+ )
+
+ # Build cumulative sequence lengths
+ cu_seqlens = [0]
+ padded_seq_lengths = []
+ for length in seq_lengths:
+ if pad_to_multiple_of > 1:
+ padded_len = ((length + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
+ else:
+ padded_len = length
+ padded_seq_lengths.append(padded_len)
+ # Use padded lengths for cu_seqlens so THD RoPE splits sum correctly under CP.
+ cu_seqlens.append(cu_seqlens[-1] + padded_len)
+
+ cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
+ max_seqlen = torch.tensor(max(padded_seq_lengths), dtype=torch.int32, device=device)
+ total_len = cu_seqlens[-1].item()
+
+ # Concatenate sequences (remove padding)
+ packed_tokens = torch.zeros(1, total_len, dtype=tokens.dtype, device=device)
+ packed_labels = torch.zeros(1, total_len, dtype=labels.dtype, device=device)
+ packed_loss_mask = torch.zeros(1, total_len, dtype=loss_mask.dtype, device=device)
+ packed_position_ids = torch.zeros(1, total_len, dtype=position_ids.dtype, device=device)
+
+ offset = 0
+ for i, seq_idx in enumerate(valid_sequences):
+ length = seq_lengths[i]
+ padded_len = padded_seq_lengths[i]
+ pad_len = padded_len - length
+ packed_tokens[0, offset : offset + length] = tokens[seq_idx, :length]
+ packed_labels[0, offset : offset + length] = labels[seq_idx, :length]
+ packed_loss_mask[0, offset : offset + length] = loss_mask[seq_idx, :length]
+ packed_position_ids[0, offset : offset + length] = position_ids[seq_idx, :length]
+ if pad_len > 0:
+ packed_tokens[0, offset + length : offset + padded_len] = pad_token_id
+ packed_labels[0, offset + length : offset + padded_len] = -100
+ packed_loss_mask[0, offset + length : offset + padded_len] = 0
+ start_pos = position_ids[seq_idx, length - 1] + 1
+ packed_position_ids[0, offset + length : offset + padded_len] = torch.arange(
+ start_pos,
+ start_pos + pad_len,
+ device=device,
+ dtype=position_ids.dtype,
+ )
+ offset += padded_len
+
+ logger.debug(
+ f"Packed {len(valid_sequences)} sequences: lengths={seq_lengths}, total_len={total_len}, max_len={max_seqlen}"
+ )
+
+ # Attention mask is not used with packed sequences (handled by cu_seqlens)
+ packed_attention_mask = None
+
+ return (
+ packed_tokens,
+ packed_labels,
+ packed_loss_mask,
+ packed_attention_mask,
+ packed_position_ids,
+ cu_seqlens,
+ max_seqlen,
+ )
+
+
+def get_batch(data_iterator: Iterable, cfg: ConfigContainer, use_mtp: bool = False, *, pg_collection) -> tuple[...]:
"""Generate a batch.
Args:
@@ -133,75 +247,110 @@ def get_batch(
is_first_pp_stage=is_first,
is_last_pp_stage=is_last,
)
+ enable_packing = getattr(cfg.dataset, "pack_sequences_in_batch", False)
- # Slice only text tensors for context parallelism
- cp_keys = ("tokens", "input_ids", "labels", "loss_mask", "attention_mask", "position_ids")
- cp_slice = {k: batch.get(k) for k in cp_keys if k in batch}
- cp_slice = get_batch_on_this_cp_rank(cp_slice, cp_group=pg_collection.cp)
- for k, v in cp_slice.items():
- batch[k] = v
-
- # When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length
- if getattr(cfg.model, "pipeline_model_parallel_size", 1) > 1:
- seq_len = cfg.model.seq_length
+ if not enable_packing:
+ # When using pipeline parallelism, ensure fixed shapes equal to cfg.model.seq_length
+ if getattr(cfg.model, "pipeline_model_parallel_size", 1) > 1:
+ seq_len = cfg.model.seq_length
- tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")
- tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0)
- if batch.get("tokens") is not None:
- batch["tokens"] = tokens_or_input # type: ignore[assignment]
- else:
- batch["input_ids"] = tokens_or_input # type: ignore[assignment]
- batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), seq_len, seq_len, pad_value=-100) # type: ignore[assignment]
- batch["loss_mask"] = pad_or_truncate_2d_to_len(batch.get("loss_mask"), seq_len, seq_len, pad_value=0) # type: ignore[assignment]
- batch["position_ids"] = pad_or_truncate_pos_to_len(batch.get("position_ids"), seq_len, seq_len) # type: ignore[assignment]
- if batch.get("attention_mask") is not None:
- batch["attention_mask"] = pad_or_truncate_attn_to_len(batch.get("attention_mask"), seq_len, seq_len) # type: ignore[assignment]
- else:
- # No PP: pad sequence length to nearest multiple of 128 for efficiency (capped at model seq_length)
- seq_cap = cfg.model.seq_length
-
- def _ceil_to_mult(n: int, mult: int) -> int:
- return ((n + mult - 1) // mult) * mult
-
- tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")
- if tokens_or_input is not None:
- cur_len = tokens_or_input.size(1)
- target_len = min(seq_cap, _ceil_to_mult(cur_len, 128))
-
- # tokens/input_ids
- padded_tokens = pad_or_truncate_2d_to_len(tokens_or_input, target_len, seq_cap, pad_value=0)
+ tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")
+ tokens_or_input = pad_or_truncate_2d_to_len(tokens_or_input, seq_len, seq_len, pad_value=0)
if batch.get("tokens") is not None:
- batch["tokens"] = padded_tokens # type: ignore[assignment]
+ batch["tokens"] = tokens_or_input # type: ignore[assignment]
else:
- batch["input_ids"] = padded_tokens # type: ignore[assignment]
+ batch["input_ids"] = tokens_or_input # type: ignore[assignment]
+ batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), seq_len, seq_len, pad_value=-100) # type: ignore[assignment]
+ batch["loss_mask"] = pad_or_truncate_2d_to_len(batch.get("loss_mask"), seq_len, seq_len, pad_value=0) # type: ignore[assignment]
+ batch["position_ids"] = pad_or_truncate_pos_to_len(batch.get("position_ids"), seq_len, seq_len) # type: ignore[assignment]
+ if batch.get("attention_mask") is not None:
+ batch["attention_mask"] = pad_or_truncate_attn_to_len(batch.get("attention_mask"), seq_len, seq_len) # type: ignore[assignment]
+ else:
+ # No PP: pad sequence length to nearest multiple of 128 for efficiency (capped at model seq_length)
+ seq_cap = cfg.model.seq_length
- # labels and loss mask
- batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), target_len, seq_cap, pad_value=-100) # type: ignore[assignment]
- batch["loss_mask"] = pad_or_truncate_2d_to_len(batch.get("loss_mask"), target_len, seq_cap, pad_value=0) # type: ignore[assignment]
+ def _ceil_to_mult(n: int, mult: int) -> int:
+ return ((n + mult - 1) // mult) * mult
- # position_ids: extend with increasing positions
- pos = batch.get("position_ids")
- pos = pad_or_truncate_pos_to_len(pos, target_len, seq_cap)
- if pos is not None:
- batch["position_ids"] = pos # type: ignore[assignment]
+ tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")
+ if tokens_or_input is not None:
+ cur_len = tokens_or_input.size(1)
+ target_len = min(seq_cap, _ceil_to_mult(cur_len, 128))
- # attention_mask if present
- attn = batch.get("attention_mask")
- if attn is not None:
- attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
- batch["attention_mask"] = attn # type: ignore[assignment]
+ # tokens/input_ids
+ padded_tokens = pad_or_truncate_2d_to_len(tokens_or_input, target_len, seq_cap, pad_value=0)
+ if batch.get("tokens") is not None:
+ batch["tokens"] = padded_tokens # type: ignore[assignment]
+ else:
+ batch["input_ids"] = padded_tokens # type: ignore[assignment]
+
+ # labels and loss mask
+ batch["labels"] = pad_or_truncate_2d_to_len(batch.get("labels"), target_len, seq_cap, pad_value=-100) # type: ignore[assignment]
+ batch["loss_mask"] = pad_or_truncate_2d_to_len(
+ batch.get("loss_mask"), target_len, seq_cap, pad_value=0
+ ) # type: ignore[assignment]
+
+ # position_ids: extend with increasing positions
+ pos = batch.get("position_ids")
+ pos = pad_or_truncate_pos_to_len(pos, target_len, seq_cap)
+ if pos is not None:
+ batch["position_ids"] = pos # type: ignore[assignment]
+
+ # attention_mask if present
+ attn = batch.get("attention_mask")
+ if attn is not None:
+ attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
+ batch["attention_mask"] = attn # type: ignore[assignment]
visual_inputs = batch.get("visual_inputs")
+ cp_size = pg_collection.cp.size() if pg_collection is not None and pg_collection.cp is not None else 1
+
+ if enable_packing:
+ # Pack sequences
+ tokens_or_input = batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")
+ (
+ packed_tokens,
+ packed_labels,
+ packed_loss_mask,
+ packed_attention_mask,
+ packed_position_ids,
+ cu_seqlens,
+ max_seqlen,
+ ) = pack_batch_sequences(
+ tokens=tokens_or_input,
+ labels=batch.get("labels"),
+ loss_mask=batch.get("loss_mask"),
+ attention_mask=batch.get("attention_mask"),
+ position_ids=batch.get("position_ids"),
+ pad_token_id=0,
+ pad_to_multiple_of=cp_size * 2 if cp_size > 1 else 1,
+ )
+
+ # Update batch dict with packed tensors
+ if batch.get("tokens") is not None:
+ batch["tokens"] = packed_tokens
+ else:
+ batch["input_ids"] = packed_tokens
+ batch["labels"] = packed_labels
+ batch["loss_mask"] = packed_loss_mask
+ batch["attention_mask"] = packed_attention_mask
+ batch["position_ids"] = packed_position_ids
+
+ # # Add packing metadata
+ logger.debug(f"Packed batch: cu_seqlens={cu_seqlens.tolist()}, max_seqlen={max_seqlen}")
+ else:
+ # No packing, use dummy values
+ cu_seqlens = None
+ max_seqlen = None
return (
(batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")),
batch.get("labels"),
- batch.get("loss_mask"),
+ batch.get("loss_mask"), # Full packed loss_mask, will be CP-sliced by model
batch.get("attention_mask"),
batch.get("position_ids"),
- batch.get("cu_seqlens"),
- batch.get("cu_seqlens_argmin"),
- batch.get("max_seqlen"),
+ cu_seqlens,
+ max_seqlen,
visual_inputs,
)
@@ -227,6 +376,7 @@ def forward_step(
use_mtp = (getattr(config, "mtp_num_layers", None) or 0) > 0
timers("batch-generator", log_level=2).start()
+ pg_collection = get_pg_collection(model)
with straggler_timer(bdata=True):
(
tokens,
@@ -235,10 +385,9 @@ def forward_step(
attention_mask,
position_ids,
cu_seqlens,
- cu_seqlens_argmin,
max_seqlen,
visual_inputs,
- ) = get_batch(data_iterator, state.cfg, use_mtp, pg_collection=get_pg_collection(model))
+ ) = get_batch(data_iterator, state.cfg, use_mtp, pg_collection=pg_collection)
timers("batch-generator").stop()
forward_args = {
@@ -246,6 +395,7 @@ def forward_step(
"position_ids": position_ids,
"attention_mask": attention_mask,
"labels": labels,
+ "loss_mask": loss_mask, # Pass full loss_mask so model can slice it consistently with labels
}
if visual_inputs is not None:
@@ -255,7 +405,6 @@ def forward_step(
if cu_seqlens is not None:
packed_seq_params = {
"cu_seqlens": cu_seqlens,
- "cu_seqlens_argmin": cu_seqlens_argmin,
"max_seqlen": max_seqlen,
}
forward_args["packed_seq_params"] = get_packed_seq_params(packed_seq_params)
@@ -273,7 +422,12 @@ def forward_step(
loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss)
return schedule_plan, loss_function
else:
- output_tensor = model(**forward_args)
+ model_output = model(**forward_args)
+ # Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CP
+ if isinstance(model_output, tuple):
+ output_tensor, loss_mask = model_output
+ else:
+ output_tensor = model_output
loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss)
diff --git a/src/megatron/bridge/utils/common_utils.py b/src/megatron/bridge/utils/common_utils.py
index 59ac803bf5..1946230bd6 100644
--- a/src/megatron/bridge/utils/common_utils.py
+++ b/src/megatron/bridge/utils/common_utils.py
@@ -271,3 +271,97 @@ def resolve_path(path: str) -> Path:
"""Resolve a path to an absolute path."""
return Path(path).expanduser().absolute().resolve()
+
+
+def slice_batch_for_context_parallel(
+ inputs_embeds: torch.Tensor,
+ labels: torch.Tensor,
+ loss_mask: torch.Tensor,
+ position_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ packed_seq_params,
+ pg_collection,
+):
+ """Slice batch tensors for Context Parallelism (CP) in VLM models.
+
+ This function handles CP slicing AFTER vision-text embedding merge, ensuring
+ image token positions are correctly preserved. It supports both:
+ - THD format (packed sequences): Uses TransformerEngine's thd_get_partitioned_indices
+ - BSHD format: Uses Megatron's get_batch_on_this_cp_rank with zigzag pattern
+
+ Args:
+ inputs_embeds: Input embeddings tensor in (T, B, D) format.
+ labels: Labels tensor.
+ loss_mask: Loss mask tensor.
+ position_ids: Position IDs tensor.
+ attention_mask: Attention mask tensor.
+ packed_seq_params: PackedSeqParams for THD format, or None for BSHD.
+ pg_collection: ProcessGroupCollection containing CP group info.
+
+ Returns:
+ Tuple of (inputs_embeds, labels, loss_mask, position_ids, attention_mask)
+ with all tensors sliced for this CP rank. inputs_embeds remains in (T, B, D) format.
+ """
+ from megatron.core.utils import get_batch_on_this_cp_rank
+
+ cp_size = pg_collection.cp.size()
+ if cp_size <= 1:
+ return inputs_embeds, labels, loss_mask, position_ids, attention_mask
+
+ cp_rank = pg_collection.cp.rank()
+
+ # (T, B, D) -> (B, T, D) for slicing
+ if inputs_embeds is not None:
+ inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
+
+ # For THD (packed) format, use TE's thd_get_partitioned_indices
+ # This properly slices WITHIN each packed sequence, not across them
+ if packed_seq_params is not None and packed_seq_params.qkv_format == "thd":
+ import transformer_engine_torch as tex
+
+ if inputs_embeds is None:
+ raise ValueError("inputs_embeds is required for THD CP slicing")
+
+ cu_seqlens = packed_seq_params.cu_seqlens_q
+ cu_seqlens_padded = (
+ packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens
+ )
+ seq_len = inputs_embeds.size(1)
+
+ index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank)
+
+ # Slice all tensors using THD indices
+ if inputs_embeds is not None:
+ inputs_embeds = inputs_embeds.index_select(1, index)
+ if labels is not None:
+ labels = labels.index_select(1, index)
+ if loss_mask is not None:
+ loss_mask = loss_mask.index_select(1, index)
+ if position_ids is not None:
+ position_ids = position_ids.index_select(1, index)
+ # Note: attention_mask and packed_seq_params stay unchanged for ring attention
+ else:
+ # For BSHD format, use standard zigzag slicing
+ cp_group = pg_collection.cp
+ cp_batch = get_batch_on_this_cp_rank(
+ {
+ "decoder_input": inputs_embeds,
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ },
+ cp_group=cp_group,
+ )
+
+ inputs_embeds = cp_batch.get("decoder_input")
+ labels = cp_batch.get("labels")
+ loss_mask = cp_batch.get("loss_mask")
+ position_ids = cp_batch.get("position_ids")
+ attention_mask = cp_batch.get("attention_mask")
+
+ # Transpose back to (T, B, D)
+ if inputs_embeds is not None:
+ inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
+
+ return inputs_embeds, labels, loss_mask, position_ids, attention_mask
diff --git a/tests/functional_tests/L2_Launch_recipes_gemma_vl.sh b/tests/functional_tests/L2_Launch_recipes_gemma_vl.sh
new file mode 100644
index 0000000000..18870efe9d
--- /dev/null
+++ b/tests/functional_tests/L2_Launch_recipes_gemma_vl.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeuo pipefail # Exit immediately if a command exits with a non-zero status
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+# Run Gemma3-VL recipe functional tests on 2 GPUs
+# This script tests Gemma3-VL finetune recipe configurations with their default
+# settings to ensure they can run basic training without crashes.
+uv run python -m torch.distributed.run --nproc_per_node=1 --nnodes=1 \
+ -m coverage run --data-file=/opt/Megatron-Bridge/.coverage \
+ --source=/opt/Megatron-Bridge/ --parallel-mode \
+ -m pytest -o log_cli=true -o log_cli_level=INFO -v -s -x \
+ -m "not pleasefixme" --tb=short -rA \
+ tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py
+
+coverage combine -q
+
diff --git a/tests/functional_tests/models/ministral3_vl/__init__.py b/tests/functional_tests/models/ministral3_vl/__init__.py
new file mode 100644
index 0000000000..341a77c5bc
--- /dev/null
+++ b/tests/functional_tests/models/ministral3_vl/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/functional_tests/models/ministral3_vl/test_ministral3_vl_conversion.py b/tests/functional_tests/models/ministral3_vl/test_ministral3_vl_conversion.py
new file mode 100644
index 0000000000..1d2d0df65c
--- /dev/null
+++ b/tests/functional_tests/models/ministral3_vl/test_ministral3_vl_conversion.py
@@ -0,0 +1,257 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import subprocess
+from pathlib import Path
+
+import pytest
+import torch
+
+
+# Ministral 3 VL toy model configuration based on typical Ministral 3 VL structure
+# This is a minimized version for testing purposes
+HF_MINISTRAL3_VL_TOY_MODEL_CONFIG = {
+ "architectures": ["Mistral3ForConditionalGeneration"],
+ "model_type": "mistral3",
+ "torch_dtype": "bfloat16",
+ "transformers_version": "5.0.0",
+ "image_token_index": 10,
+ "text_config": {
+ "model_type": "mistral3_text",
+ "hidden_size": 512,
+ "intermediate_size": 1536, # 3 * hidden_size for FFN
+ "num_hidden_layers": 4,
+ "num_attention_heads": 8,
+ "num_key_value_heads": 2, # GQA with 4 query groups
+ "vocab_size": 32768,
+ "max_position_embeddings": 4096,
+ "rms_norm_eps": 1e-5,
+ "tie_word_embeddings": True,
+ "rope_theta": 1000000,
+ "rope_parameters": {
+ "rope_type": "yarn",
+ "factor": 16.0,
+ "original_max_position_embeddings": 16384,
+ "llama_4_scaling_beta": 0.0,
+ },
+ },
+ "vision_config": {
+ "model_type": "pixtral",
+ "hidden_size": 256,
+ "intermediate_size": 1024,
+ "num_hidden_layers": 4,
+ "num_attention_heads": 4,
+ "image_size": 448,
+ "patch_size": 14,
+ "num_channels": 3,
+ },
+ "spatial_merge_size": 2,
+ "vision_feature_layer": -1,
+}
+
+
+class TestMinistral3VLConversion:
+ """
+ Test Ministral 3 VL model conversion from local HuggingFace model with different parallelism configurations.
+ """
+
+ @pytest.fixture(scope="class")
+ def ministral3_vl_toy_model_path(self, tmp_path_factory):
+ """
+ Create and save a HuggingFace Ministral 3 VL toy model from config to a temporary directory.
+
+ Args:
+ tmp_path_factory: Pytest temporary path factory for class-scoped fixtures
+
+ Returns:
+ str: Path to the saved HuggingFace model directory
+ """
+ # Skip if transformers doesn't have Mistral3 support
+ pytest.importorskip("transformers", minversion="5.0.0")
+
+ try:
+ from transformers import Mistral3ForConditionalGeneration
+ from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
+ except ImportError:
+ pytest.skip("Mistral3ForConditionalGeneration not available in transformers")
+
+ # Create a temporary directory for this test class
+ temp_dir = tmp_path_factory.mktemp("ministral3_vl_toy_model")
+ model_dir = temp_dir / "ministral3_vl_toy"
+
+ # Create config from the toy model config
+ config_dict = HF_MINISTRAL3_VL_TOY_MODEL_CONFIG.copy()
+
+ # Create config object
+ config = Mistral3Config(**config_dict)
+ config.torch_dtype = torch.bfloat16
+
+ # Create model with random weights and convert to bfloat16
+ model = Mistral3ForConditionalGeneration(config)
+ model = model.bfloat16()
+
+ # Debug: Check model dtype before saving
+ for name, param in model.named_parameters():
+ print(f"Before save - {name}: {param.dtype}")
+ break # Just check the first parameter
+
+ # Create minimal tokenizer files
+ tokenizer_config = {
+ "tokenizer_class": "LlamaTokenizer",
+ "vocab_size": 32768,
+ "bos_token": "",
+ "eos_token": "",
+ "pad_token": "",
+ "unk_token": "",
+ }
+
+ model_dir.mkdir(parents=True, exist_ok=True)
+ with open(model_dir / "tokenizer_config.json", "w") as f:
+ json.dump(tokenizer_config, f, indent=2)
+
+ # Save model and config to directory
+ model.save_pretrained(model_dir, safe_serialization=True)
+
+ # Save config.json explicitly to ensure compatibility
+ config_path = model_dir / "config.json"
+ with open(config_path, "w") as f:
+ json.dump(config_dict, f, indent=2)
+
+ return str(model_dir)
+
+ def test_toy_model_creation(self, ministral3_vl_toy_model_path):
+ """
+ Test that the toy model is created correctly and can be loaded.
+
+ Args:
+ ministral3_vl_toy_model_path: Path to the toy Ministral 3 VL model (from fixture)
+ """
+ # Verify the model directory exists
+ model_path = Path(ministral3_vl_toy_model_path)
+ assert model_path.exists(), f"Model directory not found at {model_path}"
+
+ # Check essential files exist
+ config_file = model_path / "config.json"
+ assert config_file.exists(), f"config.json not found at {config_file}"
+
+ # Check for model weights (safetensors preferred)
+ weights_file = model_path / "model.safetensors"
+ if not weights_file.exists():
+ weights_file = model_path / "pytorch_model.bin"
+ assert weights_file.exists(), f"Model weights file not found in {model_path}"
+
+ # Check for tokenizer files
+ tokenizer_config_file = model_path / "tokenizer_config.json"
+ assert tokenizer_config_file.exists(), f"tokenizer_config.json not found at {tokenizer_config_file}"
+
+ # Load and verify config
+ with open(config_file) as f:
+ config_data = json.load(f)
+
+ assert config_data["model_type"] == "mistral3"
+ assert "text_config" in config_data
+ assert "vision_config" in config_data
+ assert config_data["text_config"]["hidden_size"] == 512
+ assert config_data["text_config"]["num_hidden_layers"] == 4
+ assert config_data["text_config"]["num_attention_heads"] == 8
+ assert config_data["vision_config"]["hidden_size"] == 256
+
+ @pytest.mark.run_only_on("GPU")
+ @pytest.mark.parametrize(
+ "tp,pp,test_name",
+ [
+ (2, 1, "TP"),
+ (1, 2, "PP"),
+ ],
+ )
+ def test_ministral3_vl_conversion_parallelism(self, ministral3_vl_toy_model_path, tmp_path, tp, pp, test_name):
+ """
+ Test Ministral 3 VL model conversion with different parallelism configurations.
+
+ Args:
+ ministral3_vl_toy_model_path: Path to the toy Ministral 3 VL model (from fixture)
+ tmp_path: Pytest temporary path fixture
+ tp: Tensor parallelism size
+ pp: Pipeline parallelism size
+ test_name: Name of the test for identification
+ """
+
+ # Create temporary output directory for conversion results
+ test_output_dir = tmp_path / f"ministral3_vl_{test_name}"
+ test_output_dir.mkdir(exist_ok=True)
+
+ # Run hf_megatron_roundtrip_multi_gpu.py with specified parallelism configuration on our toy model
+ cmd = [
+ "python",
+ "-m",
+ "torch.distributed.run",
+ "--nproc_per_node=2",
+ "--nnodes=1",
+ "-m",
+ "coverage",
+ "run",
+ "--data-file=/opt/Megatron-Bridge/.coverage",
+ "--source=/opt/Megatron-Bridge/",
+ "--parallel-mode",
+ "examples/conversion/hf_megatron_roundtrip_multi_gpu.py",
+ "--hf-model-id",
+ ministral3_vl_toy_model_path, # Use our local toy model instead of downloading
+ "--output-dir",
+ str(test_output_dir),
+ "--tp",
+ str(tp),
+ "--pp",
+ str(pp),
+ ]
+
+ result = subprocess.run(
+ cmd, capture_output=True, text=True, cwd=Path(__file__).parent.parent.parent.parent.parent
+ )
+ print(cmd)
+
+ # Check that the conversion completed successfully
+ if result.returncode != 0:
+ print(f"STDOUT: {result.stdout}")
+ print(f"STDERR: {result.stderr}")
+ assert False, f"Ministral 3 VL {test_name} conversion failed with return code {result.returncode}"
+
+ # Verify that the converted model was saved
+ # The output directory should be named after the last part of the model path
+ model_name = Path(ministral3_vl_toy_model_path).name # "ministral3_vl_toy"
+ converted_model_dir = test_output_dir / model_name
+ assert converted_model_dir.exists(), f"Converted model directory not found at {converted_model_dir}"
+
+ # Check that essential model files exist
+ config_file = converted_model_dir / "config.json"
+ assert config_file.exists(), f"config.json not found in converted model at {config_file}"
+
+ # Check for model weights file (could be either safetensors or pytorch_model.bin)
+ weights_file_safetensors = converted_model_dir / "model.safetensors"
+ weights_file_pytorch = converted_model_dir / "pytorch_model.bin"
+ assert weights_file_safetensors.exists() or weights_file_pytorch.exists(), (
+ f"Model weights file not found in converted model at {converted_model_dir}"
+ )
+
+ # Verify the config contains Ministral 3 VL-specific parameters
+ with open(config_file) as f:
+ saved_config = json.load(f)
+
+ assert saved_config["model_type"] == "mistral3", "Model type should be mistral3"
+ assert "text_config" in saved_config, "VL model should have text_config"
+ assert "vision_config" in saved_config, "VL model should have vision_config"
+ assert saved_config["text_config"]["hidden_size"] == 512, "Hidden size should match toy config"
+ assert saved_config["text_config"]["num_attention_heads"] == 8, (
+ "Number of attention heads should match toy config"
+ )
diff --git a/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py b/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py
index 6d4ae51323..3b94e2f84c 100644
--- a/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py
+++ b/tests/functional_tests/recipes/test_gemma3_vl_recipes_finetune.py
@@ -36,7 +36,7 @@ class TestGemma3VLRecipes:
"""Test class for Gemma3-VL recipe functional tests."""
@pytest.mark.run_only_on("GPU")
- @pytest.mark.parametrize("config_func,recipe_name,parallelism_overrides", GEMMA3_VL_FINETUNE_RECIPES)
+ @pytest.mark.parametrize("config_func,recipe_name,model_overrides", GEMMA3_VL_FINETUNE_RECIPES)
def test_gemma3_vl_finetune_recipes(self, config_func, recipe_name, model_overrides, tmp_path):
"""Functional test for Gemma3-VL recipes with appropriate parallelism configurations."""
run_pretrain_vl_recipe_test(config_func, recipe_name, tmp_path, model_overrides=model_overrides)
diff --git a/tests/unit_tests/models/gemma/test_gemma3_provider.py b/tests/unit_tests/models/gemma/test_gemma3_provider.py
index 1e842ad42d..a06d07c104 100644
--- a/tests/unit_tests/models/gemma/test_gemma3_provider.py
+++ b/tests/unit_tests/models/gemma/test_gemma3_provider.py
@@ -416,6 +416,86 @@ def test_gemma3_rotary_embedding_initialization(self):
# Verify that RotaryEmbedding was called for local rope
assert mock_rotary_embedding.call_count >= 1
+ def test_gemma3_rotary_embedding_forward_with_cp_group(self):
+ """Test Gemma3RotaryEmbedding forward method with cp_group (non-None path)."""
+ # Create a minimal Gemma3RotaryEmbedding instance via __new__ to avoid complex init
+ rope_emb = Gemma3RotaryEmbedding.__new__(Gemma3RotaryEmbedding)
+
+ # Mock the rope_local attribute
+ mock_rope_local = Mock()
+ mock_rope_local.forward = Mock(return_value=torch.tensor([1.0, 2.0]))
+ rope_emb.rope_local = mock_rope_local
+
+ # Mock the parent class forward method (called via super().forward)
+ mock_global_output = torch.tensor([3.0, 4.0])
+ mock_local_output = torch.tensor([1.0, 2.0])
+
+ # Create a mock cp_group (ProcessGroup)
+ mock_cp_group = Mock()
+
+ with patch.object(
+ Gemma3RotaryEmbedding.__bases__[0], "forward", return_value=mock_global_output
+ ) as mock_super_forward:
+ result = rope_emb.forward(max_seq_len=1024, offset=0, packed_seq=False, cp_group=mock_cp_group)
+
+ # Verify super().forward was called with cp_group
+ mock_super_forward.assert_called_once_with(1024, 0, False, mock_cp_group)
+
+ # Verify rope_local.forward was called with cp_group
+ mock_rope_local.forward.assert_called_once_with(1024, 0, False, mock_cp_group)
+
+ # Verify return is (rope_local, rope_global) tuple
+ assert isinstance(result, tuple)
+ assert len(result) == 2
+ rope_local_result, rope_global_result = result
+ assert torch.equal(rope_local_result, mock_local_output)
+ assert torch.equal(rope_global_result, mock_global_output)
+
+ def test_gemma3_rotary_embedding_forward_without_cp_group(self):
+ """Test Gemma3RotaryEmbedding forward method without cp_group (cached path)."""
+ # Create a minimal Gemma3RotaryEmbedding instance via __new__
+ rope_emb = Gemma3RotaryEmbedding.__new__(Gemma3RotaryEmbedding)
+
+ # Mock the _forward_cached method
+ mock_cached_result = (torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0]))
+ rope_emb._forward_cached = Mock(return_value=mock_cached_result)
+
+ # Call forward without cp_group (None)
+ result = rope_emb.forward(max_seq_len=1024, offset=0, packed_seq=False, cp_group=None)
+
+ # Verify _forward_cached was called
+ rope_emb._forward_cached.assert_called_once_with(1024, 0, False)
+
+ # Verify result matches cached result
+ assert result == mock_cached_result
+
+ def test_gemma3_rotary_embedding_forward_cached(self):
+ """Test Gemma3RotaryEmbedding _forward_cached method."""
+ # Create a minimal Gemma3RotaryEmbedding instance via __new__
+ rope_emb = Gemma3RotaryEmbedding.__new__(Gemma3RotaryEmbedding)
+
+ # Mock the rope_local attribute
+ mock_rope_local = Mock()
+ mock_rope_local.forward = Mock(return_value=torch.tensor([1.0, 2.0]))
+ rope_emb.rope_local = mock_rope_local
+
+ mock_global_output = torch.tensor([3.0, 4.0])
+
+ with patch.object(
+ Gemma3RotaryEmbedding.__bases__[0], "forward", return_value=mock_global_output
+ ) as mock_super_forward:
+ result = rope_emb._forward_cached(max_seq_len=512, offset=10, packed_seq=True)
+
+ # Verify super().forward was called with cp_group=None
+ mock_super_forward.assert_called_once_with(512, 10, True, None)
+
+ # Verify rope_local.forward was called with cp_group=None
+ mock_rope_local.forward.assert_called_once_with(512, 10, True, None)
+
+ # Verify return is (rope_local, rope_global) tuple
+ assert isinstance(result, tuple)
+ assert len(result) == 2
+
def test_te_row_parallel_linear_layer_norm(self):
"""Test TERowParallelLinearLayerNorm initialization and forward."""
# Test that the class exists and can be imported
diff --git a/tests/unit_tests/models/ministral3/test_ministral3_provider.py b/tests/unit_tests/models/ministral3/test_ministral3_provider.py
index 01564a61fc..c8058e1343 100644
--- a/tests/unit_tests/models/ministral3/test_ministral3_provider.py
+++ b/tests/unit_tests/models/ministral3/test_ministral3_provider.py
@@ -12,15 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
+import torch
from megatron.bridge.models.ministral3.ministral3_provider import (
Ministral3ModelProvider,
Ministral3ModelProvider3B,
Ministral3ModelProvider8B,
Ministral3ModelProvider14B,
+ MinistralTEDotProductAttention,
)
+pytestmark = pytest.mark.unit
+
+
class TestMinistral3ModelProvider:
"""Test cases for Ministral3ModelProvider base class."""
@@ -172,3 +178,166 @@ def test_ministral3_14b_initialization(self):
assert provider.ffn_hidden_size == 16384
assert provider.num_layers == 40
assert provider.rotary_base == 1000000000.0
+
+
+class TestGetLlama4AttnScale:
+ """Test cases for _get_llama_4_attn_scale function used in MinistralTEDotProductAttention.
+
+ This function computes attention scaling based on Llama 4 attention parameters.
+ The key change in PR 1997 is that it now handles different query shapes for
+ packed (3D) vs unpacked (4D) tensors.
+ """
+
+ # Use the actual production implementation
+ _get_llama_4_attn_scale = staticmethod(MinistralTEDotProductAttention._get_llama_4_attn_scale)
+
+ def test_unpacked_4d_query_shape(self):
+ """Test attention scaling with unpacked 4D query shape [seq_len, batch, num_heads, head_dim]."""
+ seq_len = 8
+ batch_size = 2
+ num_heads = 4
+ head_dim = 64
+
+ positions_ids = torch.arange(seq_len)
+ beta = 0.1
+ max_position_embeddings = 16384
+ query_shape = (seq_len, batch_size, num_heads, head_dim)
+
+ scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
+
+ # Output should have shape [seq_len, 1, 1, 1] for broadcasting
+ assert scaling.shape == (seq_len, 1, 1, 1)
+
+ # First position should have scaling = 1 (since log(1 + 0) = 0)
+ expected_first = 1 + beta * torch.log(torch.tensor(1.0))
+ assert torch.isclose(scaling[0, 0, 0, 0], expected_first, atol=1e-6)
+
+ def test_packed_3d_query_shape(self):
+ """Test attention scaling with packed 3D query shape [seq_len, num_heads, head_dim]."""
+ seq_len = 16
+ num_heads = 8
+ head_dim = 32
+
+ positions_ids = torch.arange(seq_len)
+ beta = 0.2
+ max_position_embeddings = 8192
+ query_shape = (seq_len, num_heads, head_dim)
+
+ scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
+
+ # Output should have shape [seq_len, 1, 1] for broadcasting (3D - 1 = 2 dims added)
+ assert scaling.shape == (seq_len, 1, 1)
+
+ # Verify scaling values are computed correctly
+ expected = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
+ assert torch.allclose(scaling.squeeze(), expected, atol=1e-6)
+
+ def test_scaling_formula_correctness(self):
+ """Test that the scaling formula matches expected Llama 4 attention scaling."""
+ positions_ids = torch.tensor([0, 1, 100, 1000, 16384, 32768])
+ beta = 0.15
+ max_position_embeddings = 16384
+ query_shape = (6, 1, 1, 1)
+
+ scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
+
+ # Manual computation of expected values
+ # For position 0: 1 + 0.15 * log(1 + 0) = 1
+ # For position 16384: 1 + 0.15 * log(1 + 1) = 1 + 0.15 * log(2)
+ # For position 32768: 1 + 0.15 * log(1 + 2) = 1 + 0.15 * log(3)
+
+ expected_0 = 1.0
+ expected_16384 = 1 + beta * torch.log(torch.tensor(2.0))
+ expected_32768 = 1 + beta * torch.log(torch.tensor(3.0))
+
+ assert torch.isclose(scaling[0].squeeze(), torch.tensor(expected_0), atol=1e-6)
+ assert torch.isclose(scaling[4].squeeze(), expected_16384, atol=1e-6)
+ assert torch.isclose(scaling[5].squeeze(), expected_32768, atol=1e-6)
+
+ def test_beta_zero_returns_ones(self):
+ """Test that beta=0 returns all ones (no scaling)."""
+ positions_ids = torch.arange(10)
+ beta = 0.0
+ max_position_embeddings = 4096
+ query_shape = (10, 4, 64)
+
+ scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
+
+ assert torch.allclose(scaling.squeeze(), torch.ones(10), atol=1e-6)
+
+ def test_different_query_shapes_get_correct_dims(self):
+ """Test that different query shapes result in correct number of dimensions added."""
+ positions_ids = torch.arange(4)
+ beta = 0.1
+ max_position_embeddings = 1000
+
+ # 2D query shape
+ query_shape_2d = (4, 32)
+ scaling_2d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_2d)
+ assert scaling_2d.shape == (4, 1) # 2-1 = 1 dim added
+
+ # 3D query shape (packed THD)
+ query_shape_3d = (4, 8, 32)
+ scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_3d)
+ assert scaling_3d.shape == (4, 1, 1) # 3-1 = 2 dims added
+
+ # 4D query shape (unpacked BSHD)
+ query_shape_4d = (4, 2, 8, 32)
+ scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape_4d)
+ assert scaling_4d.shape == (4, 1, 1, 1) # 4-1 = 3 dims added
+
+ def test_broadcasting_compatibility(self):
+ """Test that scaling tensor is broadcastable to query tensor."""
+ seq_len = 8
+ num_heads = 4
+ head_dim = 64
+
+ positions_ids = torch.arange(seq_len)
+ beta = 0.1
+ max_position_embeddings = 16384
+
+ # Test for 3D packed format
+ query_3d = torch.randn(seq_len, num_heads, head_dim)
+ scaling_3d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_3d.shape)
+
+ # Broadcasting should work
+ result_3d = query_3d * scaling_3d.to(query_3d.dtype)
+ assert result_3d.shape == query_3d.shape
+
+ # Test for 4D unpacked format
+ batch = 2
+ query_4d = torch.randn(seq_len, batch, num_heads, head_dim)
+ scaling_4d = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_4d.shape)
+
+ # Broadcasting should work
+ result_4d = query_4d * scaling_4d.to(query_4d.dtype)
+ assert result_4d.shape == query_4d.shape
+
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
+ def test_gpu_tensor_support(self):
+ """Test that the function works with GPU tensors if available."""
+ positions_ids = torch.arange(8, device="cuda")
+ beta = 0.1
+ max_position_embeddings = 1024
+ query_shape = (8, 4, 32)
+
+ scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
+
+ assert scaling.device.type == "cuda"
+ assert scaling.shape == (8, 1, 1)
+
+ def test_dtype_preservation(self):
+ """Test that output dtype matches input positions_ids dtype."""
+ positions_ids_float32 = torch.arange(4, dtype=torch.float32)
+ positions_ids_float64 = torch.arange(4, dtype=torch.float64)
+ beta = 0.1
+ max_position_embeddings = 100
+ query_shape = (4, 2, 8)
+
+ scaling_32 = self._get_llama_4_attn_scale(positions_ids_float32, beta, max_position_embeddings, query_shape)
+ scaling_64 = self._get_llama_4_attn_scale(positions_ids_float64, beta, max_position_embeddings, query_shape)
+
+ # Note: torch.arange with int creates int tensors, but the function uses float operations
+ # The scaling result will be float due to log operation
+ assert scaling_32.dtype == torch.float32
+ assert scaling_64.dtype == torch.float64
diff --git a/tests/unit_tests/training/test_gpt_step.py b/tests/unit_tests/training/test_gpt_step.py
index 141fd9827f..d677c1334f 100644
--- a/tests/unit_tests/training/test_gpt_step.py
+++ b/tests/unit_tests/training/test_gpt_step.py
@@ -115,9 +115,8 @@ def test_packed_seq_params_no_padding(self):
# Verify the result is a PackedSeqParams object
assert isinstance(result, PackedSeqParams)
- # When there's no -1 padding, argmin returns 0 (index of min value)
- # So cu_seqlens[:0] returns empty tensor
- expected_cu_seqlens = torch.empty(0, dtype=torch.int32) # Empty tensor
+ # When there's no -1 padding, the tensor is returned unchanged
+ expected_cu_seqlens = torch.tensor([0, 7, 14], dtype=torch.int32)
assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens)
assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens)
@@ -181,6 +180,113 @@ def test_packed_seq_params_all_fields_match(self):
assert torch.equal(result.cu_seqlens_q, result.cu_seqlens_kv)
assert torch.equal(result.max_seqlen_q, result.max_seqlen_kv)
+ def test_packed_seq_params_with_cu_seqlens_unpadded(self):
+ """Test functionality with cu_seqlens_unpadded for THD CP support."""
+ # Padded cu_seqlens (includes padding for CP divisibility)
+ cu_seqlens_padded = torch.tensor([[0, 8, 16, -1, -1]], dtype=torch.int32)
+ # Unpadded cu_seqlens (actual sequence boundaries)
+ cu_seqlens_unpadded = torch.tensor([[0, 6, 14, -1, -1]], dtype=torch.int32)
+
+ batch = {
+ "cu_seqlens": cu_seqlens_padded,
+ "cu_seqlens_unpadded": cu_seqlens_unpadded,
+ "max_seqlen": torch.tensor([[10]], dtype=torch.int32),
+ }
+
+ result = get_packed_seq_params(batch)
+
+ # cu_seqlens_q and cu_seqlens_kv should use unpadded values
+ expected_unpadded = torch.tensor([0, 6, 14], dtype=torch.int32)
+ assert torch.equal(result.cu_seqlens_q, expected_unpadded)
+ assert torch.equal(result.cu_seqlens_kv, expected_unpadded)
+
+ # cu_seqlens_q_padded and cu_seqlens_kv_padded should use padded values
+ expected_padded = torch.tensor([0, 8, 16], dtype=torch.int32)
+ assert torch.equal(result.cu_seqlens_q_padded, expected_padded)
+ assert torch.equal(result.cu_seqlens_kv_padded, expected_padded)
+
+ def test_packed_seq_params_cu_seqlens_unpadded_with_argmin(self):
+ """Test cu_seqlens_unpadded processing with argmin hint."""
+ batch = {
+ "cu_seqlens": torch.tensor([[0, 4, 8, 12, -1, -1]], dtype=torch.int32),
+ "cu_seqlens_argmin": torch.tensor(4), # Index where -1 starts
+ "cu_seqlens_unpadded": torch.tensor([[0, 3, 7, 10, -1, -1]], dtype=torch.int32),
+ "cu_seqlens_unpadded_argmin": torch.tensor(4), # Index where -1 starts
+ }
+
+ result = get_packed_seq_params(batch)
+
+ # Verify unpadded values are used for q/kv
+ expected_unpadded = torch.tensor([0, 3, 7, 10], dtype=torch.int32)
+ assert torch.equal(result.cu_seqlens_q, expected_unpadded)
+ assert torch.equal(result.cu_seqlens_kv, expected_unpadded)
+
+ # Verify padded values are set for _padded fields
+ expected_padded = torch.tensor([0, 4, 8, 12], dtype=torch.int32)
+ assert torch.equal(result.cu_seqlens_q_padded, expected_padded)
+ assert torch.equal(result.cu_seqlens_kv_padded, expected_padded)
+
+ def test_packed_seq_params_without_unpadded_fallback(self):
+ """Test fallback to cu_seqlens when cu_seqlens_unpadded is not provided."""
+ batch = {
+ "cu_seqlens": torch.tensor([[0, 5, 10, 15, -1]], dtype=torch.int32),
+ "max_seqlen": torch.tensor([[8]], dtype=torch.int32),
+ }
+
+ result = get_packed_seq_params(batch)
+
+ expected_cu_seqlens = torch.tensor([0, 5, 10, 15], dtype=torch.int32)
+
+ # Without unpadded, q/kv should use padded values
+ assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens)
+ assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens)
+
+ # Padded fields should be None when cu_seqlens_unpadded is not provided
+ # (to avoid slower TE kernel paths)
+ assert result.cu_seqlens_q_padded is None
+ assert result.cu_seqlens_kv_padded is None
+
+ def test_packed_seq_params_no_padding_in_cu_seqlens(self):
+ """Test when cu_seqlens has no -1 padding markers."""
+ batch = {
+ "cu_seqlens": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1 padding
+ "max_seqlen": torch.tensor([[7]], dtype=torch.int32),
+ }
+
+ result = get_packed_seq_params(batch)
+
+ # When no -1 present and min != -1, the tensor should remain as-is
+ expected = torch.tensor([0, 5, 10], dtype=torch.int32)
+ assert torch.equal(result.cu_seqlens_q, expected)
+ # Padded fields are None when cu_seqlens_unpadded is not provided
+ assert result.cu_seqlens_q_padded is None
+
+ def test_packed_seq_params_qkv_format_is_thd(self):
+ """Test that qkv_format is always set to 'thd'."""
+ batch = {
+ "cu_seqlens": torch.tensor([[0, 10, -1]], dtype=torch.int32),
+ }
+
+ result = get_packed_seq_params(batch)
+
+ assert result.qkv_format == "thd"
+
+ def test_packed_seq_params_cu_seqlens_unpadded_no_padding(self):
+ """Test cu_seqlens_unpadded with no padding markers."""
+ batch = {
+ "cu_seqlens": torch.tensor([[0, 6, 12]], dtype=torch.int32),
+ "cu_seqlens_unpadded": torch.tensor([[0, 5, 10]], dtype=torch.int32), # No -1
+ }
+
+ result = get_packed_seq_params(batch)
+
+ # Unpadded should be used as-is since no -1 and min != -1
+ expected_unpadded = torch.tensor([0, 5, 10], dtype=torch.int32)
+ expected_padded = torch.tensor([0, 6, 12], dtype=torch.int32)
+
+ assert torch.equal(result.cu_seqlens_q, expected_unpadded)
+ assert torch.equal(result.cu_seqlens_q_padded, expected_padded)
+
class TestCreateLossFunction:
"""Tests for the _create_loss_function helper function."""
diff --git a/tests/unit_tests/training/test_vlm_step.py b/tests/unit_tests/training/test_vlm_step.py
index 7a4cd47016..778a15b326 100755
--- a/tests/unit_tests/training/test_vlm_step.py
+++ b/tests/unit_tests/training/test_vlm_step.py
@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
import torch
from megatron.bridge.training.utils.visual_inputs import Qwen2_5_VLVisualInputs
-from megatron.bridge.training.vlm_step import forward_step, get_batch, get_batch_from_iterator
+from megatron.bridge.training.vlm_step import (
+ forward_step,
+ get_batch,
+ get_batch_from_iterator,
+ pack_batch_sequences,
+)
class _Iterator:
@@ -92,6 +98,30 @@ def _as_nocuda(t):
assert out_vi.pixel_values is not None and out_vi.image_grid_thw is not None
+class _MockProcessGroup:
+ """Mock process group with rank/size methods for testing."""
+
+ def rank(self):
+ return 0
+
+ def size(self):
+ return 1
+
+
+class _MockPGCollection:
+ """Mock PG collection for testing."""
+
+ def __init__(self, cp_size=1):
+ self.pp = _MockProcessGroup()
+ self._cp_size = cp_size
+
+ @property
+ def cp(self):
+ pg = _MockProcessGroup()
+ pg.size = lambda: self._cp_size
+ return pg
+
+
def test_get_batch_padding_paths(monkeypatch):
# Simulate both first and last pipeline stages so tensors are returned
monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True)
@@ -135,21 +165,8 @@ def test_get_batch_padding_paths(monkeypatch):
# Iterator
it = _Iterator(batch)
- # Create a proper mock pg_collection with rank/size methods
- class _MockProcessGroup:
- def rank(self):
- return 0
-
- def size(self):
- return 1
-
- class _PG:
- def __init__(self):
- self.pp = _MockProcessGroup()
- self.cp = _MockProcessGroup()
-
tokens, labels, loss_mask, attention_mask, position_ids, *_ = get_batch(
- it, cfg, use_mtp=False, pg_collection=_PG()
+ it, cfg, use_mtp=False, pg_collection=_MockPGCollection()
)
# Length padded up to min(seq_cap, ceil_to_128(4)) == 32
assert tokens.shape[1] == 32
@@ -158,6 +175,175 @@ def __init__(self):
assert position_ids.shape[1] == 32
+def test_get_batch_enable_packing_path(monkeypatch):
+ """Test get_batch with pack_sequences_in_batch=True (enable_packing path)."""
+ # Simulate both first and last pipeline stages so tensors are returned
+ monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True)
+ monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_last_stage", lambda pg: True, raising=True)
+
+ # Disable context parallel slicing effects
+ monkeypatch.setattr(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ lambda x: x,
+ raising=True,
+ )
+
+ # Config with packing enabled
+ cfg = type("Cfg", (), {})()
+ cfg.model = type(
+ "M",
+ (),
+ {
+ "seq_length": 64,
+ "pipeline_model_parallel_size": 1,
+ },
+ )()
+ cfg.dataset = type(
+ "D",
+ (),
+ {
+ "skip_getting_attention_mask_from_dataset": True,
+ "pack_sequences_in_batch": True, # Enable packing
+ },
+ )()
+
+ # Batch with 2 sequences of different lengths (with padding)
+ # Seq 1: [1, 2, 3, 0, 0, 0, 0, 0] - length 3
+ # Seq 2: [4, 5, 6, 7, 8, 0, 0, 0] - length 5
+ tokens = torch.tensor(
+ [
+ [1, 2, 3, 0, 0, 0, 0, 0],
+ [4, 5, 6, 7, 8, 0, 0, 0],
+ ]
+ )
+ labels = torch.tensor(
+ [
+ [2, 3, -100, -100, -100, -100, -100, -100],
+ [5, 6, 7, 8, -100, -100, -100, -100],
+ ]
+ )
+ loss_mask = torch.tensor(
+ [
+ [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
+ ]
+ )
+ position_ids = torch.arange(8).unsqueeze(0).expand(2, -1).clone()
+
+ vi = Qwen2_5_VLVisualInputs(pixel_values=torch.randn(1, 1, 3, 4, 4), image_grid_thw=torch.tensor([[[1, 2, 2]]]))
+ batch = {
+ "input_ids": tokens,
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "position_ids": position_ids,
+ "attention_mask": None,
+ "visual_inputs": vi,
+ }
+
+ it = _Iterator(batch)
+
+ (
+ out_tokens,
+ out_labels,
+ out_loss_mask,
+ out_attention_mask,
+ out_position_ids,
+ cu_seqlens,
+ max_seqlen,
+ visual_inputs,
+ ) = get_batch(it, cfg, use_mtp=False, pg_collection=_MockPGCollection())
+
+ # Verify packing occurred
+ # With pad_to_multiple_of=1 (cp_size=1), total packed length = 3 + 5 = 8
+ assert out_tokens.shape == (1, 8), f"Expected packed shape (1, 8), got {out_tokens.shape}"
+ assert out_labels.shape == (1, 8)
+ assert out_loss_mask.shape == (1, 8)
+ assert out_position_ids.shape == (1, 8)
+
+ # Verify cu_seqlens is populated (not None)
+ assert cu_seqlens is not None, "cu_seqlens should be set when packing is enabled"
+ assert cu_seqlens.tolist() == [0, 3, 8], f"Expected cu_seqlens [0, 3, 8], got {cu_seqlens.tolist()}"
+
+ # Verify max_seqlen
+ assert max_seqlen is not None, "max_seqlen should be set when packing is enabled"
+ assert max_seqlen.item() == 5, f"Expected max_seqlen 5, got {max_seqlen.item()}"
+
+ # Verify attention_mask is None for packed sequences
+ assert out_attention_mask is None, "attention_mask should be None for packed sequences"
+
+ # Verify packed tokens content
+ expected_tokens = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])
+ assert torch.equal(out_tokens.cpu(), expected_tokens), f"Expected {expected_tokens}, got {out_tokens}"
+
+ # Verify visual_inputs passed through
+ assert visual_inputs is not None
+
+
+def test_get_batch_enable_packing_with_cp(monkeypatch):
+ """Test get_batch packing with context parallelism (pad_to_multiple_of > 1)."""
+ monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True)
+ monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_last_stage", lambda pg: True, raising=True)
+ monkeypatch.setattr(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ lambda x: x,
+ raising=True,
+ )
+
+ cfg = type("Cfg", (), {})()
+ cfg.model = type("M", (), {"seq_length": 64, "pipeline_model_parallel_size": 1})()
+ cfg.dataset = type(
+ "D",
+ (),
+ {
+ "skip_getting_attention_mask_from_dataset": True,
+ "pack_sequences_in_batch": True,
+ },
+ )()
+
+ # Sequences: length 3 and length 5
+ # With CP=2, pad_to_multiple_of = 2*2 = 4
+ # Seq 1: 3 -> padded to 4
+ # Seq 2: 5 -> padded to 6
+ # Total: 4 + 6 = 10
+ tokens = torch.tensor(
+ [
+ [1, 2, 3, 0, 0, 0, 0, 0],
+ [4, 5, 6, 7, 8, 0, 0, 0],
+ ]
+ )
+ labels = torch.tensor(
+ [
+ [2, 3, -100, -100, -100, -100, -100, -100],
+ [5, 6, 7, 8, -100, -100, -100, -100],
+ ]
+ )
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.arange(8).unsqueeze(0).expand(2, -1).clone()
+
+ batch = {
+ "input_ids": tokens,
+ "labels": labels,
+ "loss_mask": loss_mask,
+ "position_ids": position_ids,
+ "attention_mask": None,
+ "visual_inputs": None,
+ }
+
+ it = _Iterator(batch)
+
+ # Use CP size of 2
+ out_tokens, out_labels, out_loss_mask, _, out_position_ids, cu_seqlens, max_seqlen, _ = get_batch(
+ it, cfg, use_mtp=False, pg_collection=_MockPGCollection(cp_size=2)
+ )
+
+ # With CP=2, pad_to_multiple_of = 4
+ # Seq 1: 3 -> 4, Seq 2: 5 -> 8 (next multiple of 4)
+ # Total: 4 + 8 = 12
+ assert out_tokens.shape[1] == 12, f"Expected packed length 12, got {out_tokens.shape[1]}"
+ assert cu_seqlens.tolist() == [0, 4, 12], f"Expected cu_seqlens [0, 4, 12], got {cu_seqlens.tolist()}"
+ assert max_seqlen.item() == 8, f"Expected max_seqlen 8, got {max_seqlen.item()}"
+
+
def test_forward_step_schedule_plan(monkeypatch):
# Configure pipeline last/first to enable labels & loss_mask path
monkeypatch.setattr("megatron.core.pipeline_parallel.utils.is_pp_first_stage", lambda pg: True, raising=True)
@@ -264,3 +450,382 @@ def __init__(self):
# Execute schedule plan path
plan, loss_fn = forward_step(state, it, model, return_schedule_plan=True)
assert isinstance(plan, torch.Tensor)
+
+
+class TestPackBatchSequences:
+ """Tests for the pack_batch_sequences function."""
+
+ def test_basic_packing(self):
+ """Test basic sequence packing functionality."""
+ batch_size, seq_len = 2, 8
+ # Tokens with padding at the end (pad_token_id=0)
+ tokens = torch.tensor(
+ [
+ [1, 2, 3, 0, 0, 0, 0, 0], # length 3
+ [4, 5, 6, 7, 0, 0, 0, 0], # length 4
+ ]
+ )
+ labels = torch.tensor(
+ [
+ [2, 3, -100, -100, -100, -100, -100, -100],
+ [5, 6, 7, -100, -100, -100, -100, -100],
+ ]
+ )
+ loss_mask = torch.tensor(
+ [
+ [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ ]
+ )
+ position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = None
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=1,
+ )
+
+ packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result
+
+ # Packed output should have shape [1, total_valid_len]
+ assert packed_tokens.shape[0] == 1
+ total_len = packed_tokens.shape[1]
+ assert total_len == 7 # 3 + 4
+
+ # cu_seqlens should have num_sequences + 1 elements
+ assert len(cu_seqlens) == 3 # [0, 3, 7]
+ assert cu_seqlens[0] == 0
+ assert cu_seqlens[1] == 3 # first sequence length
+ assert cu_seqlens[2] == 7 # total length
+
+ # max_seqlen should be max of sequence lengths
+ assert max_seqlen.item() == 4
+
+ # Attention mask should be None for packed sequences
+ assert packed_attn is None
+
+ def test_packing_with_pad_to_multiple_of(self):
+ """Test packing with padding to a multiple (for CP compatibility)."""
+ batch_size = 2
+ tokens = torch.tensor(
+ [
+ [1, 2, 3, 0, 0, 0, 0, 0, 0, 0], # length 3 -> padded to 4 (mult of 2)
+ [4, 5, 6, 7, 8, 0, 0, 0, 0, 0], # length 5 -> padded to 6 (mult of 2)
+ ]
+ )
+ labels = torch.tensor(
+ [
+ [2, 3, -100, -100, -100, -100, -100, -100, -100, -100],
+ [5, 6, 7, 8, -100, -100, -100, -100, -100, -100],
+ ]
+ )
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.arange(10).unsqueeze(0).expand(batch_size, -1)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=2, # Pad each sequence to multiple of 2
+ )
+
+ packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result
+
+ # Total length should be 4 + 6 = 10 (padded lengths)
+ assert packed_tokens.shape[1] == 10
+
+ # cu_seqlens should use padded lengths
+ assert cu_seqlens[0] == 0
+ assert cu_seqlens[1] == 4 # 3 -> 4 (padded)
+ assert cu_seqlens[2] == 10 # 5 -> 6, total = 4 + 6
+
+ # max_seqlen should be 6 (longest padded sequence)
+ assert max_seqlen.item() == 6
+
+ def test_packing_with_larger_multiple(self):
+ """Test packing with larger pad_to_multiple_of (e.g., for CP=4)."""
+ tokens = torch.tensor(
+ [
+ [1, 2, 0, 0, 0, 0, 0, 0], # length 2 -> padded to 4
+ [3, 4, 5, 0, 0, 0, 0, 0], # length 3 -> padded to 4
+ ]
+ )
+ labels = torch.tensor(
+ [
+ [2, -100, -100, -100, -100, -100, -100, -100],
+ [4, 5, -100, -100, -100, -100, -100, -100],
+ ]
+ )
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.arange(8).unsqueeze(0).expand(2, -1)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=4,
+ )
+
+ packed_tokens, *_, cu_seqlens, max_seqlen = result
+
+ # Both sequences padded to 4, total = 8
+ assert packed_tokens.shape[1] == 8
+ assert cu_seqlens.tolist() == [0, 4, 8]
+ assert max_seqlen.item() == 4
+
+ def test_packing_single_sequence(self):
+ """Test packing a single sequence."""
+ tokens = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]) # length 5
+ labels = torch.tensor([[2, 3, 4, 5, -100, -100, -100, -100]])
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.arange(8).unsqueeze(0)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=1,
+ )
+
+ packed_tokens, *_, cu_seqlens, max_seqlen = result
+
+ assert packed_tokens.shape[1] == 5
+ assert cu_seqlens.tolist() == [0, 5]
+ assert max_seqlen.item() == 5
+
+ def test_packing_no_padding_sequences(self):
+ """Test packing sequences with no padding."""
+ tokens = torch.tensor(
+ [
+ [1, 2, 3, 4],
+ [5, 6, 7, 8],
+ ]
+ )
+ labels = torch.tensor(
+ [
+ [2, 3, 4, -100],
+ [6, 7, 8, -100],
+ ]
+ )
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.arange(4).unsqueeze(0).expand(2, -1)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=1,
+ )
+
+ packed_tokens, *_, cu_seqlens, max_seqlen = result
+
+ # Both sequences full length
+ assert packed_tokens.shape[1] == 8
+ assert cu_seqlens.tolist() == [0, 4, 8]
+
+ def test_packing_preserves_loss_mask_zeros(self):
+ """Test that loss_mask zeros are preserved during packing."""
+ tokens = torch.tensor([[1, 2, 3, 0, 0]])
+ labels = torch.tensor([[2, 3, -100, -100, -100]])
+ loss_mask = torch.tensor([[1.0, 0.0, 1.0, 0.0, 0.0]]) # Second token masked
+ position_ids = torch.arange(5).unsqueeze(0)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=1,
+ )
+
+ packed_tokens, packed_labels, packed_loss_mask, *_ = result
+
+ # Only first 3 tokens should be kept
+ assert packed_loss_mask.shape[1] == 3
+ assert packed_loss_mask[0, 0].item() == 1.0
+ assert packed_loss_mask[0, 1].item() == 0.0 # Preserved
+ assert packed_loss_mask[0, 2].item() == 1.0
+
+ def test_packing_position_ids_reset(self):
+ """Test that position_ids are correctly packed."""
+ tokens = torch.tensor(
+ [
+ [1, 2, 0, 0],
+ [3, 4, 5, 0],
+ ]
+ )
+ labels = torch.zeros_like(tokens)
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.tensor(
+ [
+ [0, 1, 2, 3],
+ [0, 1, 2, 3],
+ ]
+ )
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=1,
+ )
+
+ _, _, _, _, packed_pos, *_ = result
+
+ # Position IDs should be extracted from original sequences
+ assert packed_pos.shape[1] == 5 # 2 + 3
+ assert packed_pos[0, 0].item() == 0 # First seq, pos 0
+ assert packed_pos[0, 1].item() == 1 # First seq, pos 1
+ assert packed_pos[0, 2].item() == 0 # Second seq, pos 0
+ assert packed_pos[0, 3].item() == 1 # Second seq, pos 1
+ assert packed_pos[0, 4].item() == 2 # Second seq, pos 2
+
+ def test_packing_empty_batch_warning(self, caplog):
+ """Test that all-padding batch returns empty tensors with warning."""
+ tokens = torch.tensor([[0, 0, 0, 0]]) # All padding
+ labels = torch.tensor([[-100, -100, -100, -100]])
+ loss_mask = torch.zeros(1, 4)
+ position_ids = torch.arange(4).unsqueeze(0)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=1,
+ )
+
+ packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result
+
+ # No valid sequences found, should return empty tensors
+ assert packed_tokens.shape == (1, 0)
+ assert packed_labels.shape == (1, 0)
+ assert packed_loss_mask.shape == (1, 0)
+ assert packed_pos.shape == (1, 0)
+ # cu_seqlens should have just [0] for empty batch
+ assert len(cu_seqlens) == 1
+ assert cu_seqlens[0].item() == 0
+ assert max_seqlen.item() == 0
+
+ def test_packing_different_dtypes(self):
+ """Test packing with different tensor dtypes."""
+ tokens = torch.tensor([[1, 2, 3, 0]], dtype=torch.long)
+ labels = torch.tensor([[2, 3, -100, -100]], dtype=torch.long)
+ loss_mask = torch.tensor([[1.0, 1.0, 0.0, 0.0]], dtype=torch.float32)
+ position_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=1,
+ )
+
+ packed_tokens, packed_labels, packed_loss_mask, _, packed_pos, cu_seqlens, _ = result
+
+ # Dtypes should be preserved
+ assert packed_tokens.dtype == torch.long
+ assert packed_labels.dtype == torch.long
+ assert packed_loss_mask.dtype == torch.float32
+ assert packed_pos.dtype == torch.long
+ assert cu_seqlens.dtype == torch.int32
+
+ def test_packing_padding_extends_position_ids(self):
+ """Test that padding extends position_ids correctly."""
+ tokens = torch.tensor([[1, 2, 3, 0]]) # length 3
+ labels = torch.zeros_like(tokens)
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.tensor([[0, 1, 2, 3]])
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ pad_to_multiple_of=4, # Pad to 4
+ )
+
+ _, _, _, _, packed_pos, cu_seqlens, _ = result
+
+ # Length should be 4 (padded)
+ assert packed_pos.shape[1] == 4
+
+ # Original positions should be preserved
+ assert packed_pos[0, 0].item() == 0
+ assert packed_pos[0, 1].item() == 1
+ assert packed_pos[0, 2].item() == 2
+ # Padding position should be extended
+ assert packed_pos[0, 3].item() == 3
+
+ def test_packing_cu_seqlens_dtype(self):
+ """Test that cu_seqlens is int32 as expected by attention kernels."""
+ tokens = torch.tensor([[1, 2, 0]])
+ labels = torch.zeros_like(tokens)
+ loss_mask = torch.ones_like(tokens, dtype=torch.float)
+ position_ids = torch.arange(3).unsqueeze(0)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ )
+
+ _, _, _, _, _, cu_seqlens, _ = result
+
+ assert cu_seqlens.dtype == torch.int32
+
+ def test_packing_gpu_tensor(self):
+ """Test packing works on GPU if available."""
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA not available")
+
+ tokens = torch.tensor([[1, 2, 3, 0, 0]], device="cuda")
+ labels = torch.tensor([[2, 3, -100, -100, -100]], device="cuda")
+ loss_mask = torch.ones_like(tokens, dtype=torch.float, device="cuda")
+ position_ids = torch.arange(5, device="cuda").unsqueeze(0)
+
+ result = pack_batch_sequences(
+ tokens=tokens,
+ labels=labels,
+ loss_mask=loss_mask,
+ attention_mask=None,
+ position_ids=position_ids,
+ pad_token_id=0,
+ )
+
+ packed_tokens, _, _, _, _, cu_seqlens, _ = result
+
+ assert packed_tokens.device.type == "cuda"
+ assert cu_seqlens.device.type == "cuda"
diff --git a/tests/unit_tests/utils/test_slice_batch_for_cp.py b/tests/unit_tests/utils/test_slice_batch_for_cp.py
new file mode 100644
index 0000000000..2ed4b96eaf
--- /dev/null
+++ b/tests/unit_tests/utils/test_slice_batch_for_cp.py
@@ -0,0 +1,495 @@
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for slice_batch_for_context_parallel function in common_utils."""
+
+from dataclasses import dataclass
+from typing import Optional
+from unittest.mock import MagicMock, patch
+
+import torch
+
+from megatron.bridge.utils.common_utils import slice_batch_for_context_parallel
+
+
+@dataclass
+class MockPackedSeqParams:
+ """Mock PackedSeqParams for testing THD format."""
+
+ cu_seqlens_q: torch.Tensor
+ cu_seqlens_kv: torch.Tensor
+ cu_seqlens_q_padded: Optional[torch.Tensor] = None
+ cu_seqlens_kv_padded: Optional[torch.Tensor] = None
+ max_seqlen_q: Optional[torch.Tensor] = None
+ max_seqlen_kv: Optional[torch.Tensor] = None
+ qkv_format: str = "thd"
+
+
+class MockPGCollection:
+ """Mock ProcessGroupCollection for testing."""
+
+ def __init__(self, cp_size: int = 1, cp_rank: int = 0):
+ self._cp_size = cp_size
+ self._cp_rank = cp_rank
+ self.cp = MagicMock()
+ self.cp.size.return_value = cp_size
+ self.cp.rank.return_value = cp_rank
+
+
+class TestSliceBatchForContextParallelCpSize1:
+ """Tests for slice_batch_for_context_parallel when CP size <= 1 (no-op case)."""
+
+ def test_cp_size_1_returns_unchanged(self):
+ """Test that tensors are returned unchanged when cp_size=1."""
+ batch_size, seq_len, hidden = 2, 16, 64
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+ labels = torch.randint(0, 1000, (batch_size, seq_len))
+ loss_mask = torch.ones(batch_size, seq_len)
+ position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = torch.ones(batch_size, 1, seq_len, seq_len)
+
+ pg_collection = MockPGCollection(cp_size=1, cp_rank=0)
+
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ packed_seq_params=None,
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result
+
+ # All tensors should be unchanged
+ assert torch.equal(out_embeds, inputs_embeds)
+ assert torch.equal(out_labels, labels)
+ assert torch.equal(out_loss_mask, loss_mask)
+ assert torch.equal(out_pos_ids, position_ids)
+ assert torch.equal(out_attn_mask, attention_mask)
+
+ def test_cp_size_0_returns_unchanged(self):
+ """Test that tensors are returned unchanged when cp_size=0."""
+ batch_size, seq_len, hidden = 1, 8, 32
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+ labels = torch.randint(0, 100, (batch_size, seq_len))
+ loss_mask = torch.ones(batch_size, seq_len)
+ position_ids = torch.arange(seq_len).unsqueeze(0)
+ attention_mask = None
+
+ pg_collection = MockPGCollection(cp_size=0, cp_rank=0)
+
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ packed_seq_params=None,
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result
+
+ assert torch.equal(out_embeds, inputs_embeds)
+ assert torch.equal(out_labels, labels)
+
+
+class TestSliceBatchForContextParallelBSHD:
+ """Tests for slice_batch_for_context_parallel with BSHD format (non-packed)."""
+
+ def test_bshd_format_uses_get_batch_on_this_cp_rank(self):
+ """Test that BSHD format triggers get_batch_on_this_cp_rank."""
+ batch_size, seq_len, hidden = 2, 16, 64
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+ labels = torch.randint(0, 1000, (batch_size, seq_len))
+ loss_mask = torch.ones(batch_size, seq_len)
+ position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = torch.ones(batch_size, 1, seq_len, seq_len)
+
+ pg_collection = MockPGCollection(cp_size=2, cp_rank=0)
+
+ # Mock get_batch_on_this_cp_rank to return sliced tensors
+ def mock_get_batch(batch_dict, cp_group=None):
+ # Simulate slicing by returning half of each tensor
+ result = {}
+ for k, v in batch_dict.items():
+ if v is not None and isinstance(v, torch.Tensor):
+ if k == "decoder_input": # (B, T, D) format after transpose
+ result[k] = v[:, : seq_len // 2, :]
+ elif v.dim() == 4: # attention_mask
+ result[k] = v[:, :, : seq_len // 2, : seq_len // 2]
+ else:
+ result[k] = v[:, : seq_len // 2]
+ else:
+ result[k] = v
+ return result
+
+ with patch(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ side_effect=mock_get_batch,
+ ):
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ packed_seq_params=None, # BSHD format
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result
+
+ # Check that output tensors are sliced (half the sequence length)
+ assert out_embeds.shape[0] == seq_len // 2 # T dimension
+ assert out_embeds.shape[1] == batch_size # B dimension
+ assert out_labels.shape[1] == seq_len // 2
+ assert out_loss_mask.shape[1] == seq_len // 2
+
+ def test_bshd_format_handles_none_tensors(self):
+ """Test that BSHD format handles None tensors gracefully."""
+ batch_size, seq_len, hidden = 1, 8, 32
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+
+ pg_collection = MockPGCollection(cp_size=2, cp_rank=0)
+
+ def mock_get_batch(batch_dict, cp_group=None):
+ result = {}
+ for k, v in batch_dict.items():
+ if v is not None and isinstance(v, torch.Tensor):
+ if k == "decoder_input":
+ result[k] = v[:, : seq_len // 2, :]
+ else:
+ result[k] = v[:, : seq_len // 2] if v.dim() > 1 else v
+ else:
+ result[k] = None
+ return result
+
+ with patch(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ side_effect=mock_get_batch,
+ ):
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=None,
+ loss_mask=None,
+ position_ids=None,
+ attention_mask=None,
+ packed_seq_params=None,
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result
+
+ assert out_embeds is not None
+ assert out_labels is None
+ assert out_loss_mask is None
+ assert out_pos_ids is None
+ assert out_attn_mask is None
+
+
+class TestSliceBatchForContextParallelTHD:
+ """Tests for slice_batch_for_context_parallel with THD (packed) format."""
+
+ def test_thd_format_uses_tex_partitioned_indices(self):
+ """Test that THD format triggers TransformerEngine's thd_get_partitioned_indices."""
+ batch_size, seq_len, hidden = 1, 16, 64
+ cp_size = 2
+ cp_rank = 0
+
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+ labels = torch.randint(0, 1000, (batch_size, seq_len))
+ loss_mask = torch.ones(batch_size, seq_len)
+ position_ids = torch.arange(seq_len).unsqueeze(0)
+
+ cu_seqlens = torch.tensor([0, 8, 16], dtype=torch.int32)
+ packed_seq_params = MockPackedSeqParams(
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_kv=cu_seqlens,
+ cu_seqlens_q_padded=cu_seqlens,
+ qkv_format="thd",
+ )
+
+ pg_collection = MockPGCollection(cp_size=cp_size, cp_rank=cp_rank)
+
+ # Mock tex.thd_get_partitioned_indices
+ mock_indices = torch.tensor([0, 1, 2, 3, 8, 9, 10, 11]) # First half of each sequence
+
+ with patch.dict("sys.modules", {"transformer_engine_torch": MagicMock()}):
+ import sys
+
+ mock_tex = sys.modules["transformer_engine_torch"]
+ mock_tex.thd_get_partitioned_indices.return_value = mock_indices
+
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=None,
+ packed_seq_params=packed_seq_params,
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result
+
+ # Verify tex.thd_get_partitioned_indices was called
+ mock_tex.thd_get_partitioned_indices.assert_called_once()
+
+ # Check output shapes match the indices
+ assert out_embeds.shape[0] == len(mock_indices) # T dimension
+ assert out_embeds.shape[1] == batch_size # B dimension
+
+ def test_thd_format_with_padded_cu_seqlens(self):
+ """Test THD format uses cu_seqlens_q_padded when available."""
+ batch_size, seq_len, hidden = 1, 20, 32
+ cp_size = 2
+ cp_rank = 1
+
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+ labels = torch.randint(0, 100, (batch_size, seq_len))
+ loss_mask = torch.ones(batch_size, seq_len)
+ position_ids = torch.arange(seq_len).unsqueeze(0)
+
+ # Padded cu_seqlens (includes padding for divisibility)
+ cu_seqlens = torch.tensor([0, 6, 16], dtype=torch.int32)
+ cu_seqlens_padded = torch.tensor([0, 10, 20], dtype=torch.int32)
+
+ packed_seq_params = MockPackedSeqParams(
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_kv=cu_seqlens,
+ cu_seqlens_q_padded=cu_seqlens_padded,
+ qkv_format="thd",
+ )
+
+ pg_collection = MockPGCollection(cp_size=cp_size, cp_rank=cp_rank)
+
+ mock_indices = torch.tensor([5, 6, 7, 8, 9, 15, 16, 17, 18, 19]) # Second half
+
+ with patch.dict("sys.modules", {"transformer_engine_torch": MagicMock()}):
+ import sys
+
+ mock_tex = sys.modules["transformer_engine_torch"]
+ mock_tex.thd_get_partitioned_indices.return_value = mock_indices
+
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=None,
+ packed_seq_params=packed_seq_params,
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, out_labels, out_loss_mask, out_pos_ids, out_attn_mask = result
+
+ # Verify padded cu_seqlens was used
+ call_args = mock_tex.thd_get_partitioned_indices.call_args
+ assert torch.equal(call_args[0][0], cu_seqlens_padded)
+
+ def test_thd_format_without_padded_cu_seqlens_fallback(self):
+ """Test THD format falls back to cu_seqlens_q when padded version is None."""
+ batch_size, seq_len, hidden = 1, 12, 16
+ cp_size = 2
+ cp_rank = 0
+
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+ labels = torch.randint(0, 50, (batch_size, seq_len))
+ loss_mask = torch.ones(batch_size, seq_len)
+ position_ids = torch.arange(seq_len).unsqueeze(0)
+
+ cu_seqlens = torch.tensor([0, 6, 12], dtype=torch.int32)
+
+ packed_seq_params = MockPackedSeqParams(
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_kv=cu_seqlens,
+ cu_seqlens_q_padded=None, # No padded version
+ qkv_format="thd",
+ )
+
+ pg_collection = MockPGCollection(cp_size=cp_size, cp_rank=cp_rank)
+
+ mock_indices = torch.tensor([0, 1, 2, 6, 7, 8])
+
+ with patch.dict("sys.modules", {"transformer_engine_torch": MagicMock()}):
+ import sys
+
+ mock_tex = sys.modules["transformer_engine_torch"]
+ mock_tex.thd_get_partitioned_indices.return_value = mock_indices
+
+ slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=loss_mask,
+ position_ids=position_ids,
+ attention_mask=None,
+ packed_seq_params=packed_seq_params,
+ pg_collection=pg_collection,
+ )
+
+ # Verify cu_seqlens_q was used as fallback
+ call_args = mock_tex.thd_get_partitioned_indices.call_args
+ assert torch.equal(call_args[0][0], cu_seqlens)
+
+
+class TestSliceBatchForContextParallelTranspose:
+ """Tests verifying correct tensor transpose operations."""
+
+ def test_input_transposed_before_slicing(self):
+ """Test that inputs_embeds is transposed from (T,B,D) to (B,T,D) for slicing."""
+ batch_size, seq_len, hidden = 2, 8, 32
+
+ # Input is (T, B, D) format
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+
+ pg_collection = MockPGCollection(cp_size=2, cp_rank=0)
+
+ captured_batch = {}
+
+ def mock_get_batch(batch_dict, cp_group=None):
+ captured_batch.update(batch_dict)
+ # Check that decoder_input is in (B, T, D) format
+ di = batch_dict.get("decoder_input")
+ if di is not None:
+ assert di.shape == (batch_size, seq_len, hidden), f"Expected (B,T,D), got {di.shape}"
+ return batch_dict
+
+ with patch(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ side_effect=mock_get_batch,
+ ):
+ slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=None,
+ loss_mask=None,
+ position_ids=None,
+ attention_mask=None,
+ packed_seq_params=None, # BSHD format
+ pg_collection=pg_collection,
+ )
+
+ # Verify decoder_input was captured with correct shape
+ assert "decoder_input" in captured_batch
+ assert captured_batch["decoder_input"].shape == (batch_size, seq_len, hidden)
+
+ def test_output_transposed_back_to_tbd(self):
+ """Test that outputs are transposed back to (T,B,D) format."""
+ batch_size, seq_len, hidden = 2, 16, 64
+ sliced_seq_len = seq_len // 2
+
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+
+ pg_collection = MockPGCollection(cp_size=2, cp_rank=0)
+
+ def mock_get_batch(batch_dict, cp_group=None):
+ result = {}
+ for k, v in batch_dict.items():
+ if v is not None and isinstance(v, torch.Tensor):
+ if k == "decoder_input":
+ # Return sliced (B, T/2, D)
+ result[k] = v[:, :sliced_seq_len, :]
+ else:
+ result[k] = v
+ else:
+ result[k] = v
+ return result
+
+ with patch(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ side_effect=mock_get_batch,
+ ):
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=None,
+ loss_mask=None,
+ position_ids=None,
+ attention_mask=None,
+ packed_seq_params=None,
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, *_ = result
+
+ # Output should be in (T, B, D) format
+ assert out_embeds.shape == (sliced_seq_len, batch_size, hidden)
+
+
+class TestSliceBatchForContextParallelEdgeCases:
+ """Edge case tests for slice_batch_for_context_parallel."""
+
+ def test_none_inputs_embeds_with_cp_size_gt_1(self):
+ """Test handling when inputs_embeds is None but CP is enabled."""
+ pg_collection = MockPGCollection(cp_size=2, cp_rank=0)
+
+ def mock_get_batch(batch_dict, cp_group=None):
+ return batch_dict
+
+ with patch(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ side_effect=mock_get_batch,
+ ):
+ result = slice_batch_for_context_parallel(
+ inputs_embeds=None,
+ labels=torch.randint(0, 100, (1, 8)),
+ loss_mask=torch.ones(1, 8),
+ position_ids=torch.arange(8).unsqueeze(0),
+ attention_mask=None,
+ packed_seq_params=None,
+ pg_collection=pg_collection,
+ )
+
+ out_embeds, *_ = result
+ assert out_embeds is None
+
+ def test_non_thd_qkv_format_uses_bshd_path(self):
+ """Test that non-THD qkv_format (e.g., 'sbhd') uses BSHD slicing path."""
+ batch_size, seq_len, hidden = 1, 8, 16
+
+ inputs_embeds = torch.randn(seq_len, batch_size, hidden)
+ labels = torch.randint(0, 50, (batch_size, seq_len))
+
+ # Non-THD format
+ packed_seq_params = MockPackedSeqParams(
+ cu_seqlens_q=torch.tensor([0, 8]),
+ cu_seqlens_kv=torch.tensor([0, 8]),
+ qkv_format="sbhd", # Not THD
+ )
+
+ pg_collection = MockPGCollection(cp_size=2, cp_rank=0)
+
+ mock_called = {"get_batch": False}
+
+ def mock_get_batch(batch_dict, cp_group=None):
+ mock_called["get_batch"] = True
+ return batch_dict
+
+ with patch(
+ "megatron.core.utils.get_batch_on_this_cp_rank",
+ side_effect=mock_get_batch,
+ ):
+ slice_batch_for_context_parallel(
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ loss_mask=None,
+ position_ids=None,
+ attention_mask=None,
+ packed_seq_params=packed_seq_params,
+ pg_collection=pg_collection,
+ )
+
+ # Verify BSHD path was used (get_batch_on_this_cp_rank called)
+ assert mock_called["get_batch"]