Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipe/open_math_reasoning/run_sft_qwen3_8b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ MEGATRON_ENGINE_CONFIG="\
engine.pipeline_model_parallel_size=${PP_SIZE} \
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
engine.context_parallel_size=${CP_SIZE} \
engine.use_mbridge=False"
engine.use_mbridge=True"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if [ "$backend" = "fsdp" ]; then
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/config/legacy_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ actor_rollout_ref:
dist_checkpointing_path: null
seed: 42
override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage
use_mbridge: False
use_mbridge: True
vanilla_mbridge: True
profile: # profile the actor model in `update_policy`
use_profile: False # open it when you want to profile the actor model
Expand Down
77 changes: 48 additions & 29 deletions verl/models/mcore/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,54 @@
import torch
import torch.nn as nn

from .model_forward import gptmodel_forward_no_padding, model_forward_gen
from .model_forward_fused import fused_forward_model_gen


class SupportedVLM(Enum):
QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
QWEN3_VL = "Qwen3VLForConditionalGeneration"


def get_mcore_forward_fn(hf_config) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
if hf_config.architectures[0] in SupportedVLM:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check hf_config.architectures[0] in SupportedVLM is incorrect. The in operator on an Enum checks for member identity, not the string value of the members. Since hf_config.architectures[0] is a string, this condition will always evaluate to False, leading to the wrong forward function being selected for Vision Language Models (VLMs).

Suggested change
if hf_config.architectures[0] in SupportedVLM:
if hf_config.architectures[0] in {item.value for item in SupportedVLM}:

return model_forward_gen(True)
else:
# default to language model
return model_forward_gen(False)


def get_mcore_forward_no_padding_fn(hf_config) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
return gptmodel_forward_no_padding


def get_mcore_forward_fused_fn(hf_config) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
if hf_config.architectures[0] in SupportedVLM:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the issue in get_mcore_forward_fn, the check hf_config.architectures[0] in SupportedVLM is incorrect. The in operator on an Enum checks for member identity, not the string value of the members. This will always evaluate to False, causing the wrong fused forward function to be selected for Vision Language Models (VLMs).

Suggested change
if hf_config.architectures[0] in SupportedVLM:
if hf_config.architectures[0] in {item.value for item in SupportedVLM}:

return fused_forward_model_gen(True)
else:
# default to language model
return fused_forward_model_gen(False)


# ruff: noqa

########################################################
# below is the deprecated code
########################################################

