Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6155a65
use mcore config_converer and model_initializer for more types of models
ISEEKYAN Apr 13, 2025
8869168
remove megatron_config from actor/critic
ISEEKYAN Apr 13, 2025
9216811
reward model use gptmodel api, clean megatron_worker
ISEEKYAN Apr 13, 2025
6c46c2a
mcore model_forward for registry
ISEEKYAN Apr 13, 2025
a9c21cf
Merge branch 'main' into mcore_refactor
ISEEKYAN Apr 14, 2025
e709dc3
(WIP) support qwen2moe
ISEEKYAN Apr 16, 2025
0775d36
qwen2moe config converter and weight converter
ISEEKYAN Apr 17, 2025
6113b10
add scripts to run qwen1.5moe_a2.7b
ISEEKYAN Apr 17, 2025
bbf41b6
Merge branch 'main' into mcore_qwen2moe
ISEEKYAN Apr 17, 2025
5f8d8a0
format
ISEEKYAN Apr 18, 2025
d2376ec
update scripts
ISEEKYAN Apr 18, 2025
39e0658
Merge branch 'main' into mcore_qwen2moe
ISEEKYAN Apr 18, 2025
57d9671
fix for pre-commit
ISEEKYAN Apr 18, 2025
5181a99
Merge branch 'main' into mcore_qwen2moe
ISEEKYAN Apr 19, 2025
7b66d82
fix bug of merge
ISEEKYAN Apr 19, 2025
941ab95
compatible to mcore 0.12
ISEEKYAN Apr 19, 2025
267a119
WIP support moonlight
ISEEKYAN Apr 21, 2025
8801841
fix
ISEEKYAN Apr 28, 2025
e5d6ca0
typo
ISEEKYAN Apr 28, 2025
7f84424
Merge branch 'main' into mcore_moonlight
ISEEKYAN Apr 28, 2025
ae550a8
add scripts
ISEEKYAN Apr 28, 2025
4c1be5f
Merge branch 'main' into mcore_moonlight
ISEEKYAN May 23, 2025
eedee64
fix ckpt converter
ISEEKYAN May 23, 2025
6f99304
fix bug
ISEEKYAN May 23, 2025
39ce67b
succeed in running moonlight
ISEEKYAN May 23, 2025
70c1201
Merge branch 'main' into mcore_moonlight
ISEEKYAN May 24, 2025
579e831
add `trust_remote_code` in config
ISEEKYAN May 26, 2025
3b82afc
update deeepseekv3 initializer
ISEEKYAN May 26, 2025
e0d43ff
fix initilizer
ISEEKYAN May 26, 2025
8757327
adjust for review
ISEEKYAN May 26, 2025
b73f16a
fix for sglang
ISEEKYAN May 27, 2025
81f474b
add option for trust_remote_code in converter
ISEEKYAN May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
set -x

# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping


# 0. download the model
huggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct

# 1. convert the model to mcore format
# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path
HF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct
DIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct
python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH


# 2. run the script
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
train_files=$gsm8k_train_path
test_files=$gsm8k_test_path

ALL_OFFLOAD=${ALL_OFFLOAD:-False}
COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}
COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}
COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}

ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}


NODES=4
PP=2
TP=8
EP=8
ETP=1
VLLM_TP=4

# RAY_ADDRESS='auto' ray job submit --working-dir . --
python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\
algorithm.adv_estimator=gae \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
+data.trust_remote_code=True \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should trust_remote_code be set in the model per ppo_megatron_trainer.yaml?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh data preprocessing might need this as well. please ignore this if i misunderstand.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this another topic beyond supporting moonlight? would it be better if we commit another small PR for the config file modification?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. we should track any change in the config and keep it consistent.

actor_rollout_ref.model.path=$LLM \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
critic.optim.lr=1e-5 \
critic.model.path=$LLM \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_megatron_gsm8k_examples' \
trainer.experiment_name='moonlight_16b_a3b_instruct_1node' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=$NODES \
trainer.save_freq=-1 \
trainer.test_freq=5 \
actor_rollout_ref.model.trust_remote_code=True \
critic.model.trust_remote_code=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=13 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \
critic.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \
critic.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \
critic.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \
critic.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \
actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \
actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \
critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \
critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \
critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
critic.megatron.use_dist_checkpointing=True \
actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
trainer.val_before_train=False \
trainer.total_epochs=100 $@

