Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mtp/pp for v3 #9932

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
12 changes: 9 additions & 3 deletions llm/auto_parallel/deepseek-v3/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@
AutoTokenizer,
CosineAnnealingWithWarmupDecay,
DeepseekV2Config,
DeepseekV2PretrainingCriterion,
DeepseekV3ForCausalLMAuto,
LinearAnnealingWithWarmupDecay,
)
from paddlenlp.utils.log import logger

MODEL_CLASSES = {
"deepseekv3_auto": (DeepseekV2Config, DeepseekV3ForCausalLMAuto, DeepseekV2PretrainingCriterion),
"deepseekv3_auto": (DeepseekV2Config, DeepseekV3ForCausalLMAuto, None),
}


Expand Down Expand Up @@ -237,6 +236,10 @@ class ModelArguments:
default=False,
metadata={"help": "recompute_use_reentrant"},
)
first_k_dense_replace: int = field(
default=3,
metadata={"help": "first_k_dense_replace"},
)


def create_pretrained_dataset(
Expand Down Expand Up @@ -532,6 +535,7 @@ def main():
config.no_recompute_layers = model_args.no_recompute_layers
config.pp_recompute_interval = model_args.pp_recompute_interval
config.recompute_use_reentrant = model_args.recompute_use_reentrant
config.first_k_dense_replace = model_args.first_k_dense_replace

config.use_recompute = training_args.recompute
config.tensor_parallel_degree = training_args.tensor_parallel_degree
Expand All @@ -555,7 +559,9 @@ def main():

with paddle.LazyGuard():
model = model_class.from_config(config, dtype="float32")
criterion = criterion_class(config)
criterion = None
if criterion_class is not None:
criterion = criterion_class(config)

if training_args.recompute:

Expand Down
12 changes: 7 additions & 5 deletions llm/auto_parallel/deepseek-v3/run_pretrain_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export PYTHONPATH=../../../:$PYTHONPATH
to_static=0 # 是否开启动转静训练

python -u -m paddle.distributed.launch \
--gpus "0,1,2,3" \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name""_log" \
run_pretrain_auto.py \
--model_type "deepseekv3_auto" \
Expand All @@ -44,7 +44,7 @@ python -u -m paddle.distributed.launch \
--input_dir "./data" \
--output_dir "output/$task_name" \
--split 949,50,1 \
--max_seq_length 2048 \
--max_seq_length 512 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
Expand All @@ -53,9 +53,9 @@ python -u -m paddle.distributed.launch \
--fp16 0 \
--fp16_opt_level "O2" \
--scale_loss 1024 \
--pipeline_parallel_degree 1 \
--pipeline_parallel_degree 4 \
--tensor_parallel_degree 2 \
--sharding_parallel_degree 2 \
--sharding_parallel_degree 1 \
--learning_rate 0.0001 \
--min_learning_rate 0.00001 \
--max_steps 2 \
Expand All @@ -75,6 +75,8 @@ python -u -m paddle.distributed.launch \
--data_impl "mmap" \
--enable_auto_parallel 1 \
--max_grad_norm 1.0 \
--num_hidden_layers 1 \
--num_hidden_layers 6 \
--use_intermediate_api true \
--to_static $to_static \
--first_k_dense_replace 7 \
--hidden_size 1792 \
4 changes: 3 additions & 1 deletion paddlenlp/transformers/deepseek_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def __init__(
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_nextn_predict_layers=1,
num_nextn_predict_layers=0,
num_nextn_predict_lambda=1.0,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=None,
Expand Down Expand Up @@ -187,6 +188,7 @@ def __init__(
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_nextn_predict_lambda = num_nextn_predict_lambda
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
Expand Down
Loading
Loading