Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@

model:
seq_length: 4096
use_hf_vision_model: false
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
cross_entropy_loss_fusion: false

train:
train_iters: 20
global_batch_size: 8
micro_batch_size: 1
eval_iters: 5
train_iters: 1000
global_batch_size: 16
micro_batch_size: 2
eval_iters: 100

optimizer:
lr: 0.00025
Expand All @@ -40,6 +45,8 @@ dist:

logger:
log_interval: 1
log_throughput: true
log_params_norm: true

dataset:
sequence_length: 4096
Expand All @@ -50,4 +57,14 @@ rng:
ddp:
grad_reduce_in_fp32: true

profiling:
memory_snapshot_path: snapshot.pickle
nvtx_ranges: false
profile_ranks: [0]
profile_step_end: 12
profile_step_start: 10
record_memory_history: false
record_shapes: false
use_nsys_profiler: false
use_pytorch_profiler: false

5 changes: 3 additions & 2 deletions examples/recipes/qwen_vl/finetune_qwen_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@
create_omegaconf_dict_config,
parse_hydra_overrides,
)
from megatron.bridge.training.vlm_step import forward_step
from megatron.bridge.training.qwen3vl_step import forward_step
from megatron.bridge.utils.common_utils import get_rank_safe

from functools import partial

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -185,6 +185,7 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
help="Use preloaded dataset provider (enabled automatically when --data-path is set).",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")

args, cli_dotlist_overrides = parser.parse_known_args()
return args, cli_dotlist_overrides

Expand Down
2 changes: 1 addition & 1 deletion scripts/performance/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def parse_cli_args():
parser.add_argument(
"--domain",
type=lower_str,
choices=["llm", "vlm"],
choices=["llm", "vlm", "qwen3vl"],
help="Domain to use for experiment.",
default="llm",
)
Expand Down
3 changes: 3 additions & 0 deletions scripts/performance/run_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.training.vlm_step import forward_step as vlm_forward_step
from megatron.bridge.training.qwen3vl_step import forward_step as qwen3vl_forward_step


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -61,6 +62,8 @@ def main():
# Select forward step function based on the model family name.
if args.domain == "vlm":
forward_step_func = vlm_forward_step
elif args.domain == "qwen3vl":
forward_step_func = qwen3vl_forward_step
else:
forward_step_func = forward_step

Expand Down
3 changes: 3 additions & 0 deletions src/megatron/bridge/data/vlm_datasets/hf_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class HFDatasetConversationProvider(DatasetProvider):
# DataloaderConfig fields are inherited (num_workers, dataloader_type, etc.)
dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single"

# Enable batch-level online sequence packing (dataset-level packing is available in FinetuneDatasetProvider)
pack_sequences_in_batch: bool = False

