-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[megatron] chore: clean legacy code path part 1, make engine use mbridge by default #4528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -22,6 +22,54 @@ | |||||
| import torch | ||||||
| import torch.nn as nn | ||||||
|
|
||||||
| from .model_forward import gptmodel_forward_no_padding, model_forward_gen | ||||||
| from .model_forward_fused import fused_forward_model_gen | ||||||
|
|
||||||
|
|
||||||
| class SupportedVLM(Enum): | ||||||
| QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration" | ||||||
| QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration" | ||||||
| QWEN3_VL = "Qwen3VLForConditionalGeneration" | ||||||
|
|
||||||
|
|
||||||
| def get_mcore_forward_fn(hf_config) -> Callable: | ||||||
| """ | ||||||
| Get the forward function for given model architecture. | ||||||
| """ | ||||||
| assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" | ||||||
| if hf_config.architectures[0] in SupportedVLM: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check
Suggested change
|
||||||
| return model_forward_gen(True) | ||||||
| else: | ||||||
| # default to language model | ||||||
| return model_forward_gen(False) | ||||||
|
|
||||||
|
|
||||||
| def get_mcore_forward_no_padding_fn(hf_config) -> Callable: | ||||||
| """ | ||||||
| Get the forward function for given model architecture. | ||||||
| """ | ||||||
| assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" | ||||||
| return gptmodel_forward_no_padding | ||||||
|
|
||||||
|
|
||||||
| def get_mcore_forward_fused_fn(hf_config) -> Callable: | ||||||
| """ | ||||||
| Get the forward function for given model architecture. | ||||||
| """ | ||||||
| assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" | ||||||
| if hf_config.architectures[0] in SupportedVLM: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the issue in
Suggested change
|
||||||
| return fused_forward_model_gen(True) | ||||||
| else: | ||||||
| # default to language model | ||||||
| return fused_forward_model_gen(False) | ||||||
|
|
||||||
|
|
||||||
| # ruff: noqa | ||||||
|
|
||||||
| ######################################################## | ||||||
| # below is the deprecated code | ||||||
| ######################################################## | ||||||
|
|
||||||
| from .config_converter import ( | ||||||
| PretrainedConfig, | ||||||
| TransformerConfig, | ||||||
|
|
@@ -33,8 +81,6 @@ | |||||
| hf_to_mcore_config_qwen2moe, | ||||||
| hf_to_mcore_config_qwen3moe, | ||||||
| ) | ||||||
| from .model_forward import gptmodel_forward_no_padding, model_forward_gen | ||||||
| from .model_forward_fused import fused_forward_model_gen | ||||||
| from .model_initializer import ( | ||||||
| BaseModelInitializer, | ||||||
| DeepseekV3Model, | ||||||
|
|
@@ -239,33 +285,6 @@ def init_mcore_model( | |||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable: | ||||||
| """ | ||||||
| Get the forward function for given model architecture. | ||||||
| """ | ||||||
| assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" | ||||||
| model = get_supported_model(hf_config.architectures[0]) | ||||||
| return MODEL_FORWARD_REGISTRY[model] | ||||||
|
|
||||||
|
|
||||||
| def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable: | ||||||
| """ | ||||||
| Get the forward function for given model architecture. | ||||||
| """ | ||||||
| assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" | ||||||
| model = get_supported_model(hf_config.architectures[0]) | ||||||
| return MODEL_FORWARD_NOPAD_REGISTRY[model] | ||||||
|
|
||||||
|
|
||||||
| def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable: | ||||||
| """ | ||||||
| Get the forward function for given model architecture. | ||||||
| """ | ||||||
| assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" | ||||||
| model = get_supported_model(hf_config.architectures[0]) | ||||||
| return MODEL_FORWARD_FUSED_REGISTRY[model] | ||||||
|
|
||||||
|
|
||||||
| def get_mcore_weight_converter(hf_config: PretrainedConfig, dtype: torch.dtype) -> Callable: | ||||||
| """ | ||||||
| Get the weight converter for given model architecture. | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also set
use_mbridgeto True? https://github.com/volcengine/verl/blob/main/verl/trainer/config/engine/megatron.yaml#L77-L78