Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.
self.share_embeddings_and_output_weights = False
# process groups
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp
Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
vision_patch_merger_spec,
pre_process=True,
post_process=True,
pg_collection=pg_collection,
)

self.language_model = Qwen3VLGPTModel(
Expand Down
8 changes: 5 additions & 3 deletions src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
import torch.nn as nn
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rope_utils import (
_apply_rotary_pos_emb_bshd,
get_pos_emb_on_this_cp_rank,
Expand Down Expand Up @@ -52,6 +51,7 @@ def __init__(
rotary_interleaved: bool = False,
seq_len_interpolation_factor: Optional[float] = None,
rotary_base: int = 10000,
cp_group: torch.distributed.ProcessGroup = None,
) -> None:
super().__init__()

Expand All @@ -69,6 +69,8 @@ def __init__(

# default mrope section is [24, 20, 20], if no mrope section is provided, use default mrope section
self.mrope_section = [24, 20, 20]
assert cp_group is not None, "cp_group is required"
self.cp_group = cp_group

def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Expand Down Expand Up @@ -127,10 +129,10 @@ def forward(

# shape (seq_length, bs, 1, 2 * dim)
emb = emb[..., None, :].transpose(0, 1).contiguous()
if parallel_state.get_context_parallel_world_size() > 1 and not self.is_thd_format:
if self.cp_group.size() > 1 and not self.is_thd_format:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group())
emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group)
return emb


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
cp_group=self.pg_collection.cp,
)
self.mrope_section = self.config.mrope_section
assert self.mrope_section is not None, (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from typing import Optional, Union

import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core import tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
Expand Down Expand Up @@ -64,9 +65,9 @@ def __init__(
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
# model_comm_pgs: ModelCommProcessGroups = None,
vp_stage: Optional[int] = None,
patch_merger_spec: ModuleSpec = None,
pg_collection: Optional[ProcessGroupCollection] = None,
):
assert post_process and pre_process, "not support pp for deepstack_merger_list"
super().__init__(
Expand All @@ -75,9 +76,16 @@ def __init__(
post_layer_norm=post_layer_norm,
pre_process=pre_process,
post_process=post_process,
# model_comm_pgs=model_comm_pgs,
vp_stage=vp_stage,
pg_collection=pg_collection,
)
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp
self.pp_group = pg_collection.pp
Comment thread
coderabbitai[bot] marked this conversation as resolved.

self.deepstack_visual_indexes = config.deepstack_visual_indexes
self.deepstack_merger_list = nn.ModuleList(
[
Expand Down Expand Up @@ -141,7 +149,7 @@ def checkpoint_handler(forward_func):
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
hidden_states,
attention_mask,
context,
Expand Down Expand Up @@ -391,7 +399,7 @@ def sharded_state_dict(
layer_prefix = f"{prefix}layers."
num_layers = self.config.num_layers
for layer in self.layers:
offset = get_transformer_layer_offset(self.config, self.vp_stage)
offset = get_transformer_layer_offset(self.config, self.vp_stage, pp_rank=self.pp_group.rank())

global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f"{layer_prefix}{global_layer_offset - offset}." # module list index in TransformerBlock # pylint: disable=line-too-long
Expand Down Expand Up @@ -436,6 +444,32 @@ class Qwen3VLTransformerBlock(TransformerBlock):
Transformer Block for Qwen3VL model.
"""

def __init__(
self,
config: Qwen3VLTransformerConfig,
spec: Union[TransformerBlockSubmodules, ModuleSpec],
post_layer_norm: bool = True,
pre_process: bool = True,
post_process: bool = True,
vp_stage: Optional[int] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
):
super().__init__(
config=config,
spec=spec,
post_layer_norm=post_layer_norm,
pre_process=pre_process,
post_process=post_process,
vp_stage=vp_stage,
pg_collection=pg_collection,
)
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp
self.pp_group = pg_collection.pp

def _checkpointed_forward(
self,
hidden_states: Tensor,
Expand Down Expand Up @@ -503,7 +537,7 @@ def checkpoint_handler(forward_func):
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
hidden_states,
attention_mask,
context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
self.input_size = config.hidden_size
if self.use_postshuffle_norm:
self.input_size = self.hidden_size
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=False)
self.tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=False)

self.patch_norm = build_module(
submodules.patch_norm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from megatron.core import InferenceParams
from megatron.core.models.common.vision_module.vision_module import VisionModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.spec_utils import ModuleSpec
from torch import nn
Expand Down Expand Up @@ -48,13 +49,17 @@ def __init__(
patch_merger_spec: ModuleSpec,
pre_process: bool = True,
post_process: bool = True,
pg_collection: Optional[torch.distributed.ProcessGroup] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
) -> None:
assert post_process and pre_process, "not support pp for deepstack_merger_list"
super().__init__(config=transformer_config)
self.spatial_merge_size = transformer_config.spatial_merge_size
self.patch_size = transformer_config.patch_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.pg_collection = pg_collection
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
self.tp_group = self.pg_collection.tp

assert transformer_config.context_parallel_size == 1, (
f"context_parallel_size should be 1 in vision model but got {transformer_config.context_parallel_size}"
Expand All @@ -79,6 +84,7 @@ def __init__(
post_process=self.post_process,
post_layer_norm=False,
patch_merger_spec=patch_merger_spec,
pg_collection=self.pg_collection,
)

self.merger = None
Expand Down
75 changes: 75 additions & 0 deletions tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
Run with: pytest tests/unit_tests/models/qwen_vl/modelling_qwen3_vl/test_rope.py
"""

import datetime
import os

import pytest
import torch
import torch.distributed as dist
from megatron.core import parallel_state
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from transformers import Qwen3VLMoeTextConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextRotaryEmbedding

Expand All @@ -33,12 +40,71 @@ def hf_config():
class TestQwen3VLTextRotaryEmbedding:
"""Test suite for Qwen3VL Text Rotary Embedding."""

@classmethod
def setup_class(cls):
"""Setup distributed process group once for all tests in this class."""
if not dist.is_initialized():
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

device_count = torch.cuda.device_count()
if device_count > 0:
torch.cuda.set_device(0)

dist.init_process_group(
backend="nccl" if device_count > 0 else "gloo",
world_size=1,
rank=0,
timeout=datetime.timedelta(minutes=30),
)

@classmethod
def teardown_class(cls):
"""Teardown distributed process group once after all tests in this class."""
if dist.is_initialized():
dist.destroy_process_group()

def _setup_parallel_state(self, tp_size=1, ep_size=1, pp_size=1, cp_size=1):
"""Setup Megatron parallel state with specified parallelism configuration.

Args:
tp_size: Tensor model parallel size
ep_size: Expert model parallel size
pp_size: Pipeline model parallel size
cp_size: Context parallel size
"""
# Clean up any existing parallel state before initializing
if parallel_state.model_parallel_is_initialized():
parallel_state.destroy_model_parallel()

parallel_state.initialize_model_parallel(
tensor_model_parallel_size=tp_size,
pipeline_model_parallel_size=pp_size,
virtual_pipeline_model_parallel_size=None,
context_parallel_size=cp_size,
expert_model_parallel_size=ep_size,
expert_tensor_parallel_size=1,
)

model_parallel_cuda_manual_seed(123)

def teardown_method(self):
"""Teardown Megatron parallel state after each test method."""
if parallel_state.model_parallel_is_initialized():
parallel_state.destroy_model_parallel()

def test_qwen3_vl_text_rotary_embedding(self, hf_config):
"""Test that MBridge RoPE output matches HuggingFace implementation."""
self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1, cp_size=1)
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
hf_rope_embedding = Qwen3VLMoeTextRotaryEmbedding(hf_config)
mbridge_rope_embedding = Qwen3VLMultimodalRotaryEmbedding(
kv_channels=hf_config.head_dim,
rotary_base=rope_theta_from_hf(hf_config),
cp_group=pg_collection.cp,
)

seq_len = 1024
Expand Down Expand Up @@ -80,9 +146,12 @@ def test_qwen3_vl_text_rotary_embedding(self, hf_config):

def test_qwen3_vl_text_rotary_embedding_2d_position_ids(self, hf_config):
"""Test Qwen3VLMultimodalRotaryEmbedding with 2D position_ids (should auto-expand to 3D)."""
self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1, cp_size=1)
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
mbridge_rope_embedding = Qwen3VLMultimodalRotaryEmbedding(
kv_channels=hf_config.head_dim,
rotary_base=rope_theta_from_hf(hf_config),
cp_group=pg_collection.cp,
)

seq_len = 512
Expand All @@ -103,9 +172,12 @@ def test_qwen3_vl_text_rotary_embedding_2d_position_ids(self, hf_config):

def test_qwen3_vl_text_rotary_embedding_default_mrope_section(self, hf_config):
"""Test Qwen3VLMultimodalRotaryEmbedding with None mrope_section (should use default)."""
self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1, cp_size=1)
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
mbridge_rope_embedding = Qwen3VLMultimodalRotaryEmbedding(
kv_channels=hf_config.head_dim,
rotary_base=rope_theta_from_hf(hf_config),
cp_group=pg_collection.cp,
)

seq_len = 256
Expand All @@ -122,9 +194,12 @@ def test_qwen3_vl_text_rotary_embedding_default_mrope_section(self, hf_config):

def test_qwen3_vl_moe_text_rotary_embedding(self, hf_config):
"""Test Qwen3VLMultimodalRotaryEmbedding forward pass."""
self._setup_parallel_state(tp_size=1, ep_size=1, pp_size=1, cp_size=1)
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
mbridge_rope_embedding = Qwen3VLMultimodalRotaryEmbedding(
kv_channels=hf_config.head_dim,
rotary_base=rope_theta_from_hf(hf_config),
cp_group=pg_collection.cp,
)

seq_len = 512
Expand Down
Loading