from .config_converter import (
PretrainedConfig,
TransformerConfig,
Expand All @@ -33,8 +81,6 @@
hf_to_mcore_config_qwen2moe,
hf_to_mcore_config_qwen3moe,
)
from .model_forward import gptmodel_forward_no_padding, model_forward_gen
from .model_forward_fused import fused_forward_model_gen
from .model_initializer import (
BaseModelInitializer,
DeepseekV3Model,
Expand Down Expand Up @@ -239,33 +285,6 @@ def init_mcore_model(
)


def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
return MODEL_FORWARD_REGISTRY[model]


def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
return MODEL_FORWARD_NOPAD_REGISTRY[model]


def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:
"""
Get the forward function for given model architecture.
"""
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
model = get_supported_model(hf_config.architectures[0])
return MODEL_FORWARD_FUSED_REGISTRY[model]


def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable:
"""
Get the weight converter for given model architecture.
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ actor_rollout_ref:
recompute_num_layers: null
attention_backend: flash
override_mcore_model_config: {}
use_mbridge: false
use_mbridge: true
vanilla_mbridge: true
use_remove_padding: true
forward_only: false
Expand Down Expand Up @@ -433,7 +433,7 @@ critic:
recompute_num_layers: null
attention_backend: flash
override_mcore_model_config: {}
use_mbridge: false
use_mbridge: true
vanilla_mbridge: true
use_remove_padding: true
forward_only: false
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/config/engine/megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ override_transformer_config:
override_mcore_model_config: {}

# oc.select: default val for ref.megatron.use_mbridge
use_mbridge: False
use_mbridge: True

# oc.select: default val for ref.megatron.vanilla_mbridge
vanilla_mbridge: True
Expand Down
8 changes: 8 additions & 0 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,3 +1220,11 @@ def register_megatron_training_hooks(model: list[torch.nn.Module], optimizer):
config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model]
if len(model) == 1:
config.param_sync_func = config.param_sync_func[0]


def mapping_string_to_attn_backend(args: dict) -> dict:
if "attention_backend" in args and isinstance(args["attention_backend"], str):
from megatron.core.transformer.enums import AttnBackend

args["attention_backend"] = AttnBackend[args["attention_backend"]]
return args
2 changes: 1 addition & 1 deletion verl/workers/config/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class McoreEngineConfig(EngineConfig):
override_ddp_config: dict[str, Any] = field(default_factory=dict)
override_transformer_config: dict[str, Any] = field(default_factory=dict)
override_mcore_model_config: dict[str, Any] = field(default_factory=dict)
use_mbridge: bool = False
use_mbridge: bool = True
vanilla_mbridge: bool = True
strategy: str = "megatron"

Expand Down
143 changes: 55 additions & 88 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,11 @@
load_megatron_optimizer,
offload_megatron_model_to_cpu,
offload_megatron_optimizer,
per_tensor_generator,
register_megatron_training_hooks,
)
from verl.utils.model import (
extract_multi_modal_inputs_tensordict,
load_mcore_dist_weights,
load_megatron_gptmodel_weights,
)
from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig

Expand Down Expand Up @@ -76,7 +74,7 @@ def __init__(
self.engine_config = engine_config
self.optimizer_config = optimizer_config
self.checkpoint_config = checkpoint_config

assert self.engine_config.use_mbridge, "use_mbridge must be True"
self._init_device_mesh()

set_random_seed(seed=self.engine_config.seed)
Expand Down Expand Up @@ -110,70 +108,62 @@ def _init_device_mesh(self):
)

def _build_tf_config(self):
from verl.models.mcore import hf_to_mcore_config
from verl.models.mcore.config_converter import mapping_string_to_attn_backend
from verl.utils.megatron_utils import mapping_string_to_attn_backend
from verl.utils.torch_dtypes import PrecisionType

self.param_dtype = PrecisionType.to_dtype(self.engine_config.dtype)
if self.param_dtype == torch.float16:
assert self.engine_config.use_mbridge, "fp16 mode requires use_mbridge to be True"
self.dtype = PrecisionType.to_dtype(self.param_dtype)

override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config})

use_mbridge = self.engine_config.use_mbridge
self.provider = None
self.vanilla_bridge = self.engine_config.vanilla_mbridge
if use_mbridge:
if self.vanilla_bridge:
from verl.models.mcore.mbridge import AutoBridge

bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype)
bridge.set_extra_args(**override_transformer_config)
tf_config = bridge.config
tf_config.fp16 = self.param_dtype == torch.float16
tf_config.bf16 = self.param_dtype == torch.bfloat16
else:
from verl.models.mcore.bridge import AutoBridge

# Use Megatron-Bridge to convert HF config to Megatron config
bridge = AutoBridge.from_hf_pretrained(
self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code
)
# Get Megatron provider and configure it
provider = bridge.to_megatron_provider(load_weights=False)

# In case of invalid overrides, we need to make sure some critical params are set correctly
provider.params_dtype = self.param_dtype

# Pass distributed info
provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size
provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size
provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size
provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size
provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size
provider.context_parallel_size = self.engine_config.context_parallel_size
provider.sequence_parallel = self.engine_config.sequence_parallel

# Match verl implementation (need variable_seq_lengths)
from megatron.core.transformer.enums import AttnBackend

provider.attention_backend = AttnBackend.flash
provider.variable_seq_lengths = True
provider.moe_token_dispatcher_type = "alltoall"
provider.moe_router_load_balancing_type = "none"

# Apply transformer config overrides
for key, value in override_transformer_config.items():
setattr(provider, key, value)

provider.finalize()
self.provider = provider
tf_config = None # Will be set after model creation
self.bridge = bridge
if self.vanilla_bridge:
from verl.models.mcore.mbridge import AutoBridge

bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype)
bridge.set_extra_args(**override_transformer_config)
tf_config = bridge.config
tf_config.fp16 = self.param_dtype == torch.float16
tf_config.bf16 = self.param_dtype == torch.bfloat16
else:
self.bridge = None
tf_config = hf_to_mcore_config(self.model_config.hf_config, self.dtype, **override_transformer_config)
from verl.models.mcore.bridge import AutoBridge

# Use Megatron-Bridge to convert HF config to Megatron config
bridge = AutoBridge.from_hf_pretrained(
self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code
)
# Get Megatron provider and configure it
provider = bridge.to_megatron_provider(load_weights=False)

# In case of invalid overrides, we need to make sure some critical params are set correctly
provider.params_dtype = self.param_dtype

# Pass distributed info
provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size
provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size
provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size
provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size
provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size
provider.context_parallel_size = self.engine_config.context_parallel_size
provider.sequence_parallel = self.engine_config.sequence_parallel

# Match verl implementation (need variable_seq_lengths)
from megatron.core.transformer.enums import AttnBackend

provider.attention_backend = AttnBackend.flash
provider.variable_seq_lengths = True
provider.moe_token_dispatcher_type = "alltoall"
provider.moe_router_load_balancing_type = "none"

# Apply transformer config overrides
for key, value in override_transformer_config.items():
setattr(provider, key, value)

provider.finalize()
self.provider = provider
tf_config = None # Will be set after model creation
self.bridge = bridge

if not self.bridge:
self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype)
Expand Down Expand Up @@ -232,28 +222,14 @@ def _build_megatron_module(self):
if self.engine_config.use_dist_checkpointing:
load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model)
else:
if self.bridge is not None:
if self.vanilla_bridge:
self.bridge.load_weights(module, self.model_config.local_path)
else:
allowed_mismatched_params = []
if self.is_value_model:
allowed_mismatched_params = ["output_layer.weight"]
self.bridge.load_hf_weights(
module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params
)
if self.vanilla_bridge:
self.bridge.load_weights(module, self.model_config.local_path)
else:
# (vermouth1992) this is a workaround to be compatible with the old API
tmp_config = OmegaConf.create(
{"model": {"path": self.model_config.local_path, "use_shm": self.model_config.use_shm}}
)

load_megatron_gptmodel_weights(
tmp_config,
self.model_config.hf_config,
module,
params_dtype=self.dtype,
is_value_model=is_value_model,
allowed_mismatched_params = []
if self.is_value_model:
allowed_mismatched_params = ["output_layer.weight"]
self.bridge.load_hf_weights(
module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params
)

if torch.distributed.get_rank() == 0:
Expand Down Expand Up @@ -562,16 +538,7 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
def get_per_tensor_param(self):
if self._is_offload_param:
load_megatron_model_to_gpu(self.module, load_grad=False)
if self.bridge is not None:
per_tensor_param = self.bridge.export_weights(self.module)
else:
per_tensor_param = per_tensor_generator(
self.module,
self.model_config.hf_config,
self.weight_converter,
self.tf_config,
self.layer_name_mapping,
)
per_tensor_param = self.bridge.export_weights(self.module)
# TODO: support megatron LoRA
return per_tensor_param, None

Expand Down
Loading