diff --git a/recipe/open_math_reasoning/run_sft_qwen3_8b.sh b/recipe/open_math_reasoning/run_sft_qwen3_8b.sh index 3b7e9bb5c6c..ec564a1d602 100644 --- a/recipe/open_math_reasoning/run_sft_qwen3_8b.sh +++ b/recipe/open_math_reasoning/run_sft_qwen3_8b.sh @@ -55,7 +55,7 @@ MEGATRON_ENGINE_CONFIG="\ engine.pipeline_model_parallel_size=${PP_SIZE} \ engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \ engine.context_parallel_size=${CP_SIZE} \ - engine.use_mbridge=False" + engine.use_mbridge=True" if [ "$backend" = "fsdp" ]; then ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" diff --git a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml index ea2a15d685e..06e2e94a662 100644 --- a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml +++ b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml @@ -111,7 +111,7 @@ actor_rollout_ref: dist_checkpointing_path: null seed: 42 override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage - use_mbridge: False + use_mbridge: True vanilla_mbridge: True profile: # profile the actor model in `update_policy` use_profile: False # open it when you want to profile the actor model diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index 88d90be3713..fd0f0ce1a5f 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -22,6 +22,54 @@ import torch import torch.nn as nn +from .model_forward import gptmodel_forward_no_padding, model_forward_gen +from .model_forward_fused import fused_forward_model_gen + + +class SupportedVLM(Enum): + QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" + QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration" + QWEN3_VL = "Qwen3VLForConditionalGeneration" + + +def get_mcore_forward_fn(hf_config) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + if hf_config.architectures[0] in SupportedVLM: + return model_forward_gen(True) + else: + # default to language model + return model_forward_gen(False) + + +def get_mcore_forward_no_padding_fn(hf_config) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + return gptmodel_forward_no_padding + + +def get_mcore_forward_fused_fn(hf_config) -> Callable: + """ + Get the forward function for given model architecture. + """ + assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" + if hf_config.architectures[0] in SupportedVLM: + return fused_forward_model_gen(True) + else: + # default to language model + return fused_forward_model_gen(False) + + +# ruff: noqa + +######################################################## +# below is the deprecated code +######################################################## + from .config_converter import ( PretrainedConfig, TransformerConfig, @@ -33,8 +81,6 @@ hf_to_mcore_config_qwen2moe, hf_to_mcore_config_qwen3moe, ) -from .model_forward import gptmodel_forward_no_padding, model_forward_gen -from .model_forward_fused import fused_forward_model_gen from .model_initializer import ( BaseModelInitializer, DeepseekV3Model, @@ -239,33 +285,6 @@ def init_mcore_model( ) -def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: - """ - Get the forward function for given model architecture. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_REGISTRY[model] - - -def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable: - """ - Get the forward function for given model architecture. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_NOPAD_REGISTRY[model] - - -def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: - """ - Get the forward function for given model architecture. - """ - assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" - model = get_supported_model(hf_config.architectures[0]) - return MODEL_FORWARD_FUSED_REGISTRY[model] - - def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: """ Get the weight converter for given model architecture. diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 9e8e0e813bd..a117c0f332f 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -52,7 +52,7 @@ actor_rollout_ref: recompute_num_layers: null attention_backend: flash override_mcore_model_config: {} - use_mbridge: false + use_mbridge: true vanilla_mbridge: true use_remove_padding: true forward_only: false @@ -433,7 +433,7 @@ critic: recompute_num_layers: null attention_backend: flash override_mcore_model_config: {} - use_mbridge: false + use_mbridge: true vanilla_mbridge: true use_remove_padding: true forward_only: false diff --git a/verl/trainer/config/engine/megatron.yaml b/verl/trainer/config/engine/megatron.yaml index 84601f5a3f5..b588a96c1b3 100644 --- a/verl/trainer/config/engine/megatron.yaml +++ b/verl/trainer/config/engine/megatron.yaml @@ -75,7 +75,7 @@ override_transformer_config: override_mcore_model_config: {} # oc.select: default val for ref.megatron.use_mbridge -use_mbridge: False +use_mbridge: True # oc.select: default val for ref.megatron.vanilla_mbridge vanilla_mbridge: True diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index c319910d855..13c63ebdce0 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -1220,3 +1220,11 @@ def register_megatron_training_hooks(model: list[torch.nn.Module], optimizer): config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: config.param_sync_func = config.param_sync_func[0] + + +def mapping_string_to_attn_backend(args: dict) -> dict: + if "attention_backend" in args and isinstance(args["attention_backend"], str): + from megatron.core.transformer.enums import AttnBackend + + args["attention_backend"] = AttnBackend[args["attention_backend"]] + return args diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index a8799c35691..b645715489e 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -120,7 +120,7 @@ class McoreEngineConfig(EngineConfig): override_ddp_config: dict[str, Any] = field(default_factory=dict) override_transformer_config: dict[str, Any] = field(default_factory=dict) override_mcore_model_config: dict[str, Any] = field(default_factory=dict) - use_mbridge: bool = False + use_mbridge: bool = True vanilla_mbridge: bool = True strategy: str = "megatron" diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 00292a73c2e..1a1756012d8 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -41,13 +41,11 @@ load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, - per_tensor_generator, register_megatron_training_hooks, ) from verl.utils.model import ( extract_multi_modal_inputs_tensordict, load_mcore_dist_weights, - load_megatron_gptmodel_weights, ) from verl.workers.config import HFModelConfig, McoreEngineConfig, McoreOptimizerConfig @@ -76,7 +74,7 @@ def __init__( self.engine_config = engine_config self.optimizer_config = optimizer_config self.checkpoint_config = checkpoint_config - + assert self.engine_config.use_mbridge, "use_mbridge must be True" self._init_device_mesh() set_random_seed(seed=self.engine_config.seed) @@ -110,70 +108,62 @@ def _init_device_mesh(self): ) def _build_tf_config(self): - from verl.models.mcore import hf_to_mcore_config - from verl.models.mcore.config_converter import mapping_string_to_attn_backend + from verl.utils.megatron_utils import mapping_string_to_attn_backend from verl.utils.torch_dtypes import PrecisionType self.param_dtype = PrecisionType.to_dtype(self.engine_config.dtype) - if self.param_dtype == torch.float16: - assert self.engine_config.use_mbridge, "fp16 mode requires use_mbridge to be True" self.dtype = PrecisionType.to_dtype(self.param_dtype) override_transformer_config = mapping_string_to_attn_backend({**self.engine_config.override_transformer_config}) - use_mbridge = self.engine_config.use_mbridge self.provider = None self.vanilla_bridge = self.engine_config.vanilla_mbridge - if use_mbridge: - if self.vanilla_bridge: - from verl.models.mcore.mbridge import AutoBridge - - bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) - bridge.set_extra_args(**override_transformer_config) - tf_config = bridge.config - tf_config.fp16 = self.param_dtype == torch.float16 - tf_config.bf16 = self.param_dtype == torch.bfloat16 - else: - from verl.models.mcore.bridge import AutoBridge - - # Use Megatron-Bridge to convert HF config to Megatron config - bridge = AutoBridge.from_hf_pretrained( - self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code - ) - # Get Megatron provider and configure it - provider = bridge.to_megatron_provider(load_weights=False) - - # In case of invalid overrides, we need to make sure some critical params are set correctly - provider.params_dtype = self.param_dtype - - # Pass distributed info - provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size - provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size - provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size - provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size - provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size - provider.context_parallel_size = self.engine_config.context_parallel_size - provider.sequence_parallel = self.engine_config.sequence_parallel - - # Match verl implementation (need variable_seq_lengths) - from megatron.core.transformer.enums import AttnBackend - - provider.attention_backend = AttnBackend.flash - provider.variable_seq_lengths = True - provider.moe_token_dispatcher_type = "alltoall" - provider.moe_router_load_balancing_type = "none" - - # Apply transformer config overrides - for key, value in override_transformer_config.items(): - setattr(provider, key, value) - - provider.finalize() - self.provider = provider - tf_config = None # Will be set after model creation - self.bridge = bridge + if self.vanilla_bridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(self.model_config.hf_config, dtype=self.param_dtype) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + tf_config.fp16 = self.param_dtype == torch.float16 + tf_config.bf16 = self.param_dtype == torch.bfloat16 else: - self.bridge = None - tf_config = hf_to_mcore_config(self.model_config.hf_config, self.dtype, **override_transformer_config) + from verl.models.mcore.bridge import AutoBridge + + # Use Megatron-Bridge to convert HF config to Megatron config + bridge = AutoBridge.from_hf_pretrained( + self.model_config.local_path, trust_remote_code=self.model_config.trust_remote_code + ) + # Get Megatron provider and configure it + provider = bridge.to_megatron_provider(load_weights=False) + + # In case of invalid overrides, we need to make sure some critical params are set correctly + provider.params_dtype = self.param_dtype + + # Pass distributed info + provider.tensor_model_parallel_size = self.engine_config.tensor_model_parallel_size + provider.pipeline_model_parallel_size = self.engine_config.pipeline_model_parallel_size + provider.expert_model_parallel_size = self.engine_config.expert_model_parallel_size + provider.expert_tensor_parallel_size = self.engine_config.expert_tensor_parallel_size + provider.virtual_pipeline_model_parallel_size = self.engine_config.virtual_pipeline_model_parallel_size + provider.context_parallel_size = self.engine_config.context_parallel_size + provider.sequence_parallel = self.engine_config.sequence_parallel + + # Match verl implementation (need variable_seq_lengths) + from megatron.core.transformer.enums import AttnBackend + + provider.attention_backend = AttnBackend.flash + provider.variable_seq_lengths = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + + # Apply transformer config overrides + for key, value in override_transformer_config.items(): + setattr(provider, key, value) + + provider.finalize() + self.provider = provider + tf_config = None # Will be set after model creation + self.bridge = bridge if not self.bridge: self.weight_converter = get_mcore_weight_converter(self.model_config.hf_config, self.dtype) @@ -232,28 +222,14 @@ def _build_megatron_module(self): if self.engine_config.use_dist_checkpointing: load_mcore_dist_weights(module, self.engine_config.dist_checkpointing_path, is_value_model=is_value_model) else: - if self.bridge is not None: - if self.vanilla_bridge: - self.bridge.load_weights(module, self.model_config.local_path) - else: - allowed_mismatched_params = [] - if self.is_value_model: - allowed_mismatched_params = ["output_layer.weight"] - self.bridge.load_hf_weights( - module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params - ) + if self.vanilla_bridge: + self.bridge.load_weights(module, self.model_config.local_path) else: - # (vermouth1992) this is a workaround to be compatible with the old API - tmp_config = OmegaConf.create( - {"model": {"path": self.model_config.local_path, "use_shm": self.model_config.use_shm}} - ) - - load_megatron_gptmodel_weights( - tmp_config, - self.model_config.hf_config, - module, - params_dtype=self.dtype, - is_value_model=is_value_model, + allowed_mismatched_params = [] + if self.is_value_model: + allowed_mismatched_params = ["output_layer.weight"] + self.bridge.load_hf_weights( + module, self.model_config.local_path, allowed_mismatched_params=allowed_mismatched_params ) if torch.distributed.get_rank() == 0: @@ -562,16 +538,7 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw def get_per_tensor_param(self): if self._is_offload_param: load_megatron_model_to_gpu(self.module, load_grad=False) - if self.bridge is not None: - per_tensor_param = self.bridge.export_weights(self.module) - else: - per_tensor_param = per_tensor_generator( - self.module, - self.model_config.hf_config, - self.weight_converter, - self.tf_config, - self.layer_name_mapping, - ) + per_tensor_param = self.bridge.export_weights(self.module) # TODO: support megatron LoRA return per_tensor_param, None