Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/vlm/ministral3.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Ministral 3 models support multimodal tasks including image captioning, visual q
Ministral family models are supported via the Bridge system with auto-detected configuration and weight mapping.

```{important}
Please update `transformers` version to 5.0.0rc0 in order to use the Ministral 3 models.
Please upgrade to `transformers` v5 and upgrade `mistral-common` in order to use the Ministral 3 models.
```

## Available Models
Expand Down
3 changes: 3 additions & 0 deletions src/megatron/bridge/data/vlm_datasets/hf_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class HFDatasetConversationProvider(DatasetProvider):
# DataloaderConfig fields are inherited (num_workers, dataloader_type, etc.)
dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single"

# Enable batch-level online sequence packing (dataset-level packing is available in FinetuneDatasetProvider)
pack_sequences_in_batch: bool = False

def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]:
registry: Dict[str, Callable[..., List[Dict[str, Any]]]] = {
"make_rdr_dataset": make_rdr_dataset,
Expand Down
31 changes: 27 additions & 4 deletions src/megatron/bridge/models/gemma/gemma3_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,34 @@ def __init__(
**kwargs,
)

def forward(
self,
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
cp_group: torch.distributed.ProcessGroup | None = None,
) -> Tensor:
"""Get global and local rope embedding.

Note: Caching is bypassed when cp_group is provided since ProcessGroup is unhashable.
"""
# ProcessGroup is unhashable, so bypass caching when cp_group is provided
if cp_group is not None:
rope_global = super().forward(max_seq_len, offset, packed_seq, cp_group)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, cp_group)
return rope_local, rope_global
return self._forward_cached(max_seq_len, offset, packed_seq)

@lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
"""Get global and local rope embedding"""
rope_global = super().forward(max_seq_len, offset, packed_seq)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq)
def _forward_cached(
self,
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
) -> Tensor:
"""Cached forward for hashable parameters only."""
rope_global = super().forward(max_seq_len, offset, packed_seq, None)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, None)
return rope_local, rope_global


Expand Down
54 changes: 43 additions & 11 deletions src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,30 @@

import types
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import MegatronModule
from torch import Tensor
from transformers import AutoModel, Gemma3Model

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
from megatron.bridge.utils.common_utils import (
hook_hf_module_setattr_for_tp_grad_sync,
slice_batch_for_context_parallel,
)
from megatron.bridge.utils.import_utils import safe_import_from


if TYPE_CHECKING:
from megatron.core.packed_seq_params import PackedSeqParams


TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm")


Expand Down Expand Up @@ -110,12 +118,16 @@ def forward(
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
) -> tuple[Tensor, Tensor | None]:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
Forward pass combining HuggingFace vision encoder with Megatron language model.