def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]:
registry: Dict[str, Callable[..., List[Dict[str, Any]]]] = {
"make_rdr_dataset": make_rdr_dataset,
Expand Down
2 changes: 1 addition & 1 deletion src/megatron/bridge/models/conversion/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,7 +1840,7 @@ class ConcatenatedQKVMapping(MegatronParamMapping[Dict[str, torch.Tensor]]):
.. code-block:: python

# Create mapping for attention weights
mapping = QKVMapping(
mapping = ConcatenatedQKVMapping(
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
qkv="model.layers.*.self_attn.qkv.weight",
)
Expand Down
260 changes: 260 additions & 0 deletions src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024 Alibaba PAI Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron.core.transformer.attention import *

from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.rope import apply_rotary_pos_emb_absolute


class Qwen3VLSelfAttention(SelfAttention):
"""
Overrides the SelfAttention class, the difference is that qwen3vl uses apply_rotary_pos_emb_absolute
instead of apply_rotary_pos_emb
"""

def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
key_value_states: Optional[Tensor] = None,
inference_context: Optional[BaseInferenceContext] = None,
rotary_pos_emb: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[int] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
rotary_pos_cos_sin: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Perform a forward pass through the attention module.

Args:
hidden_states (Tensor): Hidden states.
attention_mask (Tensor): Attention mask.
key_value_states (Optional[Tensor]): Key/value states (for cross attention).
inference_context (Optional[BaseInferenceContext]): Inference context that manages
KV cache.
rotary_pos_emb (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]): Rotary
embedding tensor(s).
rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine.
rotary_pos_sin (Optional[Tensor]): Rotary embedding sine.
attention_bias (Optional[Tensor]): Attention bias.
packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format.
sequence_len_offset (Optional[int]): Sequence length offset used for
inference CUDA graphs.

Return:
(Tuple[Tensor, Tensor]) Attention output and bias.

"""
# Check if we need to skip RoPE
# no_rope is 0-indexed array and self.layer_number is 1-indexed
no_rope = self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False
if no_rope:
rotary_pos_emb = None

inference_context = deprecate_inference_params(inference_context, inference_params)

if inference_context and inference_context.is_dynamic_batching():
assert HAVE_FA3 or is_fa_min_version("2.7.3"), (
"flash attn verion v2.7.3 and above is required for dynamic batching."
)

# hidden_states: [sq, b, h]
if self.config.flash_decode and not self.training and inference_context is not None:
rotary_pos_emb = None
else:
assert rotary_pos_cos is None and rotary_pos_sin is None

# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2

# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
nvtx_range_push(suffix="qkv")
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
nvtx_range_pop(suffix="qkv")

# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================

in_decode_mode = inference_context is not None and inference_context.is_decode_only() and not self.training

# This branch only runs in the decode phase of flash decoding and returns after the linear
# projection. This conditional is not used in the prefill phase or non-flash-decoding cases.
nvtx_range_push(suffix="adjust_key_value")
if in_decode_mode and self.config.flash_decode:
assert self.layer_number in inference_context.key_value_memory_dict
assert inference_context.sequence_len_offset is not None
inference_key_memory, inference_value_memory = inference_context.key_value_memory_dict[self.layer_number]
output = self.flash_decode(
sequence_len_offset=sequence_len_offset,
query_layer=query,
key_layer=key,
value_layer=value,
inference_key_memory=inference_key_memory,
inference_value_memory=inference_value_memory,
rotary_cos=rotary_pos_cos,
rotary_sin=rotary_pos_sin,
rotary_interleaved=self.config.rotary_interleaved,
)
out = output.transpose(0, 1).contiguous()
context_layer = out.view(out.size(0), out.size(1), -1)
output, bias = self.linear_proj(context_layer)
return output, bias

if in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching():
raise ValueError(f"CUDA graphs must use flash decode with static batching!")

query, key, value, rotary_pos_emb, attn_mask_type, block_table = self._adjust_key_value_for_inference(
inference_context,
query,
key,
value,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
sequence_len_offset,
)

if packed_seq_params is not None:
query = query.squeeze(1)
key = key.squeeze(1)
value = value.squeeze(1)
nvtx_range_pop(suffix="adjust_key_value")

# ================================================
# relative positional embedding (rotary embedding)
# ================================================
nvtx_range_push(suffix="rotary_pos_emb")
if rotary_pos_emb is not None and not self.config.flash_decode:
q_pos_emb, k_pos_emb = rotary_pos_emb

if packed_seq_params is not None:
if packed_seq_params.cu_seqlens_q_padded is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
else:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
if packed_seq_params.cu_seqlens_kv_padded is not None:
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
else:
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None

if q_pos_emb is not None:
# TODO VIJAY: simplify
if inference_context is None or inference_context.is_static_batching():
query = apply_rotary_pos_emb_absolute(
query,
q_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_q,
)
else:
query = inference_context.apply_rotary_emb_query(
query,
q_pos_emb,
self.config,
cu_seqlens_q,
self.model_comm_pgs.cp,
)
if k_pos_emb is not None:
key = apply_rotary_pos_emb_absolute(
key,
k_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_kv,
)

# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
nvtx_range_pop(suffix="rotary_pos_emb")

# ==================================
# core attention computation
# ==================================

nvtx_range_push(suffix="core_attention")
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
else:
if inference_context is None or inference_context.is_static_batching():
# Static batching attention kernel.
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)

else:
# Dynamic batching attention kernel.
q, k, v = (query, key, value)
cu_query_lengths, max_seqlen_q = inference_context.cu_query_lengths()
cu_kv_lengths, kv_lengths, kv_lengths_decode_only, max_seqlen_k = inference_context.cu_kv_lengths()

core_attn_out = self.flash_decode_and_prefill(
q,
k,
v,
max_seqlen_q,
max_seqlen_k,
cu_query_lengths,
cu_kv_lengths,
kv_lengths,
kv_lengths_decode_only,
block_table,
)
core_attn_out = rearrange(core_attn_out, "s b h d -> s b (h d)")

if packed_seq_params is not None and packed_seq_params.qkv_format == "thd":
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
nvtx_range_pop(suffix="core_attention")

# =================
# Output. [sq, b, h]
# =================

nvtx_range_push(suffix="linear_proj")
output, bias = self.linear_proj(core_attn_out)
nvtx_range_pop(suffix="linear_proj")

return output, bias
Loading