76 changes: 72 additions & 4 deletions scripts/converter_hf_to_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _init_args():
parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model")
parser.add_argument("--use_cpu_initialization", action="store_true", help="Whether to use cpu initialization")
parser.add_argument("--test", action="store_true", help="Whether to test the conversion")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to trust remote code")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -120,7 +121,7 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
layer.self_attention.linear_qkv.bias.copy_(qkv_bias)

if hasattr(hf_layer.self_attn, "q_norm"):
layer.self_attention.q_layernorm.weight.copy_(hf_layer.self_attn.q_norm.weight.data)
layer.self_attention.k_layernorm.weight.copy_(hf_layer.self_attn.k_norm.weight.data)
Expand All @@ -145,7 +146,72 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
model.output_layer.weight.copy_(hf_model.lm_head.weight)


def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False):
@torch.no_grad()
def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig):
warnings.warn("MTP model is not supported yet", stacklevel=2)

def safe_copy(
src_tensor: torch.Tensor,
dst_tensor: torch.Tensor,
skip_dtype_assert: bool = False,
):
if not skip_dtype_assert:
if src_tensor.dtype != dst_tensor.dtype:
raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}")
assert src_tensor.shape == dst_tensor.shape
dst_tensor.data.copy_(src_tensor.data)
return src_tensor.numel()

model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)):
print(layer_idx)
layer.input_layernorm.weight.copy_(hf_layer.input_layernorm.weight)

if hf_config.q_lora_rank is None:
layer.self_attention.linear_q_proj.weight.copy_(hf_layer.self_attn.q_proj.weight)
else:
layer.self_attention.linear_q_down_proj.weight.copy_(hf_layer.self_attn.q_a_proj.weight)
layer.self_attention.linear_q_up_proj.weight.copy_(hf_layer.self_attn.q_b_proj.weight)
layer.self_attention.linear_q_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.q_a_layernorm.weight)

layer.self_attention.linear_kv_down_proj.weight.copy_(hf_layer.self_attn.kv_a_proj_with_mqa.weight)
layer.self_attention.linear_kv_up_proj.weight.copy_(hf_layer.self_attn.kv_b_proj.weight)
layer.self_attention.linear_kv_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.kv_a_layernorm.weight)
layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight)

if not hasattr(layer.mlp, "router"):
layer.mlp.linear_fc1.layer_norm_weight.copy_(hf_layer.post_attention_layernorm.weight)
layer.mlp.linear_fc1.weight.copy_(torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]))
layer.mlp.linear_fc2.weight.copy_(hf_layer.mlp.down_proj.weight)
else:
layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight)
# NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \
# recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%)
safe_copy(hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True)
if tfconfig.moe_grouped_gemm:
for i, hf_expert in enumerate(hf_layer.mlp.experts):
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(i))
linear_fc1_weighti.copy_(fc1_weight)
linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(i))
linear_fc2_weighti.copy_(hf_expert.down_proj.weight)
else:
for i, hf_expert in enumerate(hf_layer.mlp.experts):
expert = layer.mlp.experts.local_experts[i]
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
expert.linear_fc1.weight.copy_(fc1_weight)
expert.linear_fc2.weight.copy_(hf_expert.down_proj.weight)
layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight)
shared_fc1_weight = torch.cat([hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight])
layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight)
layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_experts.down_proj.weight)

model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight)
if not hf_config.tie_word_embeddings:
model.output_layer.weight.copy_(hf_model.lm_head.weight)


def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False):
os.makedirs(output_path, exist_ok=True)
if len(os.listdir(output_path)) > 0 and not test:
print(f"Output path {output_path} is not empty, skipping conversion")
Expand Down Expand Up @@ -200,12 +266,14 @@ def megatron_model_provider(pre_process, post_process):
warnings.simplefilter("ignore")

# init hf model
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16)
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code)
hf_state_dict = hf_model.state_dict()

# load hf state dict to megatron model
if "Qwen2MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
elif "DeepseekV3ForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)
elif "Qwen3MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
else:
Expand All @@ -232,4 +300,4 @@ def megatron_model_provider(pre_process, post_process):

if __name__ == "__main__":
args = _init_args()
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test)
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code)
93 changes: 86 additions & 7 deletions verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from transformers import PretrainedConfig


