Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,38 @@ def hf_to_mcore_config_qwen2_5_vl(
return TransformerConfig(**args)


def hf_to_mcore_config_qwen2_5_omni(
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
"""Convert Qwen2_5OmniForConditionalGeneration config to Megatron TransformerConfig.

Qwen2_5_Omni has a nested config structure:
- Qwen2_5OmniConfig -> thinker_config -> Qwen2_5OmniThinkerConfig -> text_config -> Qwen2_5OmniTextConfig
"""
# Qwen2_5OmniForConditionalGeneration
text_config = hf_config.thinker_config.text_config

mrope_section = None
if hasattr(text_config, "rope_parameters") and text_config.rope_parameters is not None:
if isinstance(text_config.rope_parameters, dict):
# If rope_parameters is a dict, check for mrope_section
mrope_section = text_config.rope_parameters.get("mrope_section", None)
elif hasattr(text_config.rope_parameters, "mrope_section"):
mrope_section = text_config.rope_parameters.mrope_section

args = _get_base_transformer_config(
hf_config=text_config,
dtype=dtype,
add_bias_linear=False,
add_qkv_bias=True,
mrope_section=mrope_section,
)

args.update(override_transformer_config_kwargs)
args = mapping_string_to_attn_backend(args)
return TransformerConfig(**args)


def hf_to_mcore_config_llama4(
hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs
) -> TransformerConfig:
Expand Down
7 changes: 7 additions & 0 deletions verl/models/mcore/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SupportedVLM(Enum):
QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
QWEN3_VL = "Qwen3VLForConditionalGeneration"
QWEN2_5_OMNI = "Qwen2_5OmniForConditionalGeneration"


supported_vlm = [member.value for member in SupportedVLM]
Expand Down Expand Up @@ -83,6 +84,7 @@ def get_mcore_forward_fused_fn(hf_config) -> Callable:
hf_to_mcore_config_qwen2_5_vl,
hf_to_mcore_config_qwen2moe,
hf_to_mcore_config_qwen3moe,
hf_to_mcore_config_qwen2_5_omni,
)
from .model_initializer import (
BaseModelInitializer,
Expand Down Expand Up @@ -118,6 +120,7 @@ class SupportedModel(Enum):
LLAMA_TOKEN_CLASSIFICATION = "LlamaForTokenClassification"
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
QWEN3_VL = "Qwen3VLForConditionalGeneration"
QWEN2_5_OMNI = "Qwen2_5OmniForConditionalGeneration"
GPT_OSS = "GptOssForCausalLM"
MiMO = "MiMoForCausalLM"

Expand All @@ -135,6 +138,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN3_MOE: hf_to_mcore_config_qwen3moe,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,
SupportedModel.LLAMA_TOKEN_CLASSIFICATION: hf_to_mcore_config_dense,
SupportedModel.QWEN2_5_OMNI: hf_to_mcore_config_qwen2_5_omni,
}

# Registry for model initializers
Expand Down Expand Up @@ -170,6 +174,7 @@ class SupportedModel(Enum):
SupportedModel.LLAMA_TOKEN_CLASSIFICATION: model_forward_gen(),
SupportedModel.GPT_OSS: model_forward_gen(),
SupportedModel.MiMO: model_forward_gen(),
# SupportedModel.QWEN2_5_OMNI: model_forward_gen(True), # not implemented
}

# Registry for model forward functions
Expand All @@ -190,6 +195,7 @@ class SupportedModel(Enum):
SupportedModel.LLAMA_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding,
SupportedModel.GPT_OSS: gptmodel_forward_no_padding,
SupportedModel.MiMO: gptmodel_forward_no_padding,
# SupportedModel.QWEN2_5_OMNI: gptmodel_forward_no_padding, # not implemented
}

# Registry for model forward functions
Expand All @@ -208,6 +214,7 @@ class SupportedModel(Enum):
SupportedModel.GLM4_MOE: fused_forward_model_gen(),
SupportedModel.GPT_OSS: fused_forward_model_gen(),
SupportedModel.MiMO: fused_forward_model_gen(),
# SupportedModel.QWEN2_5_OMNI: fused_forward_model_gen(True), # not implemented
}

# Registry for model weight converters
Expand Down