-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[sglang] feat: Add SGLang async multi-turn rollout with tool support #1037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 65 commits
d7a70a5
dc578c1
0edc910
59df7ab
63caefe
2eea830
4199cf8
ed46763
bc4fe76
d2522f0
e0b3bbb
4532d33
c415ad4
6a91862
c3831f2
7c063a9
3cbf99b
c730d48
983c28e
2cb290c
8d18e43
a88ddda
bdfd998
26ade88
c437093
8e7b00b
fd3b44c
1ff236c
77f0429
c0d710d
0e627d7
1a803ea
b215f02
9c11a1e
6062a8c
bc6043b
9abf201
28d6268
a556981
6a479cc
a29fac3
df51df2
9b50d00
f20aed7
67b15e6
15ff9c5
5a372b0
439d816
8ce6689
f60ac69
679c324
07e25e6
ddb2337
25e7814
6f6f94e
185cf74
c097cb1
e839465
560e2ca
b939760
0719d1f
9f17ace
6ecb8d4
80e3c2c
0729b07
caf1a08
a7ffe12
20f5696
6cfce61
6e7017d
4ea91e3
80be127
67f218f
092b7b9
2066bae
f2e8a6b
9efb2be
a195328
897edf4
10dff84
e19fcaa
377bf06
b6fa35b
9805be8
3aebb8c
712c2ef
a7236b8
5569eaf
48ad489
9e66aec
b078914
508d382
eb4f1cb
5886337
f79d706
a8eb256
7020582
3468b4c
70b7169
3501c46
828c05b
a8d6142
6b8552f
029d397
db781c0
2153056
83a735d
270b7b0
040237e
3e21b08
1105c3e
766a082
9912290
6ee96e1
ef7d527
52a7ee6
1d4d648
05413e3
6f0be26
095ab9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| name: sgl | ||
|
|
||
| on: | ||
| # Trigger the workflow on push or pull request, | ||
| # but only for the main branch | ||
| push: | ||
| branches: | ||
| - main | ||
| - v0.2.x | ||
| paths: | ||
| - "**/*.py" | ||
| - .github/workflows/vllm.yml | ||
| pull_request: | ||
eric-haibin-lin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| branches: | ||
| - main | ||
| - v0.2.x | ||
| paths: | ||
SwordFaith marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| - "**/*.py" | ||
| - "verl/trainer/config/*.yaml" | ||
| - .github/workflows/sgl.yml | ||
|
|
||
| # Cancel jobs on the same ref if a new one is triggered | ||
| concurrency: | ||
| group: ${{ github.workflow }}-${{ github.ref }} | ||
| cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | ||
|
|
||
| # Declare permissions just read content. | ||
| permissions: | ||
| contents: read | ||
|
|
||
| jobs: | ||
| sgl: | ||
| runs-on: [self-hosted, l20-0] | ||
SwordFaith marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| timeout-minutes: 20 # Increase this timeout value as needed | ||
| env: | ||
| HTTP_PROXY: ${{ secrets.PROXY_HTTP }} | ||
| HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} | ||
| NO_PROXY: "localhost,127.0.0.1" | ||
| HF_HUB_ENABLE_HF_TRANSFER: 1 | ||
SwordFaith marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| container: | ||
| image: ocss884/verl-sglang:ngc-th2.5.1-cu126-sglang0.4.4.post3 | ||
| options: --gpus all --shm-size=10g | ||
| steps: | ||
| - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
| with: | ||
| fetch-depth: 0 | ||
| - name: Install the current repository | ||
| run: | | ||
| pip3 install hf_transfer | ||
| pip3 install -e .[test,gpu,sglang] --no-deps | ||
| - name: Test the latest SGLang | ||
| run: | | ||
| cd tests/rollout | ||
| torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_sglang_spmd.py | ||
| - name: Test the latest SGLang async | ||
| run: | | ||
| cd tests/rollout | ||
| torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_async_sglang_spmd.py | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| Multi-turn Rollout Support | ||
SwordFaith marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ========================= | ||
|
|
||
| Basic Configuration | ||
| ~~~~~~~~~~~~~~~~~ | ||
|
|
||
| To enable multi-turn rollout, make sure to configure the following fields in your rollout configuration: | ||
|
|
||
| .. code-block:: yaml | ||
|
|
||
| actor_rollout_ref: | ||
| rollout: | ||
| multi_turn: True | ||
| name: "sglang_async" | ||
|
|
||
| These configuration activates the sglang_async engine for multi-turn interaction during rollout. | ||
|
|
||
| Custom Tool Configuration | ||
| ~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| For custom environment interaction tools, you can specify your tool configurations in a YAML file. | ||
| To do so, use the following format in your rollout config: | ||
|
|
||
| .. code-block:: yaml | ||
|
|
||
| actor_rollout_ref: | ||
| rollout: | ||
| tool_kwargs: | ||
| tools_config_file: <path_to_tool_yaml_file> | ||
|
|
||
| This allows integration of customized tool behaviors during actor rollout steps. You may refer to the GSM8KTool_example_configuration_ for guidance. | ||
|
|
||
| GSM8K Multi-turn Training Performance | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| See the training performance of multi-turn rollout on the GSM8K task HERE_. | ||
|
|
||
| .. _HERE: https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20 | ||
|
|
||
| .. _GSM8KTool_example_configuration: ../../examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| hydra: | ||
| searchpath: | ||
| - file://verl/trainer/config | ||
|
|
||
| defaults: | ||
| - ppo_trainer | ||
| - _self_ | ||
|
|
||
| data: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please follow https://github.com/volcengine/verl/blob/main/recipe/prime/config/prime_trainer.yaml and only add critical config diff based on the default trainer config from verl
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where do we have grpo yaml?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there isn't one but you can inherit config from ppo_trainer and only include values different from the base config |
||
| max_prompt_length: 1024 | ||
| max_response_length: 1024 | ||
| train_batch_size: 256 | ||
| return_raw_chat: True | ||
|
|
||
| actor_rollout_ref: | ||
| hybrid_engine: True | ||
| rollout: | ||
| name: sglang_async | ||
| multi_turn: | ||
| enable: True | ||
| max_turns: 5 | ||
| # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| tools: | ||
| - class_name: "verl.workers.tool.gsm8k_tool.Gsm8kTool" | ||
| config: {} | ||
| tool_schema: | ||
| type: "function" | ||
| function: | ||
| name: "calc_gsm8k_reward" | ||
| description: "A tool for calculating the reward of gsm8k. (1.0 if your answer is correct, 0.0 if your answer is incorrect)" | ||
| parameters: | ||
| type: "object" | ||
| properties: | ||
| answer: | ||
| type: "string" | ||
| description: "The model's answer to the GSM8K math problem, must be a digits" | ||
| required: ["answer"] |
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move it to data_preprocess. Rename gsm8k_tools.py |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # Copyright 2023-2024 SGLang Team and ModelBest Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ | ||
| Preprocess the GSM8k dataset to parquet format | ||
| """ | ||
|
|
||
| import argparse | ||
| import os | ||
| import re | ||
|
|
||
| import datasets | ||
|
|
||
| from verl.utils.hdfs_io import copy, makedirs | ||
|
|
||
|
|
||
| def extract_solution(solution_str): | ||
| solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) | ||
| assert solution is not None | ||
| final_solution = solution.group(0) | ||
| final_solution = final_solution.split("#### ")[1].replace(",", "") | ||
| return final_solution | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--local_dir", default="~/data/gsm8k") | ||
| parser.add_argument("--hdfs_dir", default=None) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| data_source = "openai/gsm8k" | ||
| dataset = datasets.load_dataset(data_source, "main") | ||
|
|
||
| train_dataset = dataset["train"] | ||
| test_dataset = dataset["test"] | ||
|
|
||
| instruction_following = "You must use the `calc_gsm8k_reward` tool to calculate the reward of your answer(1.0 if your answer is correct, 0.0 if your answer is incorrect) before submitting it at least once and refine your answer if necessary. Put your final answer in the format of `#### <answer>`." | ||
|
|
||
| # add a row to each data item that represents a unique id | ||
| def make_map_fn(split): | ||
| def process_fn(example, idx): | ||
| question_raw = example.pop("question") | ||
|
|
||
| question = question_raw + " " + instruction_following | ||
|
|
||
| answer_raw = example.pop("answer") | ||
| solution = extract_solution(answer_raw) | ||
| data = { | ||
| "data_source": data_source, | ||
| "prompt": [ | ||
| { | ||
| "role": "system", | ||
| "content": "You are a math expert. You are given a question and you need to solve it step by step. `calc_gsm8k_reward` is a tool for calculating the reward of gsm8k. You should use this tool to calculate the reward of your answer(1.0 if your answer is correct, 0.0 if your answer is incorrect) before submitting it and refine your answer if necessary. Put your final answer in the format of `#### <answer>`.", | ||
| }, | ||
| { | ||
| "role": "user", | ||
| "content": question, | ||
| }, | ||
| ], | ||
| "ability": "math", | ||
| "reward_model": {"style": "rule", "ground_truth": solution}, | ||
| "extra_info": { | ||
| "split": split, | ||
| "index": idx, | ||
| "answer": answer_raw, | ||
| "question": question_raw, | ||
| "need_tools_kwargs": True, | ||
| "tools_kwargs": { | ||
| "calc_gsm8k_reward": { | ||
| "create_kwargs": {"ground_truth": solution}, | ||
| # "execute_kwargs": {}, | ||
| # "calc_reward_kwargs": {}, | ||
| # "release_kwargs": {}, | ||
| }, | ||
| }, | ||
| }, | ||
| } | ||
| return data | ||
|
|
||
| return process_fn | ||
|
|
||
| train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) | ||
| test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) | ||
|
|
||
| local_dir = args.local_dir | ||
| hdfs_dir = args.hdfs_dir | ||
|
|
||
| train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) | ||
| test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) | ||
|
|
||
| if hdfs_dir is not None: | ||
| makedirs(hdfs_dir) | ||
|
|
||
| copy(src=local_dir, dst=hdfs_dir) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # Multi-Turn Rollout Example (GSM8K) | ||
|
|
||
| This example demonstrates how to perform **multi-turn rollout** using SGLang with a tool-calling capable model (e.g., Qwen2.5-3B) on the GSM8K dataset. | ||
|
|
||
| ## Usage | ||
|
|
||
|
|
||
| ### Step 1: Download GSM8K Dataset | ||
|
|
||
| ```bash | ||
| cd examples/sglang_multiturn | ||
| python3 gsm8k.py | ||
| ``` | ||
|
|
||
| This will download and preprocess the GSM8K dataset into ~/data/gsm8k/. | ||
|
|
||
| ### Step 2: Run Multi-Turn Rollout | ||
| If you have 8 GPUs | ||
| Use the standard 8-GPU script: | ||
|
|
||
| ``` | ||
| cd your_verl_root_dir | ||
| bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh | ||
| ``` | ||
|
|
||
| If you have only 4 GPUs | ||
| Use the fallback 4-GPU script: | ||
|
|
||
| ``` | ||
| cd your_verl_root_dir | ||
| bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh | ||
| ``` | ||
|
|
||
| # Notes | ||
|
|
||
| - The rollout supports multi-turn conversations with tool-calling capabilities. | ||
|
|
||
| - Current tools are used for GSM8K answer evaluation. | ||
|
|
||
| - Future versions may extend to search and code interpreter tools. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # run on 8xH100 | ||
| # make sure your current working directory is the root of the project | ||
|
|
||
| set -x | ||
|
|
||
| ulimit -n 65535 | ||
|
|
||
| PROJECT_DIR="$(pwd)" | ||
| CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" | ||
|
|
||
| python3 -m verl.trainer.main_ppo \ | ||
| --config-path="$CONFIG_PATH" \ | ||
| --config-name='gsm8k_multiturn_grpo' \ | ||
| algorithm.adv_estimator=grpo \ | ||
| data.train_batch_size=256 \ | ||
| data.max_prompt_length=1024 \ | ||
| data.max_response_length=1024 \ | ||
| data.filter_overlong_prompts=True \ | ||
| data.truncation='error' \ | ||
| data.return_raw_chat=True \ | ||
| actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ | ||
| actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=256 \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=40 \ | ||
| actor_rollout_ref.actor.use_kl_loss=True \ | ||
| actor_rollout_ref.actor.kl_loss_coef=0.001 \ | ||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=False \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ | ||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ | ||
| actor_rollout_ref.rollout.name=sglang_async \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ | ||
| actor_rollout_ref.rollout.n=16 \ | ||
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \ | ||
| actor_rollout_ref.ref.fsdp_config.param_offload=True \ | ||
| algorithm.use_kl_in_reward=False \ | ||
| trainer.critic_warmup=0 \ | ||
| trainer.logger=['console','wandb'] \ | ||
| trainer.project_name='gsm8k_async_rl' \ | ||
| trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16' \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes=1 \ | ||
| trainer.save_freq=-1 \ | ||
| trainer.test_freq=20 \ | ||
| data.train_files=$HOME/data/gsm8k/train.parquet \ | ||
| data.val_files=$HOME/data/gsm8k/test.parquet \ | ||
| actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ | ||
| trainer.total_epochs=15 $@ | ||
|
|
Uh oh!
There was an error while loading. Please reload this page.