def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> dict:
"""
Create a base TransformerConfig with common parameters across different model architectures.
TODO: (ycl) use dataclass or converter config?
Expand Down Expand Up @@ -82,19 +82,20 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
base_config.update(override_transformer_config_kwargs)
print(f"Overridden TF init config: {base_config}")

return TransformerConfig(**base_config)
return base_config


def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
# for LlamaForCausalLM or Qwen2ForCausalLM
qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False)
qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False

return _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs)
args = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs)
return TransformerConfig(**args)


def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
return _get_base_transformer_config(
args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
Expand All @@ -121,10 +122,11 @@ def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype,
add_qkv_bias=True,
**override_transformer_config_kwargs,
)
return TransformerConfig(**args)


def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
return _get_base_transformer_config(
args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
Expand All @@ -150,10 +152,11 @@ def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype,
bias_dropout_fusion=True,
**override_transformer_config_kwargs,
)
return TransformerConfig(**args)


def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
return _get_base_transformer_config(
args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
Expand All @@ -178,11 +181,87 @@ def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype,
qk_layernorm=True,
**override_transformer_config_kwargs,
)
return TransformerConfig(**args)


def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> MLATransformerConfig:
# DeepseekV3ForCausalLM
raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet")
from megatron.core.transformer.enums import AttnBackend

from .patch_v012 import apply_patch

apply_patch()

mla_rope_config = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 1,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "rope",
}
if "rope_scaling" in hf_config and hf_config.rope_scaling is not None:
mla_rope_config.update(hf_config.rope_scaling)
moe_layer_freq = [1] * hf_config.num_hidden_layers
for i in range(hf_config.first_k_dense_replace):
moe_layer_freq[i] = 0

args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
add_bias_linear=False,
attention_backend=AttnBackend.fused,
Copy link
Copy Markdown
Collaborator

@ccclyu ccclyu May 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is AttnBackend.fused specific to deepseek v3 model? is AttnBackend.auto enough here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When feed with AttnBackend.auto, the TE would use flash, but flash is not implemented for MLA, the error info is
ValueError: No dot product attention backend is available for the provided inputs. Please run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to find out the reasons for disabling all backends.

bf16=dtype is torch.bfloat16,
layernorm_epsilon=hf_config.rms_norm_eps,
ffn_hidden_size=hf_config.intermediate_size,
qk_layernorm=True,
# moe specific
moe_ffn_hidden_size=hf_config.moe_intermediate_size,
moe_token_dispatcher_type="alltoall",
moe_router_bias_update_rate=0.001,
moe_router_enable_expert_bias=True,
moe_router_topk=hf_config.num_experts_per_tok,
num_moe_experts=hf_config.n_routed_experts,
moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts,
moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001),
moe_router_load_balancing_type="seq_aux_loss",
moe_shared_expert_overlap=True,
# moe_permute_fusion=True, # need TE 2.1+
moe_grouped_gemm=True,
moe_router_score_function="sigmoid",
moe_router_pre_softmax=True,
moe_router_topk_scaling_factor=hf_config.routed_scaling_factor,
moe_layer_freq=moe_layer_freq,
# MLA
q_lora_rank=hf_config.q_lora_rank,
kv_lora_rank=hf_config.kv_lora_rank,
qk_head_dim=hf_config.qk_nope_head_dim,
qk_pos_emb_head_dim=hf_config.qk_rope_head_dim,
v_head_dim=hf_config.v_head_dim,
rotary_base=hf_config.rope_theta,
rotary_scaling_factor=mla_rope_config["factor"],
rope_type=mla_rope_config["type"],
mscale=mla_rope_config["mscale"],
mscale_all_dim=mla_rope_config["mscale_all_dim"],
max_position_embeddings=mla_rope_config["original_max_position_embeddings"],
beta_fast=mla_rope_config["beta_fast"],
beta_slow=mla_rope_config["beta_slow"],
# mcore 0.12 moe
moe_router_dtype="fp64",
disable_bf16_reduced_precision_matmul=True,
# other
# deallocate_pipeline_outputs=True,
# gradient_accumulation_fusion=True,
persist_layer_norm=True,
bias_activation_fusion=True,
bias_dropout_fusion=True,
**override_transformer_config_kwargs,
)
transformer_config = MLATransformerConfig(**args)

return transformer_config


def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
Expand Down
Loading
Loading