Returns:
tuple: (output_tensor, loss_mask) where output_tensor contains model output
and loss_mask is the CP-sliced mask for consistent loss computation.
"""
if self.pre_process:
if inputs_embeds is None:
Expand All @@ -134,7 +146,7 @@ def forward(
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)

if inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = special_image_mask.sum(dim=1).item(dim=0)[0]
image_tokens_in_text = special_image_mask[:, :, 0].sum().item()
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
Expand All @@ -144,18 +156,38 @@ def forward(
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D)

# Apply sequence parallelism scatter if enabled
if self.config.sequence_parallel:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)

# Compute attention mask on FULL sequence (before CP slicing)
# This is needed because image regions need bidirectional attention
attention_mask = self._compute_attention_mask(input_ids)

# CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
# This must happen AFTER vision-text merge so image token positions are correct
inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel(
inputs_embeds=inputs_embeds,
labels=labels,
loss_mask=loss_mask,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
pg_collection=self.config._pg_collection,
)

outputs = self.language_model.forward(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask, # (B, 1, T, T)
decoder_input=inputs_embeds, # (T, B, D)
labels=labels, # (B, T)
attention_mask=attention_mask,
decoder_input=inputs_embeds,
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs
# Return both outputs and the CP-sliced loss_mask for consistent loss computation
return (outputs, loss_mask)

def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool):
"""Freeze model modules.
Expand Down Expand Up @@ -191,7 +223,7 @@ def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_
def _compute_attention_mask(
self,
input_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Optional[torch.Tensor]:
if not self.pre_process:
return None
batch_size, seq_len = input_ids.shape
Expand Down
8 changes: 7 additions & 1 deletion src/megatron/bridge/models/glm_vl/modeling_glm_45v.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import types
from typing import Optional
from typing import TYPE_CHECKING, Optional

import torch
import transformers
Expand All @@ -37,6 +37,10 @@
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync


if TYPE_CHECKING:
from megatron.core.packed_seq_params import PackedSeqParams


def is_transformers_min_version(version):
"""Check if minimum version of transformers is installed."""
try:
Expand Down Expand Up @@ -158,6 +162,7 @@ def forward(
video_grid_thw: Optional[torch.LongTensor] = None,
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
Expand Down Expand Up @@ -233,6 +238,7 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs

Expand Down
14 changes: 11 additions & 3 deletions src/megatron/bridge/models/ministral3/ministral3_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,17 @@ def __init__(
self.beta = 0 # No effect
self.max_position_embeddings = self.config.seq_length

@staticmethod
def _get_llama_4_attn_scale(
self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int
positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
) -> torch.Tensor:
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
return scaling.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# Add dimensions to match query shape: [seq_len] -> [seq_len, 1, 1] for packed or [seq_len, 1, 1, 1] for unpacked
# Query can be either [seq_len, num_heads, head_dim] (packed) or [seq_len, batch, num_heads, head_dim] (unpacked)
num_dims_to_add = len(query_shape) - 1
for _ in range(num_dims_to_add):
scaling = scaling.unsqueeze(-1)
return scaling

def forward(
self,
Expand All @@ -276,6 +282,8 @@ def forward(
**kwargs,
):
positions_ids = torch.arange(query.shape[0], device=query.device)
query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings).to(query.dtype)
query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings, query.shape).to(
query.dtype
)

return super().forward(query, key, value, attention_mask, attn_mask_type, **kwargs)
33 changes: 28 additions & 5 deletions src/megatron/bridge/models/ministral3/modeling_ministral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,21 @@
"""

import types
from typing import Optional
from typing import TYPE_CHECKING, Optional

import torch
from megatron.core.transformer.module import MegatronModule
from torch import Tensor

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
from megatron.bridge.utils.common_utils import (
hook_hf_module_setattr_for_tp_grad_sync,
slice_batch_for_context_parallel,
)


if TYPE_CHECKING:
from megatron.core.packed_seq_params import PackedSeqParams


# Import HuggingFace Mistral3 model classes with fallback
Expand Down Expand Up @@ -185,9 +192,10 @@ def forward(
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
image_sizes: Optional[torch.Tensor] = None,
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
) -> tuple[Tensor, Tensor | None]:
"""
Forward pass combining HuggingFace vision encoder with Megatron language model.

Expand All @@ -202,7 +210,8 @@ def forward(
loss_mask: Mask for loss computation.

Returns:
Model output (logits or loss depending on mode).
tuple: (output_tensor, loss_mask) where output_tensor contains model output
and loss_mask is the CP-sliced mask for consistent loss computation.
"""
if self.pre_process:
if inputs_embeds is None:
Expand Down Expand Up @@ -237,6 +246,18 @@ def forward(
# Transpose back to Megatron format [seq_len, batch, hidden]
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous()

# CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
# This must happen AFTER vision-text merge so image token positions are correct
inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel(
inputs_embeds=inputs_embeds,
labels=labels,
loss_mask=loss_mask,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
pg_collection=self.config._pg_collection,
)

# Forward through Megatron language model
outputs = self.language_model.forward(
input_ids=None,
Expand All @@ -246,8 +267,10 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs
# Return both outputs and the CP-sliced loss_mask for consistent loss computation
return (outputs, loss_mask)

def freeze(
self,
Expand Down
1 change: 1 addition & 0 deletions src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs

Expand Down
3 changes: 3 additions & 0 deletions src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,12 @@ def _gemma3_vl_common(
model_cfg.freeze_vision_model = freeze_vision_model
model_cfg.freeze_vision_projection = freeze_vision_projection
model_cfg.seq_length = seq_length
model_cfg.cp_comm_type = "a2a"

# Optimizer and scheduler - use finetune_lr if provided, otherwise use lr
effective_lr = finetune_lr if finetune_lr is not None else lr
if min_lr > effective_lr:
min_lr = effective_lr * 0.1
opt_config, scheduler = distributed_fused_adam_with_cosine_annealing(
lr_warmup_iters=lr_warmup_iters,
lr_decay_iters=lr_decay_iters if lr_decay_iters is not None else train_iters,
Expand Down
12 changes: 8 additions & 4 deletions src/megatron/bridge/training/utils/packed_seq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,18 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams:
cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin")

if cu_seqlens_argmin is not None:
cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()]
else:
argmin_idx = cu_seqlens_argmin.item()
assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1
cu_seqlens_padded = cu_seqlens_padded[:argmin_idx]
elif torch.min(cu_seqlens_padded) == -1:
cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)]

if cu_seqlens_unpadded is not None:
if cu_seqlens_unpadded_argmin is not None:
cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
else:
argmin_idx = cu_seqlens_unpadded_argmin.item()
assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1
cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx]
elif torch.min(cu_seqlens_unpadded) == -1:
cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]

max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None
Expand Down
Loading