-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[mcore] moonlight (small model with deepseekv3 arch) #1284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6155a65
8869168
9216811
6c46c2a
a9c21cf
e709dc3
0775d36
6113b10
bbf41b6
5f8d8a0
d2376ec
39e0658
57d9671
5181a99
7b66d82
941ab95
267a119
8801841
e5d6ca0
7f84424
ae550a8
4c1be5f
eedee64
6f99304
39ce67b
70c1201
579e831
3b82afc
e0d43ff
8757327
b73f16a
81f474b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| set -x | ||
|
|
||
| # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: | ||
| # export VLLM_ATTENTION_BACKEND=XFORMERS | ||
| export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping | ||
|
|
||
|
|
||
| # 0. download the model | ||
| huggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct | ||
|
|
||
| # 1. convert the model to mcore format | ||
| # change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path | ||
| HF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct | ||
| DIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct | ||
| python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH | ||
|
|
||
|
|
||
| # 2. run the script | ||
| gsm8k_train_path=$HOME/data/gsm8k/train.parquet | ||
| gsm8k_test_path=$HOME/data/gsm8k/test.parquet | ||
| train_files=$gsm8k_train_path | ||
| test_files=$gsm8k_test_path | ||
|
|
||
| ALL_OFFLOAD=${ALL_OFFLOAD:-False} | ||
| COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD} | ||
| COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD} | ||
| COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD} | ||
|
|
||
| ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} | ||
| ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} | ||
| ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} | ||
| REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} | ||
| CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} | ||
| CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD} | ||
| CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD} | ||
| RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD} | ||
|
|
||
|
|
||
| NODES=4 | ||
| PP=2 | ||
| TP=8 | ||
| EP=8 | ||
| ETP=1 | ||
| VLLM_TP=4 | ||
|
|
||
| # RAY_ADDRESS='auto' ray job submit --working-dir . -- | ||
| python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\ | ||
| algorithm.adv_estimator=gae \ | ||
| data.train_files="$train_files" \ | ||
| data.val_files="$test_files" \ | ||
| data.train_batch_size=1024 \ | ||
| data.max_prompt_length=1024 \ | ||
| data.max_response_length=512 \ | ||
| data.filter_overlong_prompts=True \ | ||
| data.truncation='error' \ | ||
| +data.trust_remote_code=True \ | ||
| actor_rollout_ref.model.path=$LLM \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=256 \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ | ||
| actor_rollout_ref.actor.use_kl_loss=False \ | ||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \ | ||
| actor_rollout_ref.rollout.name=vllm \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \ | ||
| critic.optim.lr=1e-5 \ | ||
| critic.model.path=$LLM \ | ||
| critic.model.enable_gradient_checkpointing=False \ | ||
| critic.ppo_micro_batch_size_per_gpu=4 \ | ||
| algorithm.use_kl_in_reward=False \ | ||
| trainer.critic_warmup=0 \ | ||
| trainer.logger=['console','wandb'] \ | ||
| trainer.project_name='verl_megatron_gsm8k_examples' \ | ||
| trainer.experiment_name='moonlight_16b_a3b_instruct_1node' \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes=$NODES \ | ||
| trainer.save_freq=-1 \ | ||
| trainer.test_freq=5 \ | ||
| actor_rollout_ref.model.trust_remote_code=True \ | ||
| critic.model.trust_remote_code=True \ | ||
| +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=13 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \ | ||
| actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \ | ||
| actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \ | ||
| critic.megatron.pipeline_model_parallel_size=$PP \ | ||
| actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \ | ||
| actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \ | ||
| critic.megatron.tensor_model_parallel_size=$TP \ | ||
| actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \ | ||
| actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \ | ||
| critic.megatron.expert_model_parallel_size=$EP \ | ||
| actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \ | ||
| actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \ | ||
| critic.megatron.expert_tensor_parallel_size=$ETP \ | ||
| actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \ | ||
| actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \ | ||
| actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \ | ||
| actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \ | ||
| critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \ | ||
| critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \ | ||
| critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \ | ||
| actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \ | ||
| actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \ | ||
| critic.megatron.use_dist_checkpointing=True \ | ||
| actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ | ||
| actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ | ||
| critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \ | ||
| trainer.val_before_train=False \ | ||
| trainer.total_epochs=100 $@ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
| from transformers import PretrainedConfig | ||
|
|
||
|
|
||
| def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: | ||
| def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> dict: | ||
| """ | ||
| Create a base TransformerConfig with common parameters across different model architectures. | ||
| TODO: (ycl) use dataclass or converter config? | ||
|
|
@@ -82,19 +82,20 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype | |
| base_config.update(override_transformer_config_kwargs) | ||
| print(f"Overridden TF init config: {base_config}") | ||
|
|
||
| return TransformerConfig(**base_config) | ||
| return base_config | ||
|
|
||
|
|
||
| def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: | ||
| # for LlamaForCausalLM or Qwen2ForCausalLM | ||
| qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) | ||
| qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False | ||
|
|
||
| return _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs) | ||
| args = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs) | ||
| return TransformerConfig(**args) | ||
|
|
||
|
|
||
| def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: | ||
| return _get_base_transformer_config( | ||
| args = _get_base_transformer_config( | ||
| hf_config=hf_config, | ||
| dtype=dtype, | ||
| use_cpu_initialization=False, | ||
|
|
@@ -121,10 +122,11 @@ def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, | |
| add_qkv_bias=True, | ||
| **override_transformer_config_kwargs, | ||
| ) | ||
| return TransformerConfig(**args) | ||
|
|
||
|
|
||
| def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: | ||
| return _get_base_transformer_config( | ||
| args = _get_base_transformer_config( | ||
| hf_config=hf_config, | ||
| dtype=dtype, | ||
| use_cpu_initialization=False, | ||
|
|
@@ -150,10 +152,11 @@ def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, | |
| bias_dropout_fusion=True, | ||
| **override_transformer_config_kwargs, | ||
| ) | ||
| return TransformerConfig(**args) | ||
|
|
||
|
|
||
| def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: | ||
| return _get_base_transformer_config( | ||
| args = _get_base_transformer_config( | ||
| hf_config=hf_config, | ||
| dtype=dtype, | ||
| use_cpu_initialization=False, | ||
|
|
@@ -178,11 +181,87 @@ def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, | |
| qk_layernorm=True, | ||
| **override_transformer_config_kwargs, | ||
| ) | ||
| return TransformerConfig(**args) | ||
|
|
||
|
|
||
| def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> MLATransformerConfig: | ||
| # DeepseekV3ForCausalLM | ||
| raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet") | ||
| from megatron.core.transformer.enums import AttnBackend | ||
|
|
||
| from .patch_v012 import apply_patch | ||
|
|
||
| apply_patch() | ||
|
|
||
| mla_rope_config = { | ||
| "beta_fast": 32, | ||
| "beta_slow": 1, | ||
| "factor": 1, | ||
| "mscale": 1.0, | ||
| "mscale_all_dim": 1.0, | ||
| "original_max_position_embeddings": 4096, | ||
| "type": "rope", | ||
| } | ||
| if "rope_scaling" in hf_config and hf_config.rope_scaling is not None: | ||
| mla_rope_config.update(hf_config.rope_scaling) | ||
| moe_layer_freq = [1] * hf_config.num_hidden_layers | ||
| for i in range(hf_config.first_k_dense_replace): | ||
| moe_layer_freq[i] = 0 | ||
|
|
||
| args = _get_base_transformer_config( | ||
| hf_config=hf_config, | ||
| dtype=dtype, | ||
| use_cpu_initialization=False, | ||
| add_bias_linear=False, | ||
| attention_backend=AttnBackend.fused, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When feed with |
||
| bf16=dtype is torch.bfloat16, | ||
| layernorm_epsilon=hf_config.rms_norm_eps, | ||
| ffn_hidden_size=hf_config.intermediate_size, | ||
| qk_layernorm=True, | ||
| # moe specific | ||
| moe_ffn_hidden_size=hf_config.moe_intermediate_size, | ||
| moe_token_dispatcher_type="alltoall", | ||
| moe_router_bias_update_rate=0.001, | ||
| moe_router_enable_expert_bias=True, | ||
| moe_router_topk=hf_config.num_experts_per_tok, | ||
| num_moe_experts=hf_config.n_routed_experts, | ||
| moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts, | ||
| moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001), | ||
| moe_router_load_balancing_type="seq_aux_loss", | ||
| moe_shared_expert_overlap=True, | ||
| # moe_permute_fusion=True, # need TE 2.1+ | ||
| moe_grouped_gemm=True, | ||
| moe_router_score_function="sigmoid", | ||
| moe_router_pre_softmax=True, | ||
| moe_router_topk_scaling_factor=hf_config.routed_scaling_factor, | ||
| moe_layer_freq=moe_layer_freq, | ||
| # MLA | ||
| q_lora_rank=hf_config.q_lora_rank, | ||
| kv_lora_rank=hf_config.kv_lora_rank, | ||
| qk_head_dim=hf_config.qk_nope_head_dim, | ||
| qk_pos_emb_head_dim=hf_config.qk_rope_head_dim, | ||
| v_head_dim=hf_config.v_head_dim, | ||
| rotary_base=hf_config.rope_theta, | ||
| rotary_scaling_factor=mla_rope_config["factor"], | ||
| rope_type=mla_rope_config["type"], | ||
| mscale=mla_rope_config["mscale"], | ||
| mscale_all_dim=mla_rope_config["mscale_all_dim"], | ||
| max_position_embeddings=mla_rope_config["original_max_position_embeddings"], | ||
| beta_fast=mla_rope_config["beta_fast"], | ||
| beta_slow=mla_rope_config["beta_slow"], | ||
| # mcore 0.12 moe | ||
| moe_router_dtype="fp64", | ||
| disable_bf16_reduced_precision_matmul=True, | ||
| # other | ||
| # deallocate_pipeline_outputs=True, | ||
| # gradient_accumulation_fusion=True, | ||
| persist_layer_norm=True, | ||
| bias_activation_fusion=True, | ||
| bias_dropout_fusion=True, | ||
| **override_transformer_config_kwargs, | ||
| ) | ||
| transformer_config = MLATransformerConfig(**args) | ||
|
|
||
| return transformer_config | ||
|
|
||
|
|
||
| def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should
trust_remote_codebe set in themodelper ppo_megatron_trainer.yaml?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh data preprocessing might need this as well. please ignore this if i misunderstand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this another topic beyond supporting moonlight? would it be better if we commit another small PR for the config file modification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree. we should track any change in the config and keep it consistent.