diff --git a/src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py b/src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py index 3cb40dd3f5..195f9dd21b 100644 --- a/src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py +++ b/src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py @@ -30,10 +30,24 @@ from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM -@MegatronModelBridge.register_bridge(source=Gemma3ForConditionalGeneration, target=Gemma3VLModel) +@MegatronModelBridge.register_bridge( + source=Gemma3ForConditionalGeneration, + target=Gemma3VLModel, + provider=Gemma3VLModelProvider, + model_type="gemma3_vl", +) class Gemma3VLBridge(MegatronModelBridge): """ Megatron Bridge for Gemma3 VL. + + This bridge handles the conversion between HuggingFace Gemma3ForConditionalGeneration + and Megatron-Core Gemma3VLModel formats, including weight mappings and + configuration translation for vision-language models. + + Example: + >>> from megatron.bridge import AutoBridge + >>> bridge = AutoBridge.from_hf_pretrained("google/gemma-3-4b-it") + >>> provider = bridge.to_megatron_provider() """ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Gemma3VLModelProvider: @@ -41,37 +55,34 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Gemma3VLModelProvider text_config = hf_config.text_config vision_config = hf_config.vision_config - provider = Gemma3VLModelProvider( - # Text configuration - init_method_std=text_config.initializer_range, - hidden_size=text_config.hidden_size, - ffn_hidden_size=text_config.intermediate_size, - kv_channels=text_config.head_dim, - seq_length=text_config.max_position_embeddings, - num_attention_heads=text_config.num_attention_heads, - num_layers=text_config.num_hidden_layers, - num_query_groups=text_config.num_key_value_heads, - window_size=text_config.sliding_window, - rotary_base=(text_config.rope_local_base_freq, text_config.rope_theta), - layernorm_epsilon=text_config.rms_norm_eps, - vocab_size=text_config.vocab_size, - softmax_scale=1.0 / math.sqrt(text_config.query_pre_attn_scalar), - rope_scaling_factor=text_config.rope_scaling["factor"] if text_config.rope_scaling else 1.0, - # Vision configuration - vision_config=vision_config, - mm_tokens_per_image=hf_config.mm_tokens_per_image, - # VL-specific token IDs - bos_token_id=getattr(hf_config, "bos_token_id", 0), - eos_token_id=getattr(hf_config, "eos_token_id", 1), - vision_start_token_id=getattr(hf_config, "vision_start_token_id", 255999), - vision_end_token_id=getattr(hf_config, "vision_end_token_id", 256000), - image_token_id=getattr(hf_config, "image_token_id", 151655), - # Precision configuration - fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), - bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), - params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - ) + # Use base class helper for common config conversion + provider_kwargs = self.hf_config_to_provider_kwargs(text_config) + provider = Gemma3VLModelProvider(**provider_kwargs) + + # Gemma3-specific features not in CONFIG_MAPPING + provider.window_size = text_config.sliding_window + provider.rotary_base = (text_config.rope_local_base_freq, text_config.rope_theta) + provider.softmax_scale = 1.0 / math.sqrt(text_config.query_pre_attn_scalar) + provider.rope_scaling_factor = text_config.rope_scaling["factor"] if text_config.rope_scaling else 1.0 + + # Override dtype and vocab settings to match baseline + provider.bf16 = True + provider.params_dtype = torch.bfloat16 + provider.autocast_dtype = torch.bfloat16 + provider.make_vocab_size_divisible_by = 128 + + # Vision configuration + provider.vision_config = vision_config + provider.mm_tokens_per_image = hf_config.mm_tokens_per_image + + # VL-specific token IDs + provider.bos_token_id = getattr(hf_config, "bos_token_id", 0) + provider.eos_token_id = getattr(hf_config, "eos_token_id", 1) + provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 255999) + provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 256000) + provider.image_token_id = getattr(hf_config, "image_token_id", 262144) + # Vision projector configuration provider.vision_projector_config.input_size = vision_config.hidden_size provider.vision_projector_config.hidden_size = text_config.hidden_size diff --git a/src/megatron/bridge/models/nemotron_vl/nemotron_vl_bridge.py b/src/megatron/bridge/models/nemotron_vl/nemotron_vl_bridge.py index ff70f7bd55..c750f67dd4 100644 --- a/src/megatron/bridge/models/nemotron_vl/nemotron_vl_bridge.py +++ b/src/megatron/bridge/models/nemotron_vl/nemotron_vl_bridge.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +from megatron.core.activations import squared_relu from megatron.bridge.models import ColumnParallelMapping, RowParallelMapping from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry @@ -29,7 +29,12 @@ from megatron.bridge.models.nemotron_vl.nemotron_vl_provider import NemotronNano12Bv2VLModelProvider -@MegatronModelBridge.register_bridge(source="NemotronH_Nano_VL_V2", target=NemotronVLModel) +@MegatronModelBridge.register_bridge( + source="NemotronH_Nano_VL_V2", + target=NemotronVLModel, + provider=NemotronNano12Bv2VLModelProvider, + model_type="nemotron_vl", +) class NemotronVLBridge(MegatronModelBridge): """Conversion utilities between HF Nemotron-VL and Megatron-Core format.""" @@ -39,25 +44,26 @@ class NemotronVLBridge(MegatronModelBridge): def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> NemotronNano12Bv2VLModelProvider: # type: ignore[override] hf_config = hf_pretrained.config + llm_config = hf_config.llm_config + + # Use base class helper for common config mapping + provider_kwargs = self.hf_config_to_provider_kwargs(llm_config) + + # Handle vocab size divisibility + provider_kwargs["make_vocab_size_divisible_by"] = self.make_vocab_size_divisible_by(llm_config.vocab_size) + + provider = NemotronNano12Bv2VLModelProvider(**provider_kwargs) + + # Nemotron VL-specific settings + # Note: Most defaults come from the provider class hierarchy (NemotronNano12Bv2Provider) + provider.scatter_embedding_sequence_parallel = False + provider.attention_softmax_in_fp32 = True + + # Override fields that should use NemotronH provider's specialized defaults + # instead of HF config values + provider.activation_func = squared_relu # Nemotron uses squared_relu, not HF's hidden_act + provider.autocast_dtype = None # Not set in original code - provider = NemotronNano12Bv2VLModelProvider( - num_layers=hf_config.llm_config.num_hidden_layers, - hidden_size=hf_config.llm_config.hidden_size, - ffn_hidden_size=hf_config.llm_config.intermediate_size, - num_attention_heads=hf_config.llm_config.num_attention_heads, - num_query_groups=getattr( - hf_config.llm_config, "num_key_value_heads", hf_config.llm_config.num_attention_heads // 2 - ), - init_method_std=hf_config.llm_config.initializer_range, - layernorm_epsilon=getattr(hf_config.llm_config, "layer_norm_epsilon", 1e-5), - make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.llm_config.vocab_size), - share_embeddings_and_output_weights=getattr(hf_config.llm_config, "tie_word_embeddings", False), - vocab_size=hf_config.llm_config.vocab_size, - seq_length=hf_config.llm_config.max_position_embeddings, - fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), - bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), - params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - ) return provider # ------------------------------------------------------------------ diff --git a/src/megatron/bridge/models/qwen_vl/__init__.py b/src/megatron/bridge/models/qwen_vl/__init__.py index 14f0232a63..8495eaef29 100644 --- a/src/megatron/bridge/models/qwen_vl/__init__.py +++ b/src/megatron/bridge/models/qwen_vl/__init__.py @@ -20,7 +20,7 @@ Qwen3VLMoEModelProvider, ) from megatron.bridge.models.qwen_vl.qwen25_vl_bridge import Qwen25VLBridge -from megatron.bridge.models.qwen_vl.qwen_vl_provider import ( +from megatron.bridge.models.qwen_vl.qwen25_vl_provider import ( Qwen25VLModelProvider, ) diff --git a/src/megatron/bridge/models/qwen_vl/qwen25_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen25_vl_bridge.py index ebcbda729e..a659908e0a 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen25_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen25_vl_bridge.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from transformers import Qwen2_5_VLForConditionalGeneration from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry @@ -25,10 +24,15 @@ ) from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.qwen_vl.modeling_qwen25_vl import Qwen25VLModel -from megatron.bridge.models.qwen_vl.qwen_vl_provider import Qwen25VLModelProvider +from megatron.bridge.models.qwen_vl.qwen25_vl_provider import Qwen25VLModelProvider -@MegatronModelBridge.register_bridge(source=Qwen2_5_VLForConditionalGeneration, target=Qwen25VLModel) +@MegatronModelBridge.register_bridge( + source=Qwen2_5_VLForConditionalGeneration, + target=Qwen25VLModel, + provider=Qwen25VLModelProvider, + model_type="qwen2_5_vl", +) class Qwen25VLBridge(MegatronModelBridge): """ Megatron Bridge for Qwen2.5-VL Conditional Generation. @@ -45,35 +49,28 @@ class Qwen25VLBridge(MegatronModelBridge): def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen25VLModelProvider: hf_config = hf_pretrained.config + text_config = hf_config # Qwen2.5-VL has text config fields directly on main config - provider = Qwen25VLModelProvider( - num_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - ffn_hidden_size=hf_config.intermediate_size, - num_attention_heads=hf_config.num_attention_heads, - num_query_groups=hf_config.num_key_value_heads, - init_method_std=hf_config.initializer_range, - layernorm_epsilon=hf_config.rms_norm_eps, - gated_linear_unit=True, - make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), - rotary_base=hf_config.rope_theta, - share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), - vocab_size=hf_config.vocab_size, - seq_length=hf_config.max_position_embeddings, - fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), - bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), - params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - add_qkv_bias=True, # Qwen2 has bias in QKV projections - vision_config=hf_config.vision_config, - # VL-specific token IDs - bos_token_id=getattr(hf_config, "bos_token_id", 151643), - eos_token_id=getattr(hf_config, "eos_token_id", 151645), - vision_start_token_id=getattr(hf_config, "vision_start_token_id", 151652), - vision_end_token_id=getattr(hf_config, "vision_end_token_id", 151653), - vision_token_id=getattr(hf_config, "vision_token_id", 151654), - image_token_id=getattr(hf_config, "image_token_id", 151655), - video_token_id=getattr(hf_config, "video_token_id", 151656), - ) + provider_kwargs = self.hf_config_to_provider_kwargs(text_config) + provider = Qwen25VLModelProvider(**provider_kwargs) + + # Qwen2-specific settings + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.add_qkv_bias = True + provider.add_bias_linear = False + provider.hidden_dropout = 0.0 + + # VL-specific overrides + provider.position_embedding_type = "mrope" + provider.vision_config = hf_config.vision_config + provider.bos_token_id = getattr(hf_config, "bos_token_id", 151643) + provider.eos_token_id = getattr(hf_config, "eos_token_id", 151645) + provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 151652) + provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 151653) + provider.vision_token_id = getattr(hf_config, "vision_token_id", 151654) + provider.image_token_id = getattr(hf_config, "image_token_id", 151655) + provider.video_token_id = getattr(hf_config, "video_token_id", 151656) return provider diff --git a/src/megatron/bridge/models/qwen_vl/qwen_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen25_vl_provider.py similarity index 82% rename from src/megatron/bridge/models/qwen_vl/qwen_vl_provider.py rename to src/megatron/bridge/models/qwen_vl/qwen25_vl_provider.py index 9d6d61af35..1a0787cb34 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen25_vl_provider.py @@ -18,25 +18,16 @@ from megatron.core.models.gpt import GPTModel as MCoreGPTModel from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionConfig -from megatron.bridge.models import ( - Qwen2ModelProvider, -) - -from .modeling_qwen25_vl import Qwen25VLModel - - -# ============================================================================= -# Qwen 2.5 VL Model Providers -# ============================================================================= +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.qwen_vl.modeling_qwen25_vl import Qwen25VLModel @dataclass -class Qwen25VLModelProvider(Qwen2ModelProvider): +class Qwen25VLModelProvider(GPTModelProvider): """ Base model provider for Qwen 2.5 VL Models. """ - # Language configuration inherited from Qwen25ModelProvider3B # VL models shouldn't scatter embeddings across sequence parallel regions because # the vision embeddings are going to be inserted into the language embeddings. scatter_embedding_sequence_parallel: bool = False @@ -74,4 +65,4 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen25V return model def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: - return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + return GPTModelProvider.provide(self, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py index 3a8f7190ef..9a414653fc 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py @@ -35,7 +35,12 @@ from megatron.bridge.utils.common_utils import extract_expert_number_from_param -@MegatronModelBridge.register_bridge(source=Qwen3VLForConditionalGeneration, target=Qwen3VLModel) +@MegatronModelBridge.register_bridge( + source=Qwen3VLForConditionalGeneration, + target=Qwen3VLModel, + provider=Qwen3VLModelProvider, + model_type="qwen3_vl", +) class Qwen3VLBridge(MegatronModelBridge): """ Megatron Bridge for Qwen3-VL Conditional Generation. @@ -70,51 +75,33 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3VLModelProvider: hf_config = hf_pretrained.config text_config = hf_config.text_config - # Get the model dtype from text config - model_dtype = self.dtype_from_hf(text_config, default=torch.float32) + provider_kwargs = self.hf_config_to_provider_kwargs(text_config) - # Set vision config dtype to match the language model dtype - # This ensures vision model parameters are initialized in the same dtype vision_config = hf_config.vision_config - vision_config.torch_dtype = model_dtype - - # Create the provider with text model configuration - provider = Qwen3VLModelProvider( - # Language model configuration from text_config - num_layers=text_config.num_hidden_layers, - hidden_size=text_config.hidden_size, - ffn_hidden_size=text_config.intermediate_size, - num_attention_heads=text_config.num_attention_heads, - num_query_groups=text_config.num_key_value_heads, # GQA configuration - head_dim=text_config.head_dim, - init_method_std=text_config.initializer_range, - layernorm_epsilon=text_config.rms_norm_eps, - gated_linear_unit=True, # Qwen3 uses gated linear units - make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size), - rotary_base=text_config.rope_theta, - share_embeddings_and_output_weights=getattr(text_config, "tie_word_embeddings", False), - vocab_size=text_config.vocab_size, - seq_length=text_config.max_position_embeddings, - fp16=(model_dtype == torch.float16), - bf16=(model_dtype == torch.bfloat16), - params_dtype=model_dtype, - # Qwen3 specific parameters - add_qkv_bias=text_config.attention_bias, # Qwen3 can have bias in QKV - qk_layernorm=True, # Qwen3 uses QK layernorm - # Vision configuration - vision_config=vision_config, - # Store the original HF text config for RoPE initialization - hf_text_config=text_config, - # Vision-Language token IDs - bos_token_id=getattr(text_config, "bos_token_id", 151643), - eos_token_id=getattr(text_config, "eos_token_id", 151645), - vision_start_token_id=getattr(hf_config, "vision_start_token_id", 151652), - vision_end_token_id=getattr(hf_config, "vision_end_token_id", 151653), - image_token_id=getattr(hf_config, "image_token_id", 151655), - video_token_id=getattr(hf_config, "video_token_id", 151656), - # MRoPE configuration for multimodal position embeddings - mrope_section=text_config.rope_scaling.get("mrope_section", [24, 20, 20]), - ) + vision_config.torch_dtype = provider_kwargs.get("params_dtype", torch.float32) + + provider = Qwen3VLModelProvider(**provider_kwargs) + + # Qwen3-specific settings + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.add_qkv_bias = text_config.attention_bias + provider.add_bias_linear = False + provider.qk_layernorm = True + provider.hidden_dropout = 0.0 + + # VL-specific overrides + provider.position_embedding_type = "mrope" + provider.vision_config = vision_config + provider.hf_text_config = text_config + provider.head_dim = text_config.head_dim + provider.bos_token_id = getattr(text_config, "bos_token_id", 151643) + provider.eos_token_id = getattr(text_config, "eos_token_id", 151645) + provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 151652) + provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 151653) + provider.image_token_id = getattr(hf_config, "image_token_id", 151655) + provider.video_token_id = getattr(hf_config, "video_token_id", 151656) + provider.mrope_section = text_config.rope_scaling.get("mrope_section", [24, 20, 20]) # TODO: setattr use_hf_vision_model to bridge instance in a dangerous way, maybe optimize it later. setattr(self, "use_hf_vision_model", provider.use_hf_vision_model) @@ -122,36 +109,14 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3VLModelProvider: return provider def mapping_registry(self) -> MegatronMappingRegistry: - """ - Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format. - - The mappings are organized into: - 1. Simple 1:1 mappings for embeddings, layer norms, and output layers - 2. Vision model mappings (replicated without modification) - 3. QKV mappings that combine separate Q, K, V matrices - 4. Gated MLP mappings that combine gate and up projections - 5. Deepstack visual merger mappings - - Returns: - MegatronMappingRegistry with all parameter mappings - """ - # Dictionary maps Megatron parameter names -> HF parameter names - # Based on yan-mbridge weight mappings in __init__.py - - # Language model direct mappings param_mappings = { - # Embeddings and output layers "language_model.embedding.word_embeddings.weight": "model.language_model.embed_tokens.weight", "language_model.output_layer.weight": "lm_head.weight", "language_model.decoder.final_layernorm.weight": "model.language_model.norm.weight", - # Layer normalization for attention and MLP "language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.language_model.layers.*.input_layernorm.weight", "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.language_model.layers.*.post_attention_layernorm.weight", - # Attention output projection "language_model.decoder.layers.*.self_attention.linear_proj.weight": "model.language_model.layers.*.self_attn.o_proj.weight", - # MLP output projection "language_model.decoder.layers.*.mlp.linear_fc2.weight": "model.language_model.layers.*.mlp.down_proj.weight", - # QK layernorm weights (Qwen3 specific) "language_model.decoder.layers.*.self_attention.q_layernorm.weight": "model.language_model.layers.*.self_attn.q_norm.weight", "language_model.decoder.layers.*.self_attention.k_layernorm.weight": "model.language_model.layers.*.self_attn.k_norm.weight", # vision module attn @@ -233,7 +198,12 @@ def mapping_registry(self) -> MegatronMappingRegistry: return MegatronMappingRegistry(*mapping_list) -@MegatronModelBridge.register_bridge(source=Qwen3VLMoeForConditionalGeneration, target=Qwen3VLModel) +@MegatronModelBridge.register_bridge( + source=Qwen3VLMoeForConditionalGeneration, + target=Qwen3VLModel, + provider=Qwen3VLMoEModelProvider, + model_type="qwen3_vl_moe", +) class Qwen3VLMoEBridge(MegatronModelBridge): """ Megatron Bridge for Qwen3-VL MoE (Mixture of Experts) Conditional Generation. @@ -259,71 +229,54 @@ class Qwen3VLMoEBridge(MegatronModelBridge): def __init__(self): super().__init__() - # Cache expert shards during HF export until all ranks contribute. self.hf_weights_cache: Dict[str, Dict[int, torch.Tensor]] = {} def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3VLMoEModelProvider: - """ - Create a Qwen3VLMoEModelProvider from a HuggingFace pretrained MoE model. - - Args: - hf_pretrained: HuggingFace pretrained VLM MoE model - - Returns: - Qwen3VLMoEModelProvider configured with the HF MoE model's parameters - """ hf_config = hf_pretrained.config text_config = hf_config.text_config - # Get the model dtype from text config - model_dtype = self.dtype_from_hf(text_config, default=torch.float32) + provider_kwargs = self.hf_config_to_provider_kwargs(text_config) - # Set vision config dtype to match the language model dtype - # This ensures vision model parameters are initialized in the same dtype vision_config = hf_config.vision_config - vision_config.torch_dtype = model_dtype - - provider = Qwen3VLMoEModelProvider( - num_layers=text_config.num_hidden_layers, - hidden_size=text_config.hidden_size, - ffn_hidden_size=text_config.intermediate_size, # Dense FFN size (for non-MoE layers if any) - moe_ffn_hidden_size=text_config.moe_intermediate_size, # Expert FFN size - num_attention_heads=text_config.num_attention_heads, - num_query_groups=text_config.num_key_value_heads, # GQA configuration - head_dim=getattr(text_config, "head_dim", text_config.hidden_size // text_config.num_attention_heads), - init_method_std=text_config.initializer_range, - layernorm_epsilon=text_config.rms_norm_eps, - gated_linear_unit=True, # Qwen3 MoE uses gated linear units - make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size), - rotary_base=getattr(text_config, "rope_theta", 5000000.0), # Default Qwen3 rope theta - share_embeddings_and_output_weights=getattr(text_config, "tie_word_embeddings", False), - vocab_size=text_config.vocab_size, - seq_length=text_config.max_position_embeddings, - fp16=(model_dtype == torch.float16), - bf16=(model_dtype == torch.bfloat16), - params_dtype=model_dtype, - # Qwen3 specific parameters - add_qkv_bias=text_config.attention_bias, # Qwen3 can have bias in QKV - qk_layernorm=True, # Qwen3 uses QK layernorm - # MoE specific parameters - num_moe_experts=text_config.num_experts, - moe_router_topk=text_config.num_experts_per_tok, - decoder_sparse_step=getattr(text_config, "decoder_sparse_step", 1), # Default to every layer being MoE - mlp_only_layers=getattr(text_config, "mlp_only_layers", []), # Default to all layers using MoE - # Vision configuration - vision_config=vision_config, - # Store the original HF text config for RoPE initialization - hf_text_config=text_config, - # Vision-Language token IDs - bos_token_id=getattr(text_config, "bos_token_id", 151643), - eos_token_id=getattr(text_config, "eos_token_id", 151645), - vision_start_token_id=getattr(hf_config, "vision_start_token_id", 151652), - vision_end_token_id=getattr(hf_config, "vision_end_token_id", 151653), - image_token_id=getattr(hf_config, "image_token_id", 151655), - video_token_id=getattr(hf_config, "video_token_id", 151656), - # MRoPE configuration for multimodal position embeddings - mrope_section=getattr(text_config, "rope_scaling", {}).get("mrope_section", [24, 20, 20]), + vision_config.torch_dtype = provider_kwargs.get("params_dtype", torch.float32) + + provider = Qwen3VLMoEModelProvider(**provider_kwargs) + + # Qwen3 MoE-specific settings + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.add_qkv_bias = text_config.attention_bias + provider.add_bias_linear = False + provider.qk_layernorm = True + provider.hidden_dropout = 0.0 + + # MoE specific parameters + provider.moe_ffn_hidden_size = text_config.moe_intermediate_size + provider.num_moe_experts = text_config.num_experts + provider.moe_router_topk = text_config.num_experts_per_tok + provider.decoder_sparse_step = getattr(text_config, "decoder_sparse_step", 1) + provider.mlp_only_layers = getattr(text_config, "mlp_only_layers", []) + provider.moe_grouped_gemm = True + provider.moe_router_load_balancing_type = "aux_loss" + provider.moe_aux_loss_coeff = 1e-3 + provider.moe_router_pre_softmax = False + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_permute_fusion = True + + # VL-specific overrides + provider.position_embedding_type = "mrope" + provider.vision_config = vision_config + provider.hf_text_config = text_config + provider.head_dim = getattr( + text_config, "head_dim", text_config.hidden_size // text_config.num_attention_heads ) + provider.bos_token_id = getattr(text_config, "bos_token_id", 151643) + provider.eos_token_id = getattr(text_config, "eos_token_id", 151645) + provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 151652) + provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 151653) + provider.image_token_id = getattr(hf_config, "image_token_id", 151655) + provider.video_token_id = getattr(hf_config, "video_token_id", 151656) + provider.mrope_section = getattr(text_config, "rope_scaling", {}).get("mrope_section", [24, 20, 20]) return provider diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py index 2ce17d4ac7..c7371f8670 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py @@ -28,12 +28,12 @@ from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig, Qwen3VLVisionConfig from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeTextConfig -from megatron.bridge.models import Qwen3ModelProvider, Qwen3MoEModelProvider +from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel @dataclass -class Qwen3VLModelProvider(Qwen3ModelProvider): +class Qwen3VLModelProvider(GPTModelProvider): """ Base model provider for Qwen 3 VL Models. Inherits language model configuration from Qwen3ModelProvider. @@ -42,9 +42,7 @@ class Qwen3VLModelProvider(Qwen3ModelProvider): Default value of 8 is used for GQA (Grouped Query Attention). """ - head_dim: int = 128 - hidden_size: int = 2048 - + # Fields from Qwen3VLTransformerConfig language_max_sequence_length: int = 2048 patch_size: int = 16 temporal_patch_size: int = 2 @@ -61,8 +59,7 @@ class Qwen3VLModelProvider(Qwen3ModelProvider): vision_config: Qwen3VLVisionConfig = field(default_factory=lambda: Qwen3VLVisionConfig()) hf_text_config: Optional[Qwen3VLTextConfig] = None - - # Vision-specific token IDs matching Qwen3VL configuration + # Vision-Language token IDs # Based on https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/config.json # Token ID for image placeholder in text image_token_id: int = 151655 @@ -94,11 +91,8 @@ class Qwen3VLModelProvider(Qwen3ModelProvider): scatter_embedding_sequence_parallel: bool = False # Freeze options for fine-tuning scenarios - # Whether to freeze language model weights freeze_language_model: bool = False - # Whether to freeze vision encoder weights freeze_vision_model: bool = False - # Whether to freeze vision-to-language projection weights freeze_vision_projection: bool = False sequence_parallel: bool = False @@ -111,12 +105,9 @@ class Qwen3VLModelProvider(Qwen3ModelProvider): vision_dp_when_cp: bool = False - def provide(self, pre_process=None, post_process=None, vp_stage=None): - """ - Provide a Qwen3VL model instance with vision and language components. - """ + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: + """Provide a Qwen3 VL model instance with vision and language components.""" language_transformer_config = self - hf_vision_config = self.vision_config # Spec for the Qwen3VLTransformerLayer @@ -147,35 +138,32 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): return model def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: - """ - Provide just the language model component without vision. - - Args: - pre_process: Whether this is the first stage in pipeline parallelism - post_process: Whether this is the last stage in pipeline parallelism - vp_stage: Virtual pipeline stage number - - Returns: - MCoreGPTModel instance (language model only) - """ - # Use parent class to create standard language model - return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + """Provide just the language model component without vision.""" + # Use GPTModelProvider's provide method to create standard language model + return GPTModelProvider.provide(self, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) @dataclass -class Qwen3VLMoEModelProvider(Qwen3MoEModelProvider): +class Qwen3VLMoEModelProvider(GPTModelProvider): """ - Base model provider for Qwen 3 VL MoE Models. - Inherits language model MoE configuration from Qwen3MoEModelProvider. + Base model provider for Qwen 3 VL MoE (Mixture of Experts) Models. + + This provider inherits directly from GPTModelProvider following the + provider_bridge refactoring pattern. It includes: + - Qwen3 MoE-specific LLM defaults (RMSNorm, gated linear unit, QK layernorm, MoE config) + - VL-specific configurations (vision_config, token IDs, mrope) - Key MoE Parameters (inherited from Qwen3MoEModelProvider): + The Qwen3VLMoEBridge leverages Qwen3MoEBridge for HF config mapping, + then applies VL-specific overrides. + + Key MoE Parameters: - num_moe_experts: Number of total experts (default 128) - moe_router_topk: Number of experts selected per token (default 8) - moe_router_load_balancing_type: Load balancing strategy (default "aux_loss") - moe_aux_loss_coeff: Auxiliary loss coefficient (default 1e-3) - moe_grouped_gemm: Use grouped GEMM for efficiency (default True) - Note: num_query_groups in parent class corresponds to num_key_value_heads in HF config. + Note: num_query_groups corresponds to num_key_value_heads in HF config. """ # Vision configuration using the transformers Qwen3VLVisionConfig @@ -245,25 +233,20 @@ class Qwen3VLMoEModelProvider(Qwen3MoEModelProvider): decoder_sparse_step: int = 1 # Every layer is MoE by default # Freeze options for fine-tuning scenarios - # Whether to freeze language model weights freeze_language_model: bool = True - # Whether to freeze vision encoder weights freeze_vision_model: bool = True - # Whether to freeze vision-to-language projection weights freeze_vision_projection: bool = False language_max_sequence_length: int = 2048 - # QK layernorm is already True in Qwen3MoEModelProvider, no need to redefine - - # These are typically set in the base class but documented here for clarity - persist_layer_norm: bool = True # Persist layer norm for efficiency - bias_activation_fusion: bool = True # Fuse bias and activation - bias_dropout_fusion: bool = True # Fuse bias and dropout + # Performance optimizations + persist_layer_norm: bool = True + bias_activation_fusion: bool = True + bias_dropout_fusion: bool = True masked_softmax_fusion: bool = False # Don't fuse masked softmax (Qwen specific) - deallocate_pipeline_outputs: bool = True # Deallocate pipeline outputs to save memory - async_tensor_model_parallel_allreduce: bool = True # Async tensor parallel - distribute_saved_activations: bool = False # Don't distribute saved activations - cp_comm_type: str = "p2p" # Point-to-point communication for context parallel + deallocate_pipeline_outputs: bool = True + async_tensor_model_parallel_allreduce: bool = True + distribute_saved_activations: bool = False + cp_comm_type: str = "p2p" use_hf_vision_model: bool = False vision_dp_when_cp: bool = False @@ -271,17 +254,13 @@ class Qwen3VLMoEModelProvider(Qwen3MoEModelProvider): def finalize(self) -> None: if self.tensor_model_parallel_size > 1: self.sequence_parallel = True - super().finalize() - def provide(self, pre_process=None, post_process=None, vp_stage=None): - """ - Provide a Qwen3VL MoE model instance with vision and language components. - """ + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> Qwen3VLModel: + """Provide a Qwen3 VL MoE model instance with vision and language components.""" language_transformer_config = self - - # handle vision config inside model initialization hf_vision_config = self.vision_config + language_transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=self.num_moe_experts, moe_grouped_gemm=True, @@ -289,7 +268,7 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): fp8=False, ) - # reuse Qwen3VLModel for MoE model but replace the language model with MoE language model + # Reuse Qwen3VLModel for MoE model but replace the language model with MoE language model model = Qwen3VLModel( language_transformer_config=language_transformer_config, language_transformer_layer_spec=language_transformer_layer_spec, @@ -310,16 +289,6 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None): return model def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: - """ - Provide just the language MoE model component without vision. - - Args: - pre_process: Whether this is the first stage in pipeline parallelism - post_process: Whether this is the last stage in pipeline parallelism - vp_stage: Virtual pipeline stage number - - Returns: - MCoreGPTModel instance (MoE language model only) - """ - # Use parent class to create standard MoE language model - return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) + """Provide just the language MoE model component without vision.""" + # Use GPTModelProvider's provide method to create standard MoE language model + return GPTModelProvider.provide(self, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) diff --git a/tests/unit_tests/models/gemma_vl/test_gemma3_vl_bridge.py b/tests/unit_tests/models/gemma_vl/test_gemma3_vl_bridge.py index c70611cfb9..a134d9d537 100644 --- a/tests/unit_tests/models/gemma_vl/test_gemma3_vl_bridge.py +++ b/tests/unit_tests/models/gemma_vl/test_gemma3_vl_bridge.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from unittest.mock import Mock, patch +from unittest.mock import Mock import pytest import torch @@ -44,6 +44,14 @@ def mock_text_config(): config.rope_theta = 1000000.0 config.query_pre_attn_scalar = 256 config.rope_scaling = None + config.hidden_act = "gelu_pytorch_tanh" + # Set MLA-specific fields to None (these are auto-mapped in CONFIG_MAPPING) + config.q_lora_rank = None + config.kv_lora_rank = None + config.qk_nope_head_dim = None + config.qk_rope_head_dim = None + config.v_head_dim = None + config.num_nextn_predict_layers = None return config @@ -195,7 +203,7 @@ def test_provider_bridge_with_missing_token_ids(self, gemma3_vl_bridge, mock_hf_ # Should use defaults assert provider.vision_start_token_id == 255999 - assert provider.image_token_id == 151655 # Default from bridge + assert provider.image_token_id == 262144 # Default from bridge def test_provider_bridge_with_rope_scaling(self, gemma3_vl_bridge, mock_hf_pretrained): """Test provider_bridge with RoPE scaling configuration.""" @@ -206,39 +214,14 @@ def test_provider_bridge_with_rope_scaling(self, gemma3_vl_bridge, mock_hf_pretr assert provider.rope_scaling_factor == 2.0 - @patch.object(Gemma3VLBridge, "dtype_from_hf") - def test_provider_bridge_dtype_handling(self, mock_dtype_from_hf, gemma3_vl_bridge, mock_hf_pretrained): - """Test provider_bridge handles dtype correctly.""" - mock_dtype_from_hf.return_value = torch.float16 - - provider = gemma3_vl_bridge.provider_bridge(mock_hf_pretrained) - - assert provider.fp16 is True - assert provider.bf16 is False - assert provider.params_dtype == torch.float16 - - @patch.object(Gemma3VLBridge, "dtype_from_hf") - def test_provider_bridge_bfloat16_handling(self, mock_dtype_from_hf, gemma3_vl_bridge, mock_hf_pretrained): - """Test provider_bridge handles bfloat16 correctly.""" - mock_dtype_from_hf.return_value = torch.bfloat16 - + def test_provider_bridge_hardcoded_bf16(self, gemma3_vl_bridge, mock_hf_pretrained): + """Test provider_bridge hardcodes bf16 dtype.""" provider = gemma3_vl_bridge.provider_bridge(mock_hf_pretrained) - assert provider.fp16 is False + # Gemma3VL bridge hardcodes bf16 to match baseline assert provider.bf16 is True assert provider.params_dtype == torch.bfloat16 - @patch.object(Gemma3VLBridge, "dtype_from_hf") - def test_provider_bridge_float32_handling(self, mock_dtype_from_hf, gemma3_vl_bridge, mock_hf_pretrained): - """Test provider_bridge handles float32 correctly.""" - mock_dtype_from_hf.return_value = torch.float32 - - provider = gemma3_vl_bridge.provider_bridge(mock_hf_pretrained) - - assert provider.fp16 is False - assert provider.bf16 is False - assert provider.params_dtype == torch.float32 - class TestGemma3VLBridgeMappingRegistry: """Test mapping_registry method functionality.""" @@ -399,6 +382,14 @@ def test_provider_bridge_with_minimal_config(self, gemma3_vl_bridge): text_config.rope_theta = 1000000.0 text_config.query_pre_attn_scalar = 256 text_config.rope_scaling = None + text_config.hidden_act = "gelu_pytorch_tanh" + # Set MLA-specific fields to None + text_config.q_lora_rank = None + text_config.kv_lora_rank = None + text_config.qk_nope_head_dim = None + text_config.qk_rope_head_dim = None + text_config.v_head_dim = None + text_config.num_nextn_predict_layers = None # Create minimal vision config vision_config = SiglipVisionConfig() diff --git a/tests/unit_tests/models/nemotron_vl/test_nemotron_vl_bridge.py b/tests/unit_tests/models/nemotron_vl/test_nemotron_vl_bridge.py index 313560fec2..3e94446b32 100644 --- a/tests/unit_tests/models/nemotron_vl/test_nemotron_vl_bridge.py +++ b/tests/unit_tests/models/nemotron_vl/test_nemotron_vl_bridge.py @@ -34,9 +34,18 @@ def mock_llm_config(): cfg.num_attention_heads = 40 cfg.num_key_value_heads = 8 cfg.initializer_range = 0.02 - cfg.layer_norm_epsilon = 1e-5 + cfg.rms_norm_eps = 1e-5 # CONFIG_MAPPING uses rms_norm_eps -> layernorm_epsilon cfg.vocab_size = 262144 cfg.max_position_embeddings = 131072 + cfg.hidden_act = "relu2" + # Set MLA-specific fields to None (these are auto-mapped in CONFIG_MAPPING) + cfg.q_lora_rank = None + cfg.kv_lora_rank = None + cfg.qk_nope_head_dim = None + cfg.qk_rope_head_dim = None + cfg.v_head_dim = None + cfg.num_nextn_predict_layers = None + cfg.rope_scaling = None return cfg diff --git a/tests/unit_tests/models/qwen_vl/test_qwen25_vl_bridge.py b/tests/unit_tests/models/qwen_vl/test_qwen25_vl_bridge.py index 581f823d1a..0eda02e0f7 100644 --- a/tests/unit_tests/models/qwen_vl/test_qwen25_vl_bridge.py +++ b/tests/unit_tests/models/qwen_vl/test_qwen25_vl_bridge.py @@ -21,7 +21,7 @@ from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.qwen_vl.qwen25_vl_bridge import Qwen25VLBridge -from megatron.bridge.models.qwen_vl.qwen_vl_provider import Qwen25VLModelProvider +from megatron.bridge.models.qwen_vl.qwen25_vl_provider import Qwen25VLModelProvider @pytest.fixture @@ -39,6 +39,15 @@ def mock_hf_config(): config.max_position_embeddings = 4096 config.rope_theta = 1000000.0 config.tie_word_embeddings = False + config.hidden_act = "silu" + # Set MLA-specific fields to None (these are auto-mapped in CONFIG_MAPPING) + config.q_lora_rank = None + config.kv_lora_rank = None + config.qk_nope_head_dim = None + config.qk_rope_head_dim = None + config.v_head_dim = None + config.num_nextn_predict_layers = None + config.rope_scaling = None # VL-specific configuration config.vision_config = Qwen2_5_VLVisionConfig() @@ -304,6 +313,15 @@ def test_provider_bridge_with_minimal_config(self, qwen25_vl_bridge): minimal_config.max_position_embeddings = 4096 minimal_config.rope_theta = 1000000.0 minimal_config.vision_config = Qwen2_5_VLVisionConfig() + minimal_config.hidden_act = "silu" + # Set MLA-specific fields to None + minimal_config.q_lora_rank = None + minimal_config.kv_lora_rank = None + minimal_config.qk_nope_head_dim = None + minimal_config.qk_rope_head_dim = None + minimal_config.v_head_dim = None + minimal_config.num_nextn_predict_layers = None + minimal_config.rope_scaling = None minimal_pretrained.config = minimal_config diff --git a/tests/unit_tests/models/qwen_vl/test_qwen25_vl_provider.py b/tests/unit_tests/models/qwen_vl/test_qwen25_vl_provider.py index b422411f54..2b9bfb605c 100644 --- a/tests/unit_tests/models/qwen_vl/test_qwen25_vl_provider.py +++ b/tests/unit_tests/models/qwen_vl/test_qwen25_vl_provider.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch.nn.functional as F from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionConfig from megatron.bridge.models.qwen_vl import Qwen25VLModelProvider @@ -34,20 +33,22 @@ def test_qwen25_vl_model_provider_initialization(self): assert provider.hidden_size == 4096 assert provider.num_attention_heads == 32 - # Check Qwen2-inherited defaults - assert provider.normalization == "RMSNorm" - assert provider.activation_func is F.silu - assert provider.gated_linear_unit is True - assert provider.add_bias_linear is False - assert provider.add_qkv_bias is True - assert provider.seq_length == 4096 - assert provider.init_method_std == 0.02 - assert provider.hidden_dropout == 0.0 - assert provider.attention_dropout == 0.0 - assert provider.vocab_size == 151936 - assert provider.share_embeddings_and_output_weights is False - assert provider.layernorm_epsilon == 1e-6 - assert provider.rotary_base == 1000000.0 + # Check VL-specific defaults (inherits from GPTModelProvider) + assert provider.scatter_embedding_sequence_parallel is False + assert provider.position_embedding_type == "mrope" + assert provider.mrope_section == [16, 24, 24] + + # Check vision config + assert isinstance(provider.vision_config, Qwen2_5_VLVisionConfig) + + # Check token IDs + assert provider.bos_token_id == 151643 + assert provider.eos_token_id == 151645 + assert provider.vision_start_token_id == 151652 + assert provider.vision_end_token_id == 151653 + assert provider.vision_token_id == 151654 + assert provider.image_token_id == 151655 + assert provider.video_token_id == 151656 def test_qwen25_vl_vl_specific_defaults(self): """Test Qwen25VLModelProvider VL-specific default configuration.""" @@ -253,11 +254,11 @@ def test_qwen25_vl_model_provider_edge_cases(self): class TestQwen25VLModelProviderInheritance: """Test inheritance relationships for Qwen25VLModelProvider.""" - def test_qwen25_vl_inherits_from_qwen2_provider(self): - """Test that Qwen25VLModelProvider inherits from Qwen2ModelProvider.""" - from megatron.bridge.models import Qwen2ModelProvider + def test_qwen25_vl_inherits_from_gpt_provider(self): + """Test that Qwen25VLModelProvider inherits from GPTModelProvider.""" + from megatron.bridge.models.gpt_provider import GPTModelProvider - assert issubclass(Qwen25VLModelProvider, Qwen2ModelProvider) + assert issubclass(Qwen25VLModelProvider, GPTModelProvider) def test_qwen25_vl_provider_method_inheritance(self): """Test that inherited methods work correctly."""