Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/megatron/bridge/models/gpt_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,24 @@ class GPTModelProvider(TransformerConfig, ModelProviderMixin[MCoreGPTModel]):

_pg_collection: Optional[ProcessGroupCollection] = None

# vision model type will be used to override the vision model config.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit weird add here the distrain configs and vision configs. find a better approach, this shouldn't be exposed to all gpt_provider models, only qwen3_vl, you can edit this in qwen3 provider

# if None, the vision model config will be used as is.
# currently, only vit_2b is supported.
vision_model_type: Optional[str] = None

# parameters for DistTrain
use_dist_train: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a train config

dist_train_vision_chunk_size: Optional[int] = 1
vision_world_size: Optional[int] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe dist_train_dataclass

language_world_size: Optional[int] = None
vision_tensor_model_parallel_size: Optional[int] = None
vision_pipeline_model_parallel_size: Optional[int] = None
vision_context_parallel_size: Optional[int] = None
vision_expert_tensor_parallel_size: Optional[int] = None
vision_expert_model_parallel_size: Optional[int] = None
add_vision_module: bool = True
add_language_module: bool = True

def provide(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel:
"""Configure and instantiate a Megatron Core GPT model based on this configuration.

Expand Down
139 changes: 96 additions & 43 deletions src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.pipeline_parallel.utils import is_pp_last_stage
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
Expand All @@ -45,6 +46,7 @@
split_deepstack_embs,
)
from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.vision_model import Qwen3VLVisionModel
from typing import List, Dict


class Qwen3VLModel(MegatronModule):
Expand Down Expand Up @@ -83,14 +85,21 @@ def __init__(
pg_collection: ProcessGroupCollection = None,
) -> None:
super().__init__(config=language_transformer_config)
self.vision_embeds = None
self.deepstack_feature_lists = None
self.language_transformer_config = language_transformer_config

language_transformer_layer_spec.submodules.self_attention.module = Qwen3VLSelfAttention

self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder

if language_transformer_config.use_dist_train:
self.add_encoder = add_encoder
self.add_decoder = add_decoder
assert not (self.add_encoder and self.add_decoder) and (self.add_encoder or self.add_decoder), "add_encoder and add_decoder should not be both True or both False for dist train"
else:
self.add_encoder = self.pre_process
self.add_decoder = True
self.encoder_hidden_state = None
self.vision_model = None
self.language_model = None
Expand Down Expand Up @@ -135,41 +144,44 @@ def __init__(
)
megatron_vision_transformer_config.pipeline_model_parallel_size = 1
megatron_vision_transformer_config.first_pipeline_num_layers = None
self.vision_transformer_config = megatron_vision_transformer_config

if self.add_encoder:
self.vision_model = Qwen3VLVisionModel(
megatron_vision_transformer_config,
vision_transformer_layer_spec,
vision_patch_merger_spec,
pre_process=True,
post_process=True,
pg_collection=pg_collection,
)

self.language_model = Qwen3VLGPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
vocab_size=language_transformer_config.vocab_size,
max_sequence_length=language_transformer_config.language_max_sequence_length,
parallel_output=parallel_output,
position_embedding_type="mrope",
rotary_percent=language_transformer_config.rotary_percent,
pre_process=self.pre_process,
post_process=self.post_process,
rotary_base=language_transformer_config.rotary_base,
fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy,
share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights,
scatter_embedding_sequence_parallel=False,
pg_collection=pg_collection,
)
if pre_process:
assert len(vision_transformer_config.deepstack_visual_indexes) <= len(
self.language_model.decoder.layers
), (
"the deepstack_visual_embeds should on the first pp-stage of language model",
f"got {len(vision_transformer_config.deepstack_visual_indexes)} deepstack_visual_indexes, "
f" {len(self.language_model.decoder.layers)} language model layers",
if self.add_decoder:
self.language_model = Qwen3VLGPTModel(
config=language_transformer_config,
transformer_layer_spec=language_transformer_layer_spec,
vocab_size=language_transformer_config.vocab_size,
max_sequence_length=language_transformer_config.language_max_sequence_length,
parallel_output=parallel_output,
position_embedding_type="mrope",
rotary_percent=language_transformer_config.rotary_percent,
pre_process=self.pre_process,
post_process=self.post_process,
rotary_base=language_transformer_config.rotary_base,
fp16_lm_cross_entropy=language_transformer_config.fp16_lm_cross_entropy,
share_embeddings_and_output_weights=language_transformer_config.share_embeddings_and_output_weights,
scatter_embedding_sequence_parallel=False,
pg_collection=pg_collection,
)
if pre_process:
assert len(vision_transformer_config.deepstack_visual_indexes) <= len(
self.language_model.decoder.layers
), (
"the deepstack_visual_embeds should on the first pp-stage of language model",
f"got {len(vision_transformer_config.deepstack_visual_indexes)} deepstack_visual_indexes, "
f" {len(self.language_model.decoder.layers)} language model layers",
)

self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights
self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights

if self.pg_collection.cp.size() > 1:
assert self.config.calculate_per_token_loss, (
Expand All @@ -183,17 +195,34 @@ def shared_embedding_or_output_weight(self):
return self.language_model.shared_embedding_or_output_weight()
return None

def set_input_tensor(self, input_tensor) -> None:
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen3VL"

if self.pre_process:
self.encoder_hidden_state = input_tensor[0]
def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]):
"""Set input tensor for pipeline parallelism.
"""
if input_tensor is None or len(input_tensor) == 0 or input_tensor[0] is None:
return
if self.config.use_dist_train:
assert isinstance(input_tensor, list), "Input tensor must be a list, but got {type(input_tensor)}"
assert len(input_tensor) == 1, "Input tensor must be a list of length 1, but got {len(input_tensor)}"
assert isinstance(input_tensor[0], dict), "Input tensor[0] must be a dictionary, but got {type(input_tensor[0])}"
input_dict = input_tensor[0]

if 'vision_module' in input_dict:
vision_module_output_tensor = input_dict['vision_module']
num_chunks = len(self.vision_transformer_config.deepstack_visual_indexes) + 1
chunks = torch.chunk(vision_module_output_tensor, chunks=num_chunks, dim=0)
self.vision_embeds = chunks[-1]
self.deepstack_feature_lists = chunks[:-1]
if 'language_module' in input_dict:
self.language_model.set_input_tensor(input_dict['language_module'])
else:
self.language_model.set_input_tensor(input_tensor[0])
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
assert len(input_tensor) == 1, "input_tensor should only be length 1 for Qwen3VL"

if self.pre_process:
self.encoder_hidden_state = input_tensor[0]
else:
self.language_model.set_input_tensor(input_tensor[0])

def freeze(
self,
Expand Down Expand Up @@ -303,6 +332,7 @@ def forward(
)

vision_embeds = None
vision_module_output = None
if vision_grid_thw is not None and vision_grid_thw.shape[0] > 0:
if cp_size > 1 and self.config.vision_dp_when_cp:
if cp_img_num is None:
Expand All @@ -323,25 +353,42 @@ def forward(
)
vision_grid_thw = collapse_thw(vision_grid_thw)
if vision_data.shape[0] > 0:
vision_embeds, deepstack_feature_lists = self.vision_model(
hidden_states=vision_data,
grid_thw=vision_grid_thw,
)
if self.vision_model is not None:
if self.config.use_dist_train:
assert cp_size == 1, "currently, dist train does not support context parallelism for encoder"
num_chunks = self.config.dist_train_vision_chunk_size
chunk_idx = self.pg_collection.dp.rank() % num_chunks
vision_data_chunks = torch.chunk(vision_data, chunks=num_chunks, dim=0)
vision_data = vision_data_chunks[chunk_idx]
vision_grid_thw_chunks = torch.chunk(vision_grid_thw, chunks=num_chunks, dim=0)
vision_grid_thw = vision_grid_thw_chunks[chunk_idx]
vision_embeds, deepstack_feature_lists = self.vision_model(
hidden_states=vision_data,
grid_thw=vision_grid_thw,
)
vision_module_output = deepstack_feature_lists
vision_module_output.append(vision_embeds)
vision_module_output_tensor = torch.cat(vision_module_output, dim=0)
output_vision_module = {'vision_module': vision_module_output_tensor}
else:
vision_embeds = self.vision_embeds
deepstack_feature_lists = self.deepstack_feature_lists
else:
vision_embeds = torch.zeros(
(0, self.language_model.config.hidden_size),
device=vision_data.device,
dtype=torch.bfloat16,
)
deepstack_feature_lists = []
for _ in self.vision_model.config.deepstack_visual_indexes:
for _ in self.vision_transformer_config.deepstack_visual_indexes:
deepstack_feature_lists.append(
torch.zeros(
(0, self.language_model.config.hidden_size),
device=vision_data.device,
dtype=torch.bfloat16,
)
)

if cp_size > 1 and self.config.vision_dp_when_cp:
vision_embeds = AllGatherVisionEmbeddings.apply(
vision_embeds,
Expand All @@ -355,6 +402,9 @@ def forward(
cp_group=self.pg_collection.cp,
)

if self.language_model is None:
# TODO(shifang): need to handle the case when num_images is 0 for some samples.
return output_vision_module
combined_embeddings = self.language_model.embedding(
input_ids=input_ids,
position_ids=None, # NOTE: disable
Expand Down Expand Up @@ -490,5 +540,8 @@ def forward(
**kwargs,
)
torch.cuda.nvtx.range_pop()
if self.config.use_dist_train:
if not is_pp_last_stage(self.pg_collection.pp):
return {'language_module': output}

return output
8 changes: 5 additions & 3 deletions src/megatron/bridge/models/qwen_vl/modelling_qwen3_vl/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
import torch.nn as nn
from megatron.core import parallel_state
from megatron.core.models.common.embeddings.rope_utils import (
_apply_rotary_pos_emb_bshd,
get_pos_emb_on_this_cp_rank,
Expand Down Expand Up @@ -52,6 +51,7 @@ def __init__(
rotary_interleaved: bool = False,
seq_len_interpolation_factor: Optional[float] = None,
rotary_base: int = 10000,
cp_group: torch.distributed.ProcessGroup = None,
) -> None:
super().__init__()

Expand All @@ -69,6 +69,8 @@ def __init__(

# default mrope section is [24, 20, 20], if no mrope section is provided, use default mrope section
self.mrope_section = [24, 20, 20]
assert cp_group is not None, "cp_group is required"
self.cp_group = cp_group

def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Expand Down Expand Up @@ -127,10 +129,10 @@ def forward(

# shape (seq_length, bs, 1, 2 * dim)
emb = emb[..., None, :].transpose(0, 1).contiguous()
if parallel_state.get_context_parallel_world_size() > 1 and not self.is_thd_format:
if self.cp_group.size() > 1 and not self.is_thd_format:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group())
emb = get_pos_emb_on_this_cp_rank(emb, 0, self.cp_group)
return emb


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
rotary_interleaved=self.config.rotary_interleaved,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
cp_group=self.pg_collection.cp,
)
self.mrope_section = self.config.mrope_section
assert self.mrope_section is not None, (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Optional, Union

import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core import tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.enums import Fp8Recipe
Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(
# model_comm_pgs: ModelCommProcessGroups = None,
vp_stage: Optional[int] = None,
patch_merger_spec: ModuleSpec = None,
pg_collection: Optional[torch.distributed.ProcessGroup] = None,
):
assert post_process and pre_process, "not support pp for deepstack_merger_list"
super().__init__(
Expand All @@ -77,7 +78,13 @@ def __init__(
post_process=post_process,
# model_comm_pgs=model_comm_pgs,
vp_stage=vp_stage,
pg_collection=pg_collection,
)
self.pg_collection = pg_collection
self.cp_group = pg_collection.cp
self.tp_group = pg_collection.tp
self.pp_group = pg_collection.pp

self.deepstack_visual_indexes = config.deepstack_visual_indexes
self.deepstack_merger_list = nn.ModuleList(
[
Expand Down Expand Up @@ -141,7 +148,7 @@ def checkpoint_handler(forward_func):
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
hidden_states,
attention_mask,
context,
Expand Down Expand Up @@ -391,7 +398,7 @@ def sharded_state_dict(
layer_prefix = f"{prefix}layers."
num_layers = self.config.num_layers
for layer in self.layers:
offset = get_transformer_layer_offset(self.config, self.vp_stage)
offset = get_transformer_layer_offset(self.config, self.vp_stage, pp_rank=self.pp_group.rank())

global_layer_offset = layer.layer_number - 1 # self.layer_number starts at 1
state_dict_prefix = f"{layer_prefix}{global_layer_offset - offset}." # module list index in TransformerBlock # pylint: disable=line-too-long
Expand Down Expand Up @@ -503,7 +510,7 @@ def checkpoint_handler(forward_func):
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
self.tp_group,
hidden_states,
attention_mask,
context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ def get_vision_model_config(hf_config, megatron_config=None):
add_bias_linear=True,
add_qkv_bias=True,
)
if megatron_config.vision_model_type == "vit_2b":
hf_config.depth = 45
hf_config.hidden_size = 1536
hf_config.num_heads = 16
hf_config.intermediate_size = 8960
hf_config.patch_size = 16
hf_config.spatial_merge_size = 2
if hasattr(hf_config, "head_dim"):
hf_config.head_dim = 96
else:
assert megatron_config.vision_model_type is None, ValueError(f"support only vit_2b, but got {config.vision_model_type}")

# apply text model config to vision model config
config.recompute_granularity = megatron_config.recompute_granularity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
self.input_size = config.hidden_size
if self.use_postshuffle_norm:
self.input_size = self.hidden_size
tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=False)
self.tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=False)

self.patch_norm = build_module(
submodules.patch_norm,
Expand Down
Loading