-
Notifications
You must be signed in to change notification settings - Fork 1k
Refactoring training inference importance sampling with seqeunce/geometry level #429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 63 commits
Commits
Show all changes
70 commits
Select commit
Hold shift + click to select a range
30e01df
fix quick start docs in zh/en
zhaochenyang20 82695df
Update run-qwen3-30B-A3B.sh
zhaochenyang20 fc0ec27
[Importance sampling] seperate importance sampling as a function
zhaochenyang20 2ee67b6
Merge branch 'THUDM:main' into importance_sampling
zhaochenyang20 35f414c
fix lint
zhaochenyang20 901da8d
Merge branch 'THUDM:main' into main
zhaochenyang20 19649f9
fix lint of main
zhaochenyang20 81ffa47
adding pre-commit as a CI flow
zhaochenyang20 243cc44
only pre-commit with yml
zhaochenyang20 0b164f3
fix up pre commit
zhaochenyang20 0717171
unigy local pre-commit with third party
zhaochenyang20 0d27716
rebase with main for lint
zhaochenyang20 984c724
fix lint with main
zhaochenyang20 0695c73
adding kl metircs
zhaochenyang20 cac9fa9
fix type custing for metrics
zhaochenyang20 e7cf0c2
comments to compute_tis_weights
zhaochenyang20 2924975
[lint] tis comment
zhaochenyang20 7963809
refactor clip mode in sequence level
zhaochenyang20 6f36eef
[test] geometric level
zhaochenyang20 3cc4982
adding metrics to new tis
zhaochenyang20 d60a595
[log probs in 1D]
zhaochenyang20 5cac6e0
stash with main
zhaochenyang20 71194c3
slice tis with slice_log_prob_with_cp
zhaochenyang20 92e6e97
[todo] filter out catastrophic tokens
zhaochenyang20 dbd27d7
before metrics
zhaochenyang20 083c300
fix with rollout log probs
zhaochenyang20 ff68b32
[wait for the metircs]
zhaochenyang20 afc1cd7
logging a whole sequence
zhaochenyang20 4ab1aee
rebase with main, ready to metircs global
zhaochenyang20 f31c9cd
tmp commit
guapisolo 46adaef
make good abstraction to metrics
guapisolo 2d9ffd8
fix seq concat problem
guapisolo f44a7d1
fix key not exist
guapisolo 8d77db6
update metrics calculation
guapisolo 88f24dd
fix cp>1 with scatter_with_cp func and update argument
guapisolo 3b68708
move cp scatter to loss.py
guapisolo 74ea791
revert fsdp tis
guapisolo 7e4ee12
fix argument names from tis to is
guapisolo 3fd863b
Update clip bound name and veto impl to fully follow the paper
guapisolo c047b2c
upd clip_mask to mask
guapisolo ca6e90e
upd arg name, fix small bug, delete scatter_with_cp
guapisolo c0aeb32
move cp logic to tis.py
guapisolo 4444634
add a qwen3-4b tis sample
guapisolo c4441ad
fix code comments, file name
guapisolo 293cbce
fix small bug and arg setting
guapisolo d750125
fix small bug in veto, add mask, and add a script
guapisolo ce01008
small change fsdp
guapisolo 43d04bd
fix parameter type in some functions and update comments
guapisolo 1074d7a
Update train_infer_is.py
zhaochenyang20 003c002
Merge pull request #1 from guapisolo/tis
zhaochenyang20 99224fa
remove sh
zhaochenyang20 8a8c44c
logging a whole sequence
zhaochenyang20 4326b28
rebase with main
zhaochenyang20 52a401e
Update run-qwen3-30B-A3B.sh
zhaochenyang20 ac4e63a
create test scripts
zhaochenyang20 4df0724
revert change in qwen3 30B sh
zhaochenyang20 7b01369
remove two tests sh
zhaochenyang20 4455479
add kl metrics
guapisolo 3319f38
fix comment
guapisolo ce541be
Merge branch 'main' into importance_sampling
zhaochenyang20 1d35f45
Merge pull request #3 from guapisolo/tis
zhaochenyang20 bf8df31
adding kl metrics
zhaochenyang20 e7a88e3
Merge branch 'main' into importance_sampling
zhaochenyang20 8249244
Merge branch 'main' into cytis
guapisolo fa4606d
revert changes to use_tis
guapisolo a66753f
move to examples and use yaml for custom args parsing
guapisolo 5809cd1
Merge pull request #5 from guapisolo/tis
zhaochenyang20 fffeab9
fix small bug
guapisolo 9e4bf7c
remove tis file
guapisolo eb7711c
give vanilla tis func
guapisolo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,7 +79,7 @@ GRPO_ARGS=( | |
| --eps-clip 0.2 | ||
| --eps-clip-high 0.28 | ||
|
|
||
| --use-tis | ||
| --use-train-infer-is | ||
| ) | ||
|
|
||
| OPTIMIZER_ARGS=( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,7 +86,7 @@ GRPO_ARGS=( | |
| --eps-clip 1e-4 | ||
| --eps-clip-high 2e-4 | ||
|
|
||
| --use-tis | ||
| --use-train-infer-is | ||
| ) | ||
|
|
||
| OPTIMIZER_ARGS=( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| #!/bin/bash | ||
|
|
||
| # for rerun the task | ||
| pkill -9 sglang | ||
| sleep 3 | ||
| ray stop --force | ||
| pkill -9 ray | ||
| pkill -9 python | ||
| sleep 3 | ||
| pkill -9 ray | ||
| pkill -9 python | ||
| pkill -9 redis | ||
|
|
||
| set -ex | ||
|
|
||
| # will prevent ray from buffering stdout/stderr | ||
| export PYTHONBUFFERED=16 | ||
|
|
||
| NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) | ||
| if [ "$NVLINK_COUNT" -gt 0 ]; then | ||
| HAS_NVLINK=1 | ||
| else | ||
| HAS_NVLINK=0 | ||
| fi | ||
| echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" | ||
|
|
||
| SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" | ||
| source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" | ||
|
|
||
| CKPT_ARGS=( | ||
| --hf-checkpoint /root/Qwen3-30B-A3B | ||
| #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 | ||
| --ref-load /root/Qwen3-30B-A3B_torch_dist | ||
| # --load /root/Qwen3-30B-A3B_slime/ | ||
| # --save /root/Qwen3-30B-A3B_slime/ | ||
| # --save-interval 20 | ||
| ) | ||
|
|
||
| ROLLOUT_ARGS=( | ||
| --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl | ||
| --input-key prompt | ||
| --label-key label | ||
| --apply-chat-template | ||
| --rollout-shuffle | ||
| --rm-type deepscaler | ||
| --num-rollout 3000 | ||
| --rollout-batch-size 8 | ||
| --n-samples-per-prompt 4 | ||
| --rollout-max-response-len 8192 | ||
| --rollout-temperature 0.8 | ||
|
|
||
| --global-batch-size 32 | ||
| --balance-data | ||
| ) | ||
|
|
||
| EVAL_ARGS=( | ||
| # --eval-interval 20 | ||
| --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl | ||
| --n-samples-per-eval-prompt 16 | ||
| --eval-max-response-len 16384 | ||
| --eval-top-p 0.7 | ||
| ) | ||
|
|
||
| PERF_ARGS=( | ||
| --tensor-model-parallel-size 2 | ||
| --sequence-parallel | ||
| --pipeline-model-parallel-size 1 | ||
| --context-parallel-size 2 | ||
| --expert-model-parallel-size 4 | ||
| --expert-tensor-parallel-size 1 | ||
|
|
||
| --recompute-granularity full | ||
| --recompute-method uniform | ||
| --recompute-num-layers 1 | ||
|
|
||
| # --micro-batch-size 1 | ||
| --use-dynamic-batch-size | ||
| --max-tokens-per-gpu 20480 | ||
| ) | ||
|
|
||
| GRPO_ARGS=( | ||
| --advantage-estimator grpo | ||
| --use-kl-loss | ||
| --kl-loss-coef 0.00 | ||
| --kl-loss-type low_var_kl | ||
| --entropy-coef 0.00 | ||
| --eps-clip 0.2 | ||
| --eps-clip-high 0.28 | ||
| ) | ||
|
|
||
| IS_ARGS=( | ||
| --use-train-infer-is | ||
| --train-infer-is-level geometric | ||
| --train-infer-is-mode mask | ||
| --train-infer-is-lower-bound 0.5 | ||
| --train-infer-is-upper-bound 2.0 | ||
| --train-infer-is-veto-threshold 1e-3 | ||
| ) | ||
|
|
||
| OPTIMIZER_ARGS=( | ||
| --optimizer adam | ||
| --lr 1e-6 | ||
| --lr-decay-style constant | ||
| --weight-decay 0.1 | ||
| --adam-beta1 0.9 | ||
| --adam-beta2 0.98 | ||
|
|
||
| --optimizer-cpu-offload | ||
| --overlap-cpu-optimizer-d2h-h2d | ||
| --use-precision-aware-optimizer | ||
| ) | ||
|
|
||
| WANDB_ARGS=( | ||
| --use-wandb | ||
| --wandb-project slime-dev | ||
| --wandb-group qwen3-30B-A3B-TIS | ||
| --wandb-run-id qwen3-30B-A3B-TIS-sequence | ||
| --wandb-key ${WANDB_KEY} | ||
| ) | ||
|
|
||
| SGLANG_ARGS=( | ||
| --rollout-num-gpus-per-engine 4 | ||
| --sglang-mem-fraction-static 0.7 | ||
| --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) | ||
| ) | ||
|
|
||
| MISC_ARGS=( | ||
| # default dropout in megatron is 0.1 | ||
| --attention-dropout 0.0 | ||
| --hidden-dropout 0.0 | ||
| # should be good for model performance | ||
| --accumulate-allreduce-grads-in-fp32 | ||
| --attention-softmax-in-fp32 | ||
| # need to comment this when using model with MLA | ||
| --attention-backend flash | ||
| ) | ||
|
|
||
| # launch the master node of ray in container | ||
| export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} | ||
| ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 4 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 | ||
|
|
||
| # Build the runtime environment JSON with proper variable substitution | ||
| RUNTIME_ENV_JSON="{ | ||
| \"env_vars\": { | ||
| \"PYTHONPATH\": \"/root/Megatron-LM/\", | ||
| \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", | ||
| \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" | ||
| } | ||
| }" | ||
|
|
||
| ray job submit --address="http://127.0.0.1:8265" \ | ||
| --runtime-env-json="${RUNTIME_ENV_JSON}" \ | ||
| -- python3 train.py \ | ||
| --actor-num-nodes 1 \ | ||
| --actor-num-gpus-per-node 4 \ | ||
| --colocate \ | ||
| ${MODEL_ARGS[@]} \ | ||
| ${CKPT_ARGS[@]} \ | ||
| ${ROLLOUT_ARGS[@]} \ | ||
| ${OPTIMIZER_ARGS[@]} \ | ||
| ${GRPO_ARGS[@]} \ | ||
| ${WANDB_ARGS[@]} \ | ||
| ${PERF_ARGS[@]} \ | ||
| ${EVAL_ARGS[@]} \ | ||
| ${SGLANG_ARGS[@]} \ | ||
| ${MISC_ARGS[@]} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.