-
Notifications
You must be signed in to change notification settings - Fork 356
[train][2/N] Support for Megatron PP + CP for R3 #1335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
dc93a03
d2b2d53
8cb6155
d1712bd
3da803a
f7362c3
5d74b89
5855ee8
153a2d0
c46287e
33b3083
7066bf4
7744e69
bfcd8db
daf5752
493387c
089ee8b
0f18c70
1a73422
04d29f3
2619ed4
9a0088f
4cf3a48
4b844c0
44aab7d
be50ba6
f7c3086
bf84467
0a2fdb2
73741a3
43468a8
a4b9228
a841448
5def4cb
266b9c1
5628f22
2bc0ee0
fff1c6a
9463aec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,107 @@ | ||||||||
| set -x | ||||||||
|
|
||||||||
| # Fully async GRPO training+generation for Qwen2.5-1.5B-Instruct on GSM8K. | ||||||||
| # This bash script is copied from examples/async/async_run_gsm8k.sh, except for: | ||||||||
| # - running examples.train.fully_async.main_fully_async | ||||||||
| # - setting the generator.batched=false. | ||||||||
| # - colocate_all=false | ||||||||
| # - the various generator configs at the end (http, chat template, etc.) | ||||||||
|
|
||||||||
| # uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k | ||||||||
| # export WANDB_API_KEY=<your_key_here> | ||||||||
| # bash examples/train/router_replay/router_replay_fully_async.sh | ||||||||
|
|
||||||||
| # NOTE (sumanthrh): `micro_train_batch_size_per_gpu` and `micro_forward_batch_size_per_gpu` can be tuned | ||||||||
|
|
||||||||
| # You can override the default values with e.g.: `NUM_GPUS=1 bash examples/train/fully_async/fully_async_run_gsm8k.sh`. | ||||||||
|
|
||||||||
| : "${DATA_DIR:="$HOME/data/gsm8k"}" | ||||||||
| : "${NUM_INFERENCE_GPUS:=4}" | ||||||||
| : "${NUM_POLICY_GPUS:=4}" | ||||||||
| : "${LOGGER:=wandb}" # change to "console" to print to stdout / or use wandb | ||||||||
|
|
||||||||
| : "${INFERENCE_BACKEND:=vllm}" | ||||||||
|
|
||||||||
| # Fully async specific configuration knobs: | ||||||||
| : "${MINI_BATCH_SIZE:=256}" | ||||||||
| : "${MAX_STALENESS_STEPS:=4}" | ||||||||
| : "${NUM_PARALLEL_GENERATION_WORKERS:=$(( MINI_BATCH_SIZE * (MAX_STALENESS_STEPS + 1) ))}" | ||||||||
|
|
||||||||
| TIS_TYPE=token | ||||||||
| TIS_IMP_RATIO_CAP=2.0 | ||||||||
|
|
||||||||
| # moonlight16b | ||||||||
| MODEL_NAME="moonshotai/Moonlight-16B-A3B-Instruct" | ||||||||
|
|
||||||||
| NUM_NODES=1 | ||||||||
| NUM_GPUS=8 | ||||||||
|
|
||||||||
| MEGATRON_TP=1 | ||||||||
| MEGATRON_PP=2 | ||||||||
| MEGATRON_CP=2 | ||||||||
| MEGATRON_EP=4 | ||||||||
| MEGATRON_ETP=1 | ||||||||
|
|
||||||||
| NUM_INFERENCE_ENGINES=1 | ||||||||
| INFERENCE_ENGINE_TP=4 | ||||||||
|
|
||||||||
| # router replay (r3) | ||||||||
| ROUTER_REPLAY=true | ||||||||
| DISTRIBUTED_EXECUTION_BACKEND="mp" | ||||||||
|
|
||||||||
| RUN_NAME=gsm8k-fully-async-moonlight16b-a3b-useTIS_${TIS_TYPE}-maxStale${MAX_STALENESS_STEPS}-numCon${NUM_PARALLEL_GENERATION_WORKERS}-${NUM_POLICY_GPUS}train${NUM_INFERENCE_GPUS}gen_r3 | ||||||||
|
|
||||||||
| uv run --isolated --extra fsdp -m examples.train.fully_async.main_fully_async \ | ||||||||
| data.train_data="['$DATA_DIR/train.parquet']" \ | ||||||||
| data.val_data="['$DATA_DIR/validation.parquet']" \ | ||||||||
| trainer.fully_async.max_staleness_steps=${MAX_STALENESS_STEPS} \ | ||||||||
| trainer.fully_async.num_parallel_generation_workers=${NUM_PARALLEL_GENERATION_WORKERS} \ | ||||||||
| trainer.algorithm.advantage_estimator="grpo" \ | ||||||||
| trainer.algorithm.off_policy_correction.tis_ratio_type=$TIS_TYPE \ | ||||||||
| trainer.algorithm.off_policy_correction.token_tis_ratio_clip_high=$TIS_IMP_RATIO_CAP \ | ||||||||
| trainer.policy.model.path=$MODEL_NAME \ | ||||||||
| trainer.placement.colocate_all=false \ | ||||||||
| trainer.strategy=fsdp2 \ | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Duplicate Line 64 sets
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||
| trainer.placement.policy_num_gpus_per_node=$NUM_POLICY_GPUS \ | ||||||||
| trainer.placement.critic_num_gpus_per_node=$NUM_POLICY_GPUS \ | ||||||||
| trainer.placement.ref_num_gpus_per_node=$NUM_POLICY_GPUS \ | ||||||||
| generator.inference_engine.num_engines=$NUM_INFERENCE_GPUS \ | ||||||||
| generator.inference_engine.distributed_executor_backend=$DISTRIBUTED_EXECUTION_BACKEND \ | ||||||||
| generator.inference_engine.enable_return_routed_experts=$ROUTER_REPLAY \ | ||||||||
| generator.inference_engine.tensor_parallel_size=1 \ | ||||||||
|
erictang000 marked this conversation as resolved.
Outdated
|
||||||||
| trainer.epochs=20 \ | ||||||||
| trainer.eval_batch_size=1024 \ | ||||||||
| trainer.eval_before_train=false \ | ||||||||
| trainer.eval_interval=4 \ | ||||||||
| trainer.strategy=megatron \ | ||||||||
| trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \ | ||||||||
| trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \ | ||||||||
| trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \ | ||||||||
| trainer.policy.megatron_config.expert_model_parallel_size=$MEGATRON_EP \ | ||||||||
| trainer.policy.megatron_config.expert_tensor_parallel_size=$MEGATRON_ETP \ | ||||||||
| trainer.policy.megatron_config.moe_enable_routing_replay=$ROUTER_REPLAY \ | ||||||||
| trainer.update_epochs_per_batch=1 \ | ||||||||
| trainer.train_batch_size=${MINI_BATCH_SIZE} \ | ||||||||
| trainer.policy_mini_batch_size=${MINI_BATCH_SIZE} \ | ||||||||
| trainer.micro_forward_batch_size_per_gpu=8 \ | ||||||||
| trainer.micro_train_batch_size_per_gpu=8 \ | ||||||||
| trainer.ckpt_interval=10 \ | ||||||||
| trainer.max_prompt_length=512 \ | ||||||||
| generator.sampling_params.max_generate_length=1024 \ | ||||||||
| trainer.policy.optimizer_config.lr=1.0e-6 \ | ||||||||
| trainer.algorithm.use_kl_loss=true \ | ||||||||
| generator.inference_engine.backend=$INFERENCE_BACKEND \ | ||||||||
| generator.inference_engine.run_engines_locally=true \ | ||||||||
| generator.inference_engine.weight_sync_backend=nccl \ | ||||||||
| generator.inference_engine.async_engine=true \ | ||||||||
| generator.batched=false \ | ||||||||
| environment.env_class=gsm8k \ | ||||||||
| generator.n_samples_per_prompt=5 \ | ||||||||
| generator.inference_engine.gpu_memory_utilization=0.8 \ | ||||||||
| trainer.logger="$LOGGER" \ | ||||||||
| trainer.project_name="gsm8k-async" \ | ||||||||
| trainer.run_name=${RUN_NAME} \ | ||||||||
| trainer.resume_mode=latest \ | ||||||||
| trainer.ckpt_path="$HOME/ckpts/${RUN_NAME}" \ | ||||||||
| generator.inference_engine.enforce_eager=true \ | ||||||||
| $@ | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -374,7 +374,9 @@ def init_configs( | |
|
|
||
| self.strategy.hf_config = hf_config | ||
| self.tokenizer = tokenizer | ||
| self.enable_router_replay = megatron_config.moe_enable_routing_replay | ||
| self.enable_router_replay = transformer_config_kwargs.get( | ||
|
devpatelio marked this conversation as resolved.
Outdated
|
||
| "moe_enable_routing_replay", megatron_config.moe_enable_routing_replay | ||
| ) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Validation for The PR changes Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| def configure_lora(self, lora_config, lora_type: Optional[str] = "lora"): | ||
| if lora_type == "lora": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.