diff --git a/.github/workflows/dataset.yml b/.github/workflows/dataset.yml index 396b46b95d..90b7f92706 100644 --- a/.github/workflows/dataset.yml +++ b/.github/workflows/dataset.yml @@ -33,7 +33,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/e2e_dapo.yml b/.github/workflows/e2e_dapo.yml index bb97eb898c..e6a01e78ad 100644 --- a/.github/workflows/e2e_dapo.yml +++ b/.github/workflows/e2e_dapo.yml @@ -36,7 +36,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/e2e_eval_aime24.yml b/.github/workflows/e2e_eval_aime24.yml index 1273d1a8f2..c6638db834 100644 --- a/.github/workflows/e2e_eval_aime24.yml +++ b/.github/workflows/e2e_eval_aime24.yml @@ -37,7 +37,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index 140d73a052..d397287bb4 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -61,7 +61,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 @@ -139,7 +140,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0 @@ -172,7 +174,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3 @@ -193,6 +196,64 @@ jobs: ray stop --force ENGINE=sglang bash tests/e2e/ppo_trainer/run_function_reward.sh + e2e_ppo_trainer_sglang_async: + runs-on: [L20x8] + needs: pre_commit_for_ppo + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + container: + image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install -e .[test,gpu,sglang] --no-deps + - name: Prepare gsm8k dataset + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async + run: | + ray stop --force + ENGINE=sglang_async bash tests/e2e/ppo_trainer/run_function_reward.sh + + e2e_ppo_trainer_sglang_async_with_tool: + runs-on: [L20x8] + needs: pre_commit_for_ppo + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + container: + image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install -e .[test,gpu,sglang] --no-deps + - name: Prepare gsm8k dataset with tool + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k_multiturn_w_tool.py --local_dir $HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed + - name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async + run: | + ray stop --force + bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh + e2e_ppo_trainer_sglang_vlm: runs-on: [L20x8] needs: pre_commit_for_ppo @@ -200,7 +261,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3 diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 0a4a493ebd..d90d1eb913 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -44,7 +44,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 @@ -82,7 +83,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/e2e_prime.yml b/.github/workflows/e2e_prime.yml index df5de828d3..2c684812d4 100644 --- a/.github/workflows/e2e_prime.yml +++ b/.github/workflows/e2e_prime.yml @@ -36,7 +36,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/e2e_sft.yml b/.github/workflows/e2e_sft.yml index 88aab250e3..b786a4b951 100644 --- a/.github/workflows/e2e_sft.yml +++ b/.github/workflows/e2e_sft.yml @@ -42,7 +42,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index dc4c48d990..195025623b 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -29,7 +29,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/ray_test.yml b/.github/workflows/ray_test.yml index 7d4064b46d..561b4ee32c 100644 --- a/.github/workflows/ray_test.yml +++ b/.github/workflows/ray_test.yml @@ -32,7 +32,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/sandbox.yml b/.github/workflows/sandbox.yml index 973eb5c7da..dc4f1a0c88 100644 --- a/.github/workflows/sandbox.yml +++ b/.github/workflows/sandbox.yml @@ -31,7 +31,8 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:ngc-cu124-vllm0.8.3-sglang0.4.5-mcore0.12.0-te2.2 diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml new file mode 100644 index 0000000000..6794c5faf8 --- /dev/null +++ b/.github/workflows/sgl.yml @@ -0,0 +1,65 @@ +name: sgl + +on: + workflow_dispatch: # Manual + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - .github/workflows/vllm.yml + pull_request: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - "verl/trainer/config/*.yaml" + - .github/workflows/sgl.yml + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + sgl: + runs-on: [self-hosted, l20-0] + timeout-minutes: 20 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: 1 + SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True" + container: + image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test,gpu,sglang] --no-deps + - name: Test the latest SGLang + run: | + cd tests/rollout + torchrun --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_sglang_spmd.py + - name: Test the latest SGLang async + run: | + cd tests/rollout + torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_spmd.py + - name: Test the latest SGLang Rollout async with tool + run: | + cd tests/rollout + torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_rollout_w_tools.py diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index f5da0b91ac..abe8bc7828 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -46,9 +46,9 @@ jobs: env: HTTP_PROXY: ${{ secrets.PROXY_HTTP }} HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: image: whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-megatron0.11.0-v0.0.6 options: --gpus all --shm-size=10g diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst new file mode 100644 index 0000000000..8d6cde3261 --- /dev/null +++ b/docs/sglang_multiturn/multiturn.rst @@ -0,0 +1,40 @@ +Multi-turn Rollout Support +========================= + +Basic Configuration +~~~~~~~~~~~~~~~~~ + +To enable multi-turn rollout, make sure to configure the following fields in your rollout configuration: + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + multi_turn: True + name: "sglang_async" + +These configuration activates the sglang_async engine for multi-turn interaction during rollout. + +Custom Tool Configuration +~~~~~~~~~~~~~~~~~~~~~~~ + +For custom environment interaction tools, you can specify your tool configurations in a YAML file. +To do so, use the following format in your rollout config: + +.. code-block:: yaml + + actor_rollout_ref: + rollout: + tool_kwargs: + tools_config_file: + +This allows integration of customized tool behaviors during actor rollout steps. You may refer to the GSM8KTool_example_configuration_ for guidance. + +GSM8K Multi-turn Training Performance +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +See the training performance of multi-turn rollout on the GSM8K task HERE_. + +.. _HERE: https://wandb.ai/zhaochenyang20/gsm8k_async_rl/runs/1ro1r7om?nw=nwuserzhaochenyang20 + +.. _GSM8KTool_example_configuration: ../../examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml \ No newline at end of file diff --git a/examples/data_preprocess/gsm8k_multiturn_w_tool.py b/examples/data_preprocess/gsm8k_multiturn_w_tool.py new file mode 100644 index 0000000000..b3b708861a --- /dev/null +++ b/examples/data_preprocess/gsm8k_multiturn_w_tool.py @@ -0,0 +1,117 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocess the GSM8k dataset to parquet format +""" + +import argparse +import os +import re + +import datasets + +from verl.utils.hdfs_io import copy, makedirs + + +def extract_solution(solution_str): + solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) + assert solution is not None + final_solution = solution.group(0) + final_solution = final_solution.split("#### ")[1].replace(",", "") + return final_solution + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default="~/data/gsm8k") + parser.add_argument("--hdfs_dir", default=None) + + args = parser.parse_args() + + data_source = "openai/gsm8k" + dataset = datasets.load_dataset(data_source, "main") + + train_dataset = dataset["train"] + test_dataset = dataset["test"] + + instruction_following = """ + You must use the `calc_gsm8k_reward` tool to calculate the reward + of your answer(1.0 if your answer is correct, 0.0 if your answer is incorrect) + before submitting it at least once and refine your answer if necessary. + Put your final answer in the format of `#### `. + """ + + # add a row to each data item that represents a unique id + def make_map_fn(split): + def process_fn(example, idx): + question_raw = example.pop("question") + + question = question_raw + " " + instruction_following + + answer_raw = example.pop("answer") + solution = extract_solution(answer_raw) + data = { + "data_source": data_source, + "prompt": [ + { + "role": "system", + "content": """ + You are a math expert. You are given a question and you need to solve it step by step. + `calc_gsm8k_reward` is a tool for calculating the reward of gsm8k. You should use this + tool to calculate the reward of your answer(1.0 if your answer is correct, 0.0 if your + answer is incorrect) before submitting it and refine your answer if necessary. Put your + final answer in the format of `#### `.""", + }, + { + "role": "user", + "content": question, + }, + ], + "ability": "math", + "reward_model": {"style": "rule", "ground_truth": solution}, + "extra_info": { + "split": split, + "index": idx, + "answer": answer_raw, + "question": question_raw, + "need_tools_kwargs": True, + "tools_kwargs": { + "calc_gsm8k_reward": { + "create_kwargs": {"ground_truth": solution}, + # "execute_kwargs": {}, + # "calc_reward_kwargs": {}, + # "release_kwargs": {}, + }, + }, + }, + } + return data + + return process_fn + + train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True) + test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True) + + local_dir = args.local_dir + hdfs_dir = args.hdfs_dir + + train_dataset.to_parquet(os.path.join(local_dir, "train.parquet")) + test_dataset.to_parquet(os.path.join(local_dir, "test.parquet")) + + if hdfs_dir is not None: + makedirs(hdfs_dir) + + copy(src=local_dir, dst=hdfs_dir) diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml new file mode 100644 index 0000000000..c57a5afde2 --- /dev/null +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml @@ -0,0 +1,22 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 1024 + max_response_length: 1024 + train_batch_size: 256 + return_raw_chat: True + +actor_rollout_ref: + hybrid_engine: True + rollout: + name: sglang_async + multi_turn: + enable: True + max_turns: 5 + # tool_config_path: "./config/tool_config/gsm8k_tool_config.yaml" diff --git a/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml b/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml new file mode 100644 index 0000000000..c388e4c233 --- /dev/null +++ b/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml @@ -0,0 +1,15 @@ +tools: + - class_name: "verl.tools.gsm8k_tool.Gsm8kTool" + config: {} + tool_schema: + type: "function" + function: + name: "calc_gsm8k_reward" + description: "A tool for calculating the reward of gsm8k. (1.0 if your answer is correct, 0.0 if your answer is incorrect)" + parameters: + type: "object" + properties: + answer: + type: "string" + description: "The model's answer to the GSM8K math problem, must be a digits" + required: ["answer"] diff --git a/examples/sglang_multiturn/readme.md b/examples/sglang_multiturn/readme.md new file mode 100644 index 0000000000..13c86e8166 --- /dev/null +++ b/examples/sglang_multiturn/readme.md @@ -0,0 +1,40 @@ +# Multi-Turn Rollout Example (GSM8K) + +This example demonstrates how to perform **multi-turn rollout** using SGLang with a tool-calling capable model (e.g., Qwen2.5-3B) on the GSM8K dataset. + +## Usage + + +### Step 1: Download GSM8K Dataset + +```bash +cd examples/sglang_multiturn +python3 gsm8k.py +``` + +This will download and preprocess the GSM8K dataset into ~/data/gsm8k/. + +### Step 2: Run Multi-Turn Rollout +If you have 8 GPUs +Use the standard 8-GPU script: + +``` +cd your_verl_root_dir +bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh +``` + +If you have only 4 GPUs +Use the fallback 4-GPU script: + +``` +cd your_verl_root_dir +bash examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh +``` + +# Notes + +- The rollout supports multi-turn conversations with tool-calling capabilities. + +- Current tools are used for GSM8K answer evaluation. + +- Future versions may extend to search and code interpreter tools. \ No newline at end of file diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh new file mode 100644 index 0000000000..7dedada0ec --- /dev/null +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh @@ -0,0 +1,53 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.total_epochs=15 $@ + diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh new file mode 100644 index 0000000000..5cca4b471b --- /dev/null +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh @@ -0,0 +1,58 @@ +# run on 4xH100 +# make sure your current working directory is the root of the project + +set -x +export HYDRA_FULL_ERROR=1 +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16-4cards' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=20 \ + trainer.total_epochs=15 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=8192 \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=8192 \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=8192 \ + critic.ppo_max_token_len_per_gpu=8192 \ + critic.forward_max_token_len_per_gpu=8192 \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + $@ \ No newline at end of file diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 40750678ba..c8d00b0f0a 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -45,10 +45,7 @@ "--local_dir", type=str, required=True, - help=( - "The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, " - "commonly be `config.default_local_dir/global_step_\{global_step\}`." - ), + help=("The path for your saved model. For megatron, point to the base dir of model, rng, optimizer checkpoints, commonly be `config.default_local_dir/global_step_\{global_step\}`."), ) parser.add_argument("--target_dir", required=False, default="tmp", type=str, help="The path for the target model") parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") @@ -96,10 +93,7 @@ def patch_model_generation_config(model, hf_model_path): try: model.generation_config = GenerationConfig.from_pretrained(args.hf_model_path) except OSError: - print( - f"Warning: Generation config file not found in {args.hf_model_path}, " - "using a generation config created from the model config." - ) + print(f"Warning: Generation config file not found in {args.hf_model_path}, using a generation config created from the model config.") pass return model diff --git a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh b/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh new file mode 100644 index 0000000000..364a03723a --- /dev/null +++ b/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh @@ -0,0 +1,55 @@ +# run on 8xH100 +# make sure your current working directory is the root of the project + +set -x + +huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen/Qwen2.5-3B-Instruct + +ulimit -n 65535 + +PROJECT_DIR="$(pwd)" +CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config" + +python3 -m verl.trainer.main_ppo \ + --config-path="$CONFIG_PATH" \ + --config-name='gsm8k_multiturn_grpo' \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=256 \ + data.max_prompt_length=1024 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-3B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='gsm8k_async_rl' \ + trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-rebased-0427-verify-n16' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + data.train_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/train.parquet \ + data.val_files=$HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed/test.parquet \ + actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ + trainer.val_before_train=False \ + trainer.total_training_steps=1 $@ diff --git a/tests/rollout/test_sglang_async_rollout_w_tools.py b/tests/rollout/test_sglang_async_rollout_w_tools.py new file mode 100644 index 0000000000..b8b84be48f --- /dev/null +++ b/tests/rollout/test_sglang_async_rollout_w_tools.py @@ -0,0 +1,133 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +usage: torchrun --standalone --nnodes=1 \ + --nproc_per_node=2 $(which pytest) \ + -s test_sglang_async_rollout_w_tools.py +""" + +import numpy as np +import torch +from tensordict import TensorDict +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import MixedPrecision, ShardingStrategy +from utils_sglang import ( + are_lists_similar, + clean_torchelastic_env, + generate_hf_output, + get_rollout_config, + initialize_global_process_group, + load_tokenizer_and_model, + prepare_inputs, +) + +from verl import DataProto +from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout +from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager + + +def test_async_sglang_rollout_w_tool(): + assert torch.cuda.device_count() >= 2 + initialize_global_process_group() + clean_torchelastic_env() + + max_prompt_length = 32 + max_response_length = 16 + dtype = "bfloat16" + tensor_parallel_size = 2 + local_model_path = "Qwen/Qwen2.5-0.5B" + + tokenizer, actor_model = load_tokenizer_and_model(local_model_path) + + preencode_prompts = [ + [{"role": "user", "content": prompt, "tool_calls": None}] + for prompt in [ + "Who won the Champions League in 2019?", + "The founder of Apple is", + "What's the best way to learn python?", + ] + ] + prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in preencode_prompts] + input_ids, attention_mask, position_ids = prepare_inputs(tokenizer, prompts, max_prompt_length) + + hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) + + fsdp_device_mesh = init_device_mesh("cuda", mesh_shape=(tensor_parallel_size,), mesh_dim_names=("fsdp",)) + inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=("dp", "infer_tp", "pp")) + + fsdp_model = FSDP( + actor_model, + use_orig_params=True, + device_id=fsdp_device_mesh["fsdp"].get_local_rank(), + mixed_precision=MixedPrecision(param_dtype=getattr(torch, dtype)), + sharding_strategy=ShardingStrategy.FULL_SHARD, + device_mesh=fsdp_device_mesh, + ) + + rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size) + rollout = AsyncSGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config) + + rollout_sharding_manager = FSDPAsyncSGLangShardingManager( + module=fsdp_model, + inference_engine=rollout._engine, + model_config=actor_model.config, + full_params=True, + device_mesh=inference_device_mesh_cpu, + ) + + with rollout_sharding_manager: + prompt_dict = TensorDict( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=input_ids.shape[0], + ) + print(f"preprocessed {input_ids.shape=}") + + messages = np.asarray(preencode_prompts) + prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages}) + + prompts.meta_info.update( + { + "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": tokenizer.pad_token_id, + } + ) + + prompts = rollout_sharding_manager.preprocess_data(prompts) + # log_gpu_memory_usage("Before generating sequences", logger=None) + output = rollout.generate_sequences_with_tools(prompts=prompts) + print(f"generated {output.batch['responses'].shape=}") + # log_gpu_memory_usage("After generating sequences", logger=None) + output = rollout_sharding_manager.postprocess_data(output) + print(f"postprocessed {output.batch['responses'].shape=}") + sglang_output = output.to("cpu") + + sglang_response_tokens = tokenizer.batch_decode(sglang_output.batch["responses"]) + + print(f"hf response: {hf_response_tokens}") + print(f"sglang response: {sglang_response_tokens}") + assert are_lists_similar(hf_response_tokens, sglang_response_tokens) + print("SGLang w tool Test Passed!") + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + test_async_sglang_rollout_w_tool() diff --git a/tests/rollout/test_sglang_async_spmd.py b/tests/rollout/test_sglang_async_spmd.py new file mode 100644 index 0000000000..d8187c9694 --- /dev/null +++ b/tests/rollout/test_sglang_async_spmd.py @@ -0,0 +1,113 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +usage: torchrun --standalone --nnodes=1 \ + --nproc_per_node=2 $(which pytest) \ + -s test_sglang_async_spmd.py +""" + +import asyncio + +import torch +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.utils import broadcast_pyobj +from torch.distributed.device_mesh import init_device_mesh +from utils_sglang import ( + are_lists_similar, + clean_torchelastic_env, + generate_hf_output, + initialize_global_process_group, + load_tokenizer_and_model, + prepare_inputs, +) + + +def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + token_ids = prompt_token_ids[non_pad_index:].tolist() + return token_ids + + +def test_sglang_spmd(): + assert torch.cuda.device_count() >= 2 + initialize_global_process_group(spmd=True) + clean_torchelastic_env() + + max_prompt_length = 16 + max_response_length = 16 + + local_model_path = "Qwen/Qwen2.5-0.5B" + tokenizer, actor_model = load_tokenizer_and_model(local_model_path) + + preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"] + input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length) + + hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) + + tensor_parallel_size = 2 + inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) + tp_rank = inference_device_mesh_cpu["tp"].get_local_rank() + + if tp_rank == 0: + llm = Engine( + model_path=local_model_path, + dtype="bfloat16", + mem_fraction_static=0.5, + enable_memory_saver=True, + tp_size=inference_device_mesh_cpu["tp"].size(), + ) + + input_ids = input_ids.cuda() + idx_list = [] + + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + for i in range(input_ids.shape[0]): + idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) + + sampling_params = dict( + n=1, + temperature=0, + top_p=1, + top_k=-1, + max_new_tokens=max_response_length, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ignore_eos=False, + ) + + loop = asyncio.get_event_loop() + outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params)) + else: + outputs = None + + [outputs] = broadcast_pyobj( + [outputs], + rank=inference_device_mesh_cpu["tp"].get_local_rank(), + src=inference_device_mesh_cpu["tp"].mesh[0].item(), + dist_group=inference_device_mesh_cpu["tp"].get_group(), + force_cpu_device=False, + ) + + sglang_response_tokens = [output["text"] for output in outputs] + + print(f"sglang response: {sglang_response_tokens}") + assert are_lists_similar(hf_response_tokens, sglang_response_tokens) + print("SPMD Test Passed!") + + torch.distributed.barrier() + torch.distributed.destroy_process_group() diff --git a/tests/rollout/test_sglang_spmd.py b/tests/rollout/test_sglang_spmd.py index fdb26d25b0..40c514dd5e 100644 --- a/tests/rollout/test_sglang_spmd.py +++ b/tests/rollout/test_sglang_spmd.py @@ -1,17 +1,5 @@ # Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import torch @@ -114,7 +101,7 @@ def test_sglang_spmd(): preencode_prompts = [ "Who won the Champions League in 2019?", "The founder of Apple is", - "What's your name", + "What's your name?", ] tokenizer.pad_token = tokenizer.eos_token prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) @@ -207,7 +194,6 @@ def test_sglang_spmd(): def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids diff --git a/tests/rollout/test_vllm_spmd.py b/tests/rollout/test_vllm_spmd.py index 8e010dea36..5b96078753 100644 --- a/tests/rollout/test_vllm_spmd.py +++ b/tests/rollout/test_vllm_spmd.py @@ -89,7 +89,7 @@ def test_vllm_spmd(): preencode_prompts = [ "Who won the Champions League in 2019?", "The founder of Apple is", - "What's your name", + "What's your name?", ] tokenizer.pad_token = tokenizer.eos_token prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) diff --git a/tests/rollout/utils_sglang.py b/tests/rollout/utils_sglang.py new file mode 100644 index 0000000000..de97920942 --- /dev/null +++ b/tests/rollout/utils_sglang.py @@ -0,0 +1,162 @@ +# Copyright 2023-2024 SGLang Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from datetime import timedelta + +import torch +from omegaconf import OmegaConf +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from verl.utils.model import compute_position_id_with_mask +from verl.utils.torch_functional import pad_sequence_to_length + + +# ====================== utils ====================== +def levenshtein(s1, s2): + m, n = len(s1), len(s2) + dp = [[0] * (n + 1) for _ in range(m + 1)] + for i in range(m + 1): + dp[i][0] = i + for j in range(n + 1): + dp[0][j] = j + for i in range(1, m + 1): + for j in range(1, n + 1): + cost = 0 if s1[i - 1] == s2[j - 1] else 1 + dp[i][j] = min(dp[i - 1][j] + 1, dp[i][j - 1] + 1, dp[i - 1][j - 1] + cost) + return dp[m][n] + + +def are_lists_similar(a, b): + if len(a) != len(b): + print("The lists are of different lengths.") + return False + total_length = 0 + total_diff = 0 + for s1, s2 in zip(a, b): + max_len = max(len(s1), len(s2)) + total_length += max_len + total_diff += levenshtein(s1, s2) + percentage_difference = (total_diff / total_length) * 100 + print(f"Total difference: {percentage_difference:.2f}%") + return percentage_difference <= 10 + + +def initialize_global_process_group(timeout_second=36000, spmd=False): + import torch.distributed + + if not torch.distributed.is_initialized(): # Check if already initialized + print("Initializing process group...") + torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second)) + else: + print("Process group already initialized.") + + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + + CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if not CUDA_VISIBLE_DEVICES: + if spmd: + # CUDA_VISIBLE_DEVICES = ','.join(str(i) for i in range(tensor_parallel_size)) + CUDA_VISIBLE_DEVICES = ",".join(str(i) for i in range(world_size)) + else: + CUDA_VISIBLE_DEVICES = str(local_rank) + os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES + print(f"CUDA_VISIBLE_DEVICES is not set, set to {CUDA_VISIBLE_DEVICES}") + + return local_rank, rank, world_size + + +def clean_torchelastic_env(): + for k in ["TORCHELASTIC_USE_AGENT_STORE"]: + if k in os.environ: + del os.environ[k] + + +def load_tokenizer_and_model(local_model_path, dtype="bfloat16"): + tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained(local_model_path, torch_dtype=getattr(torch, dtype), device_map="cuda") + return tokenizer, model + + +def prepare_inputs(tokenizer, prompts, max_prompt_length): + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + tokenized = tokenizer(prompts, return_tensors="pt", padding=True) + input_ids = pad_sequence_to_length(tokenized["input_ids"], max_prompt_length, pad_token_id, left_pad=True) + attention_mask = pad_sequence_to_length(tokenized["attention_mask"], max_prompt_length, pad_token_id=0, left_pad=True) + position_ids = compute_position_id_with_mask(attention_mask) + position_ids = pad_sequence_to_length(position_ids, max_prompt_length, pad_token_id=0, left_pad=True) + return input_ids, attention_mask, position_ids + + +def generate_hf_output(model, input_ids, attention_mask, tokenizer, max_response_length): + generation_config = GenerationConfig(do_sample=False) + output = model.generate( + input_ids=input_ids.cuda(), + attention_mask=attention_mask.cuda(), + max_new_tokens=max_response_length, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config=generation_config, + output_scores=False, + return_dict_in_generate=True, + use_cache=False, + ) + seq = output.sequences + response = seq[:, input_ids.shape[1] :] + return tokenizer.batch_decode(response) + + +def get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size): + sampling_params = dict( + n=1, + temperature=0, + top_p=1, + top_k=-1, + max_new_tokens=max_response_length, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ignore_eos=False, + ) + + rollout_config = OmegaConf.create( + { + "name": "sglang", + "load_format": "dummy_dtensor", + "enforce_eager": False, + "free_cache_engine": False, + "dtype": dtype, + "gpu_memory_utilization": 0.5, + "ignore_eos": False, + "max_num_batched_tokens": 8192, + "prompt_length": max_prompt_length, + "response_length": max_response_length, + "tensor_model_parallel_size": tensor_parallel_size, + "multi_turn": { + "max_turns": 4, + "enable": True, + "tool_config_path": None, + "format": "chatml", + }, + "max_model_len": None, + **sampling_params, + } + ) + + return rollout_config diff --git a/tests/sanity/check_license.py b/tests/sanity/check_license.py index 1b760916f3..9cf2da3ce9 100644 --- a/tests/sanity/check_license.py +++ b/tests/sanity/check_license.py @@ -19,7 +19,16 @@ # Add custom license headers below license_head_prime = "Copyright 2024 PRIME team and/or its affiliates" license_head_individual = "Copyright 2025 Individual Contributor:" -license_headers = [license_head_bytedance, license_head_bytedance_25, license_head_prime, license_head_individual] +license_head_sglang = "Copyright 2023-2024 SGLang Team" +license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates" +license_headers = [ + license_head_bytedance, + license_head_bytedance_25, + license_head_prime, + license_head_individual, + license_head_sglang, + license_head_modelbest, +] if __name__ == "__main__": diff --git a/verl/tools/__init__.py b/verl/tools/__init__.py new file mode 100644 index 0000000000..c4b932b1ae --- /dev/null +++ b/verl/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/tools/base_tool.py b/verl/tools/base_tool.py new file mode 100644 index 0000000000..639a7f80cc --- /dev/null +++ b/verl/tools/base_tool.py @@ -0,0 +1,86 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple +from uuid import uuid4 + +from .schemas import OpenAIFunctionToolSchema + + +class BaseTool: + """Base class for tools. + + A tool should support the following methods: + + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + self.config = config + self.name = tool_schema.function.name + self.tool_schema = tool_schema + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, **kwargs) -> str: + """Create a tool instance. + + Args: + instance_id: The instance id of the tool. + + Returns: + The instance id of the tool. + """ + if instance_id is None: + return str(uuid4()) + else: + return instance_id + + async def execute(self, instance_id: str, parameters: str, **kwargs) -> Tuple[str, float, dict]: + """Execute the tool. + + Args: + instance_id: The instance id of the tool. + parameters: The json string of the parameters of the tool. + + Returns: tool_response, tool_reward_score, tool_metrics + tool_response: The response str of the tool. + tool_reward_score: The step reward score of the tool. + tool_metrics: The metrics of the tool. + """ + return "Updated the tool state.", 0.0, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + """Calculate the reward of the tool. + + Args: + instance_id: The instance id of the tool. + + Returns: + The reward of the tool. + """ + return 0.0 + + async def release(self, instance_id: str, **kwargs) -> None: + """Release the tool instance. + + Args: + instance_id: The instance id of the tool. + """ + pass diff --git a/verl/tools/gsm8k_tool.py b/verl/tools/gsm8k_tool.py new file mode 100644 index 0000000000..c02cd9e6b4 --- /dev/null +++ b/verl/tools/gsm8k_tool.py @@ -0,0 +1,109 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from typing import Optional, Tuple +from uuid import uuid4 + +from verl.utils.reward_score import gsm8k + +from .base_tool import BaseTool +from .schemas import OpenAIFunctionToolSchema + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class Gsm8kTool(BaseTool): + """A demo tool for calculating the reward of gsm8k. + + - `to_openai_function_tool_schema`: return the tool schema in OpenAI format. + - `create`: create a tool instance for a trajectory. + - `execute`: execute the tool. + - `calc_reward`: calculate the reward respect to tool state. + - `release`: release the tool instance. + """ + + def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema): + """ + _tool_schema = OpenAIFunctionToolSchema.model_validate({ + "type": "function", + "function": { + "name": "calc_gsm8k_reward", + "description": "A tool for calculating the reward of gsm8k", + "parameters": { + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "The answer to the question", + }, + }, + "required": ["answer"], + }, + } + }) + """ + super().__init__(config, tool_schema) + self._instance_dict = {} + + def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema: + return self.tool_schema + + async def create(self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs) -> str: + if instance_id is None: + instance_id = str(uuid4()) + self._instance_dict[instance_id] = { + "response": "", + "ground_truth": ground_truth, + "reward": 0.0, + } + return instance_id + + async def execute(self, instance_id: str, parameters: str, **kwargs) -> Tuple[str, float, dict]: + try: + _parameters = json.loads(parameters) + except json.JSONDecodeError: + _parameters = {} + if isinstance(_parameters, dict): + answer = _parameters.get("answer", "") + if not isinstance(answer, str): + answer = str(answer) + else: + answer = "" + if answer.startswith("#### "): + self._instance_dict[instance_id]["response"] = answer + else: + self._instance_dict[instance_id]["response"] = "#### " + answer + reward = await self.calc_reward(instance_id) + # penalty for non improved answer submission + tool_reward = 0.0 if reward > self._instance_dict[instance_id]["reward"] else -0.05 + # update the reward + self._instance_dict[instance_id]["reward"] = reward + return f"Current parsed {answer=} {reward=}", tool_reward, {} + + async def calc_reward(self, instance_id: str, **kwargs) -> float: + return gsm8k.compute_score( + self._instance_dict[instance_id]["response"], + self._instance_dict[instance_id]["ground_truth"], + method="flexible", + format_score=0.0, + score=1.0, + ) + + async def release(self, instance_id: str, **kwargs) -> None: + del self._instance_dict[instance_id] diff --git a/verl/tools/schemas.py b/verl/tools/schemas.py new file mode 100644 index 0000000000..d4e0797b58 --- /dev/null +++ b/verl/tools/schemas.py @@ -0,0 +1,64 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal + +from pydantic import BaseModel + + +class OpenAIFunctionPropertySchema(BaseModel): + """The schema of a parameter in OpenAI format.""" + + type: str + description: str | None = None + enum: list[str] | None = None + + +class OpenAIFunctionParametersSchema(BaseModel): + """The schema of parameters in OpenAI format.""" + + type: str + properties: dict[str, OpenAIFunctionPropertySchema] + required: list[str] + + +class OpenAIFunctionSchema(BaseModel): + """The schema of a function in OpenAI format.""" + + name: str + description: str + parameters: OpenAIFunctionParametersSchema + strict: bool = False + + +class OpenAIFunctionToolSchema(BaseModel): + """The schema of a tool in OpenAI format.""" + + type: str + function: OpenAIFunctionSchema + + +class OpenAIFunctionParsedSchema(BaseModel): + """The parsed schema of a tool in OpenAI format.""" + + name: str + arguments: str # JSON string + + +class OpenAIFunctionToolCall(BaseModel): + """The tool call in OpenAI format.""" + + id: str + type: Literal["function"] = "function" + function: OpenAIFunctionParsedSchema diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 22420d7206..d6bf4c3db1 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -146,6 +146,11 @@ actor_rollout_ref: temperature: 0 n: 1 do_sample: False # default eager for validation + multi_turn: + enable: False # should set rollout.name to sglang_async if True + max_turns: null # null for no limit (default max_length // 3) + tool_config_path: null # null for no tool + format: chatml # chatml, more formats will be supported in the future critic: rollout_n: ${actor_rollout_ref.rollout.n} diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index ca14b378bf..1b6668dce8 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -121,6 +121,11 @@ actor_rollout_ref: temperature: 0 n: 1 do_sample: False # default eager for validation + multi_turn: + enable: False # should set rollout.name to sglang_async if True + max_turns: null # null for no limit (default max_length // 3) + tool_config_path: null # null for no tool + format: chatml # chatml, more formats will be supported in the future critic: rollout_n: ${actor_rollout_ref.rollout.n} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 1637a40b8b..bbeacff487 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1,4 +1,6 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -141,13 +143,18 @@ def _check_resource_available(self): raise ValueError(f"Resource pool {resource_pool_name}: {num_gpus}*{num_nodes}" + "cannot be satisfied in this ray cluster") -def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl"): +def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, kl_penalty="kl", multi_turn=False): responses = data.batch["responses"] response_length = responses.size(1) token_level_scores = data.batch["token_level_scores"] batch_size = data.batch.batch_size[0] - attention_mask = data.batch["attention_mask"] - response_mask = attention_mask[:, -response_length:] + + if multi_turn: + loss_mask = data.batch["loss_mask"] + response_mask = loss_mask[:, -response_length:] + else: + attention_mask = data.batch["attention_mask"] + response_mask = attention_mask[:, -response_length:] # compute kl between ref_policy and current policy # When apply_kl_penalty, algorithm.use_kl_in_reward=True, so the reference model has been enabled. @@ -176,7 +183,7 @@ def compute_response_mask(data: DataProto): return attention_mask[:, -response_length:] -def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, norm_adv_by_std_in_grpo=True): +def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1, multi_turn=False, norm_adv_by_std_in_grpo=True): # Back-compatible with trainers that do not compute response mask in fit if "response_mask" not in data.batch: data.batch["response_mask"] = compute_response_mask(data) @@ -193,9 +200,16 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re data.batch["advantages"] = advantages data.batch["returns"] = returns elif adv_estimator == AdvantageEstimator.GRPO: + # TODO: test on more adv estimator type + grpo_calculation_mask = data.batch["response_mask"] + if multi_turn: + # If multi-turn, replace the mask with the relevant part of loss_mask + response_length = grpo_calculation_mask.size(1) # Get length from the initial response mask + grpo_calculation_mask = data.batch["loss_mask"][:, -response_length:] # This mask is the one intended for GRPO + # Call compute_grpo_outcome_advantage with parameters matching its definition advantages, returns = core_algos.compute_grpo_outcome_advantage( token_level_rewards=data.batch["token_level_rewards"], - response_mask=data.batch["response_mask"], + response_mask=grpo_calculation_mask, index=data.non_tensor_batch["uid"], norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, ) @@ -414,27 +428,39 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): if config.actor_rollout_ref.rollout.val_kwargs.do_sample: assert config.actor_rollout_ref.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample" + # check multi_turn with tool config + if config.actor_rollout_ref.rollout.multi_turn.enable: + assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None, "tool_config_path must be set when enabling multi_turn with tool, due to no role-playing support" + assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool" + print("[validate_config] All configuration checks passed successfully!") def _create_dataloader(self): - # TODO: we have to make sure the batch size is divisible by the dp size + """ + Creates the train and validation dataloaders. + """ + # make sure the batch size is divisible by the dp size from verl.utils.import_utils import load_extern_type if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None: - dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name) - if not issubclass(dataset_cls, Dataset): - raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from '{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset") + # Dynamically load the custom dataset class specified in config + try: + dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name) + if not issubclass(dataset_cls, Dataset): + raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from '{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset") + print(f"Using custom dataset class: {dataset_cls.__name__}") + except Exception as e: + print(f"Error loading custom dataset class: {e}") + raise e else: dataset_cls = RLHFDataset - + print(f"Using default dataset class: {dataset_cls.__name__}") self.train_dataset = dataset_cls( data_files=self.config.data.train_files, tokenizer=self.tokenizer, processor=self.processor, config=self.config.data, ) - - # use sampler for better ckpt resume if self.config.data.shuffle: train_dataloader_generator = torch.Generator() train_dataloader_generator.manual_seed(self.config.data.get("seed", 1)) @@ -445,7 +471,7 @@ def _create_dataloader(self): self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), - num_workers=8, + num_workers=self.config.data.get("dataloader_num_workers", 8), drop_last=True, collate_fn=collate_fn, sampler=sampler, @@ -457,24 +483,25 @@ def _create_dataloader(self): processor=self.processor, config=self.config.data, ) - # consider the design of single controller with a large val dataset in multi-modal scenarios - # may lead to oom issues - val_batch_size = self.config.data.val_batch_size or len(self.val_dataset) + + val_batch_size = self.config.data.val_batch_size # Prefer config value if set + if val_batch_size is None: + val_batch_size = len(self.val_dataset) + self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, batch_size=val_batch_size, - num_workers=8, + num_workers=self.config.data.get("dataloader_num_workers", 8), shuffle=False, drop_last=False, collate_fn=collate_fn, ) - assert len(self.train_dataloader) >= 1 - assert len(self.val_dataloader) >= 1 + assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" + assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" print(f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: {len(self.val_dataloader)}") - # inject total_training_steps to actor/critic optim_config. This is hacky. total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs if self.config.trainer.total_training_steps is not None: @@ -483,10 +510,15 @@ def _create_dataloader(self): self.total_training_steps = total_training_steps print(f"Total training steps: {self.total_training_steps}") - OmegaConf.set_struct(self.config, True) - with open_dict(self.config): - self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps - self.config.critic.optim.total_training_steps = total_training_steps + try: + OmegaConf.set_struct(self.config, True) + with open_dict(self.config): + if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): + self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps + if OmegaConf.select(self.config, "critic.optim"): + self.config.critic.optim.total_training_steps = total_training_steps + except Exception as e: + print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") def _dump_generations(self, inputs, outputs, scores, reward_extra_infos_dict, dump_path): """Dump rollout/validation samples as JSONL.""" @@ -561,16 +593,18 @@ def _validate(self): input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) - if "multi_modal_inputs" in test_batch.non_tensor_batch.keys(): - test_gen_batch = test_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], - ) - else: - test_gen_batch = test_batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"] + ["raw_prompt"] if self.async_rollout_mode else [], - ) + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_inputs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) + if "raw_prompt" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in test_batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + test_gen_batch = test_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) test_gen_batch.meta_info = { "eos_token_id": self.tokenizer.eos_token_id, @@ -868,20 +902,21 @@ def fit(self): for batch_dict in self.train_dataloader: metrics = {} timing_raw = {} - batch: DataProto = DataProto.from_single_dict(batch_dict) # pop those keys for generation - if "multi_modal_inputs" in batch.non_tensor_batch.keys(): - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"], - ) - else: - gen_batch = batch.pop( - batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"] + ["raw_prompt"] if self.async_rollout_mode else [], - ) + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + if "multi_modal_inputs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.extend(["multi_modal_data", "multi_modal_inputs"]) + if "raw_prompt" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("raw_prompt") + if "tools_kwargs" in batch.non_tensor_batch: + non_tensor_batch_keys_to_pop.append("tools_kwargs") + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) is_last_step = self.global_steps >= self.total_training_steps @@ -980,7 +1015,9 @@ def fit(self): batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] # compute advantages, executed on the driver process + norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) # GRPO adv normalization factor + batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, @@ -988,6 +1025,7 @@ def fit(self): lam=self.config.algorithm.lam, num_repeat=self.config.actor_rollout_ref.rollout.n, norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + multi_turn=self.config.actor_rollout_ref.rollout.multi_turn.enable, ) # update critic @@ -1001,6 +1039,7 @@ def fit(self): if self.config.trainer.critic_warmup <= self.global_steps: # update actor with _timer("update_actor", timing_raw): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index a06b73ddb7..8bced80de9 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -1,4 +1,6 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +15,7 @@ # limitations under the License. import copy +import logging import os import re from collections import defaultdict @@ -28,6 +31,8 @@ import verl.utils.torch_functional as verl_F from verl.utils.model import compute_position_id_with_mask +logger = logging.getLogger(__name__) + def collate_fn(data_list: list[dict]) -> dict: tensors = defaultdict(list) @@ -75,16 +80,15 @@ def __init__( self.image_key = config.get("image_key", "images") self.video_key = config.get("video_key", "videos") self.max_prompt_length = config.get("max_prompt_length", 1024) - self.return_raw_chat = config.get("return_raw_chat", False) self.truncation = config.get("truncation", "error") self.filter_overlong_prompts = config.get("filter_overlong_prompts", True) self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4)) self.num_workers = min(self.num_workers, os.cpu_count()) - - # whether to store the dataset in state_dict() - # default not store + self.chat_template_func = config.get("chat_template_func", None) + self.need_tools_kwargs = config.get("need_tools_kwargs", False) + self.filter_prompts = config.get("filter_prompts", True) self.serialize_dataset = False self._download() self._read_files_and_tokenize() @@ -240,8 +244,12 @@ def __getitem__(self, item): # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0) + tools_kwargs = row_dict.get("extra_info", {}).get("tools_kwargs", {}) + need_tools_kwargs = row_dict.get("extra_info", {}).get("need_tools_kwargs", self.need_tools_kwargs) + if need_tools_kwargs and not tools_kwargs: + logger.warning("tools_kwargs is empty for index {}, data source: {}", index, row_dict["data_source"]) row_dict["index"] = index - + row_dict["tools_kwargs"] = tools_kwargs return row_dict def __getstate__(self): diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index f9896e3c0a..f461b3f47a 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -89,5 +89,6 @@ def log(self, func, *args, **kwargs): mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() message = f"After {name}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}" + self.logging_function(message) return output diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index aff297c51f..3b26d81fe6 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -1,4 +1,6 @@ # Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -164,7 +166,7 @@ def _optimizer_step(self): # if grad_norm is not finite, skip the update if not torch.isfinite(grad_norm): - print(f"WARN: grad_norm is not finite: {grad_norm}") + print(f"WARN: rank {torch.distributed.get_rank()} grad_norm is not finite: {grad_norm}") self.actor_optimizer.zero_grad() else: self.actor_optimizer.step() @@ -240,8 +242,11 @@ def update_policy(self, data: DataProto): self.actor_module.train() temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error + multi_turn = data.meta_info.get("multi_turn", False) select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages"] + if multi_turn: + select_keys.append("loss_mask") if self.config.use_kl_loss: select_keys.append("ref_log_prob") batch = data.select(batch_keys=select_keys).batch @@ -284,7 +289,11 @@ def update_policy(self, data: DataProto): responses = data["responses"] response_length = responses.size(1) attention_mask = data["attention_mask"] - response_mask = attention_mask[:, -response_length:] + if multi_turn: + response_mask = data["loss_mask"][:, -response_length:] + else: + response_mask = attention_mask[:, -response_length:] + old_log_prob = data["old_log_probs"] advantages = data["advantages"] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 322ece5ca5..c1b501c0cb 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -400,6 +400,33 @@ def _build_rollout(self, trust_remote_code=False): ) log_gpu_memory_usage("After building sharding manager", logger=logger) + elif rollout_name == "sglang_async": + from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout + from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager + + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) + rollout = AsyncSGLangRollout( + actor_module=self.config.model.path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + ) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) + + if torch.distributed.get_world_size() == 1: + self.config.rollout.load_format = "dummy_hf" + rollout_sharding_manager = FSDPAsyncSGLangShardingManager( + module=self.actor_module_fsdp, + inference_engine=rollout._engine, + model_config=self.actor_model_config, + full_params="hf" in self.config.rollout.load_format, + device_mesh=rollout_device_mesh, + ) + log_gpu_memory_usage("After building sharding manager", logger=None) + + else: + raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported") + return rollout, rollout_sharding_manager @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -540,7 +567,18 @@ def generate_sequences(self, prompts: DataProto): log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) - output = self.rollout.generate_sequences(prompts=prompts) + + if self.config.rollout.name == "sglang_async": + from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout + + if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0: + output = self.rollout.generate_sequences_with_tools(prompts=prompts) + else: + output = self.rollout.generate_sequences(prompts=prompts) + else: + output = self.rollout.generate_sequences(prompts=prompts) + log_gpu_memory_usage("After rollout generation", logger=logger) + output = self.rollout_sharding_manager.postprocess_data(output) output = output.to("cpu") diff --git a/verl/workers/rollout/schemas.py b/verl/workers/rollout/schemas.py new file mode 100644 index 0000000000..6847bae6d7 --- /dev/null +++ b/verl/workers/rollout/schemas.py @@ -0,0 +1,217 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Any, Dict, List, Literal, Optional + +import torch +from pydantic import BaseModel +from transformers import PreTrainedTokenizer + +from verl.tools.schemas import OpenAIFunctionToolCall, OpenAIFunctionToolSchema +from verl.utils.model import compute_position_id_with_mask + + +class FinishReasonTypeEnum(str, Enum): + """The enum for finish reason type.""" + + LENGTH = "length" + STOP = "stop" + TOOL_CALL = "tool_calls" + + @classmethod + def from_str(cls, value: str) -> "FinishReasonTypeEnum": + if value == "stop": + return cls.STOP + elif value == "length": + return cls.LENGTH + elif value == "tool_calls": + return cls.TOOL_CALL + else: + raise ValueError(f"Unsupported finish reason type: {value}") + + +class Message(BaseModel): + role: str + content: str + tool_calls: Optional[List[OpenAIFunctionToolCall]] = None + + +class AsyncRolloutRequestStateEnum(str, Enum): + """The enum for async rollout request state.""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + TOOL_CALLING = "tool_calling" + + +class AsyncRolloutRequest(BaseModel): + """The data model for async rollout.""" + + batch_data_id: int = 0 + rollout_offset: int = 0 + request_id: str + state: AsyncRolloutRequestStateEnum + messages: List[Message] + tools: Optional[List[OpenAIFunctionToolSchema]] = None + tools_kwargs: Dict[str, Any] = {} + input_ids: List[int] + prompt_ids: List[int] + response_ids: List[int] + attention_mask: List[int] + prompt_attention_mask: List[int] + response_attention_mask: List[int] + position_ids: List[int] + prompt_position_ids: List[int] + response_position_ids: List[int] + loss_mask: List[int] + prompt_loss_mask: List[int] + response_loss_mask: List[int] + reward_scores: Dict[str, float] + max_response_len: int = 8192 + max_model_len: int = 32768 + + format_config: dict = { + "chatml": { + "assistant_prefix_msg": "\n<|im_start|>assistant\n", + "assistant_suffix_msg": "<|im_end|>", + "tool_prefix_msg": "\n<|im_start|>tool\n", + "tool_suffix_msg": "<|im_end|>", + } + } + + def get_generation_prompt(self, tokenizer: PreTrainedTokenizer) -> str: + return tokenizer.apply_chat_template( # type: ignore + conversation=[msg.model_dump() for msg in self.messages], + tools=[tool.model_dump() for tool in self.tools] if self.tools else None, + add_generation_prompt=True, + tokenize=False, + ) + + def add_assistant_message( + self, + tokenizer: PreTrainedTokenizer, + content: str, + tool_calls: Optional[List[OpenAIFunctionToolCall]] = None, + format: Literal["chatml"] = "chatml", + already_over_long: bool = False, + ) -> None: + """Currently, we only support chatml format.""" + msg = Message(role="assistant", content=content, tool_calls=tool_calls) + self.messages.append(msg) + if tool_calls is not None: + content_with_tool_calls: str = tokenizer.apply_chat_template( # type: ignore + conversation=[msg.model_dump()], add_generation_prompt=False, tokenize=False + ) + else: + content_with_tool_calls = content + # TODO: support other formats + if format in self.format_config: + prefix_msg = self.format_config[format]["assistant_prefix_msg"] + prefix_token_ids = tokenizer.encode(prefix_msg, add_special_tokens=False) + suffix_msg = self.format_config[format]["assistant_suffix_msg"] + suffix_token_ids = tokenizer.encode(suffix_msg, add_special_tokens=False) + if tool_calls is not None: + content = content_with_tool_calls.split(f"{prefix_msg}")[-1].split(f"{suffix_msg}")[0] + content_token_ids = tokenizer.encode(content, add_special_tokens=False) + if self.input_ids[-len(prefix_token_ids) :] == prefix_token_ids: + append_token_ids = content_token_ids + _loss_mask = [1] * len(content_token_ids) + elif self.input_ids[-len(suffix_token_ids) :] == suffix_token_ids: + append_token_ids = prefix_token_ids + content_token_ids + _loss_mask = [0] * len(prefix_token_ids) + [1] * len(content_token_ids) + else: + max_len = max(len(prefix_token_ids), len(suffix_token_ids)) + raise ValueError( + f"""Unsupported end of message format: + {tokenizer.decode(self.input_ids[-max_len:])}, + {tokenizer.decode(self.input_ids)=}, {self.messages=}""" + ) + if not already_over_long: + append_token_ids += suffix_token_ids + _loss_mask += [1] * len(suffix_token_ids) + self.input_ids += append_token_ids + _attention_mask = [1] * len(append_token_ids) + self.attention_mask += _attention_mask + _delta_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() + last_position_id = self.position_ids[-1] + _position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids] + self.loss_mask += _loss_mask + self.position_ids += _position_ids + else: + raise ValueError(f"Unsupported format: {format}") + assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, + {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + + def add_tool_response_message(self, tokenizer: PreTrainedTokenizer, content: str, format: Literal["chatml"] = "chatml") -> None: + """Currently, we only support chatml format.""" + msg = Message(role="tool", content=content) + self.messages.append(msg) + # TODO: support other formats + if format in self.format_config: + prefix_msg = self.format_config[format]["tool_prefix_msg"] + prefix_token_ids = tokenizer.encode(prefix_msg, add_special_tokens=False) + suffix_msg = self.format_config[format]["tool_suffix_msg"] + suffix_token_ids = tokenizer.encode(suffix_msg, add_special_tokens=False) + content_token_ids = tokenizer.encode(content, add_special_tokens=False) + if self.input_ids[-len(prefix_token_ids) :] == prefix_token_ids: + append_token_ids = content_token_ids + suffix_token_ids + elif self.input_ids[-len(suffix_token_ids) :] == suffix_token_ids: + append_token_ids = prefix_token_ids + content_token_ids + suffix_token_ids + else: + raise ValueError(f"Unsupported end of message format: {tokenizer.decode(self.input_ids[-len(prefix_token_ids) :])}") + self.input_ids += append_token_ids + _attention_mask = [1] * len(append_token_ids) + self.attention_mask += _attention_mask + _delta_position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() + last_position_id = self.position_ids[-1] + _position_ids = [pos_id + last_position_id for pos_id in _delta_position_ids] + self.loss_mask += [0] * len(append_token_ids) + self.position_ids += _position_ids + else: + raise ValueError(f"Unsupported format: {format}") + assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, + {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + + def finalize( + self, + tokenizer: PreTrainedTokenizer, + reward_scores: Dict[str, float], + finish_reason_type: FinishReasonTypeEnum = FinishReasonTypeEnum.STOP, + ) -> None: + self.state = AsyncRolloutRequestStateEnum.COMPLETED + self.reward_scores = reward_scores + self.response_ids = self.input_ids[len(self.prompt_ids) :] + if finish_reason_type == FinishReasonTypeEnum.STOP: + pass + elif finish_reason_type == FinishReasonTypeEnum.LENGTH: + pass + else: + raise ValueError(f"Unsupported finalize finish reason type: {finish_reason_type}") + self.truncate_output_ids(tokenizer) + assert len(self.input_ids) == len(self.attention_mask) == len(self.position_ids) == len(self.loss_mask), f"""Request {self.request_id} has different length of {len(self.input_ids)=}, + {len(self.attention_mask)=}, {len(self.position_ids)=}, {len(self.loss_mask)=}""" + + def truncate_output_ids(self, tokenizer: PreTrainedTokenizer) -> None: + self.input_ids = self.input_ids[: self.max_model_len] + self.attention_mask = self.attention_mask[: self.max_model_len] + self.position_ids = self.position_ids[: self.max_model_len] + self.loss_mask = self.loss_mask[: self.max_model_len] + self.response_ids = self.input_ids[len(self.prompt_ids) :][: self.max_response_len] + self.response_attention_mask = self.attention_mask[len(self.prompt_attention_mask) :][: self.max_response_len] + self.response_position_ids = self.position_ids[len(self.prompt_position_ids) :][: self.max_response_len] + self.response_loss_mask = self.loss_mask[len(self.prompt_loss_mask) :][: self.max_response_len] diff --git a/verl/workers/rollout/sglang_rollout/__init__.py b/verl/workers/rollout/sglang_rollout/__init__.py index 43a1eebb4c..fd00061e3b 100644 --- a/verl/workers/rollout/sglang_rollout/__init__.py +++ b/verl/workers/rollout/sglang_rollout/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +from .async_sglang_rollout import AsyncSGLangRollout from .sglang_rollout import SGLangRollout -__all__ = ["SGLangRollout"] +__all__ = ["AsyncSGLangRollout", "SGLangRollout"] diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py new file mode 100644 index 0000000000..eb9542f2d2 --- /dev/null +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -0,0 +1,730 @@ +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import logging +import os +from contextlib import contextmanager +from copy import deepcopy +from json import JSONDecodeError +from typing import TYPE_CHECKING +from uuid import uuid4 + +import numpy as np +import torch +import torch.distributed as dist +from omegaconf import DictConfig +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.function_call_parser import FunctionCallParser +from sglang.srt.openai_api.protocol import Tool +from sglang.srt.patch_torch import monkey_patch_torch_reductions +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.utils import broadcast_pyobj, get_ip, get_open_port +from tensordict import TensorDict +from torch.distributed.device_mesh import init_device_mesh +from torch.nn.utils.rnn import pad_sequence +from transformers import PreTrainedTokenizer + +from verl import DataProto +from verl.third_party.sglang import parallel_state as sglang_ps +from verl.tools.base_tool import BaseTool +from verl.tools.schemas import OpenAIFunctionParsedSchema, OpenAIFunctionToolCall +from verl.utils.debug import GPUMemoryLogger +from verl.utils.model import compute_position_id_with_mask +from verl.utils.net_utils import is_ipv6 +from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length +from verl.workers.rollout.base import BaseRollout +from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + FinishReasonTypeEnum, + Message, +) +from verl.workers.rollout.sglang_rollout.sglang_rollout import _post_process_outputs, _pre_process_inputs + +if TYPE_CHECKING: + from torch import nn + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str: + for parser_type, parser_cls in FunctionCallParser.ToolCallParserEnum.items(): + parser = parser_cls() + if parser.bot_token in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab()): + return parser_type + else: + raise ValueError(f"No tool call parser found for tokenizer {tokenizer}") + + +class AsyncSGLangRollout(BaseRollout): + def __init__( + self, + actor_module: nn.Module | str, + config: DictConfig, + tokenizer, + model_hf_config, + port=None, + **kwargs, + ): + """A SGLang rollout. It requires the module is supported by the SGLang. + + Args: + actor_module: module here follows huggingface APIs + config: DictConfig + tokenizer: the task/model tokenizer + model_hf_config: the huggingface config to initiallize the generating model in SGLang + **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group + """ + super().__init__() + self.config = config + + tool_list = None + if config.multi_turn.tool_config_path is not None: + from omegaconf import OmegaConf + + def initialize_tools(tools_config) -> list: + import importlib.util + import sys + + from verl.tools.schemas import OpenAIFunctionToolSchema + + tool_list = [] + + for tool_config in tools_config.tools: + cls_name = tool_config.class_name + module_name, class_name = cls_name.rsplit(".", 1) + + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + tool_cls = getattr(module, class_name) + + tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) + tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict) + + tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema) + tool_list.append(tool) + + return tool_list + + tools_config_file = config.multi_turn.tool_config_path + tools_config = OmegaConf.load(tools_config_file) + tool_list = initialize_tools(tools_config) + + if tool_list is not None: + self._tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] + self._tool_map = {tool.name: tool for tool in tool_list} + self._tool_call_parser_type = get_tool_call_parser_type(tokenizer) + self._sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in self._tool_schemas] + self._function_call_parser = FunctionCallParser( + self._sgl_tools, + self._tool_call_parser_type, + ) + else: + self._tool_schemas = [] + self._tool_map = {} + self._tool_call_parser_type = None + self._sgl_tools = [] + self._function_call_parser = None + assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" + + tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) + assert tensor_parallel_size <= dist.get_world_size(), "tensor parallel size should be less than or equal to the world size" + + if kwargs.get("train_tp", None) is not None: + # deployed with megatron + os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" + os.environ["MEGATRON_IMPORT_TIMERS"] = "0" + train_tp = kwargs.get("train_tp", None) + num_tp_per_train_tp = train_tp // tensor_parallel_size + sglang_ps.initialize_parallel_state( + tensor_model_parallel_size=tensor_parallel_size, + num_tp_per_train_tp=num_tp_per_train_tp, + ) + + if not self.config.get("max_model_len", None): + self.config.max_model_len = self.config.prompt_length + self.config.response_length + assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length): + {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" + assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length" + # currently max_turns stand for max number of tool calls + if self.config.multi_turn.max_turns is None: + self.config.multi_turn.max_turns = self.config.max_model_len // 3 + + tp_size = tensor_parallel_size + world_size = int(os.getenv("WORLD_SIZE", "-1")) + + # init device mesh + device_mesh_kwargs = dict( + mesh_shape=(world_size // tp_size, tp_size, 1), + mesh_dim_names=["dp", "tp", "pp"], + ) + + device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) + # device_mesh_device = init_device_mesh("cuda", **device_mesh_kwargs) + + # get tp_rank of this process in this tp group + visible_devices = [None] * device_mesh_cpu.size(1) + dist.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], device_mesh_cpu.get_group("tp")) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(visible_devices) + + # initialize the inference engine + monkey_patch_torch_reductions() + nnodes = -(-tp_size // len(visible_devices)) + if nnodes > 1: + ip = get_ip() + port = get_open_port() if port is None else port + [ip, port] = broadcast_pyobj( + [ip, port], + rank=self._tp_rank, + dist_group=device_mesh_cpu.get_group("tp"), + src=device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + dist_init_addr = f"[{ip}]:{port}" if is_ipv6(ip) else f"{ip}:{port}" + else: + dist_init_addr = None + + load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format + self._device_mesh_cpu = device_mesh_cpu + self._tp_rank = device_mesh_cpu["tp"].get_local_rank() + self._tp_size = device_mesh_cpu["tp"].size() + tp_size_per_node = self._tp_size // nnodes + node_rank = self._tp_rank // tp_size_per_node + first_rank_in_node = self._tp_rank % tp_size_per_node == 0 + + if first_rank_in_node: + os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" + self._engine = Engine( + model_path=actor_module, + dtype=config.dtype, + mem_fraction_static=config.gpu_memory_utilization, + enable_memory_saver=True, + base_gpu_id=0, + gpu_id_step=1, + tp_size=self._tp_size, + node_rank=node_rank, + nnodes=nnodes, + load_format=load_format, + dist_init_addr=dist_init_addr, + ) + else: + self._engine = None + + # offload + if self._tp_rank == 0: + self._engine.release_memory_occupation() + + kwargs = dict( + n=1, + max_new_tokens=config.response_length, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + ) + # supporting adding any sampling params from the config file + for k in config.keys(): + if hasattr(SamplingParams(), str(k)): + kwargs[k] = config.get(k) + print(f"kwargs: {kwargs}") + self.sampling_params = kwargs + + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id + + @contextmanager + def update_sampling_params(self, **kwargs): + # update sampling params + old_sampling_params_args = {} + if kwargs: + for key, value in kwargs.items(): + if key in self.sampling_params: + old_value = self.sampling_params[key] + old_sampling_params_args[key] = old_value + self.sampling_params[key] = value + yield + # roll back to previous sampling params + # if len(old_sampling_params_args): + for key, value in old_sampling_params_args.items(): + self.sampling_params[key] = value + + @GPUMemoryLogger(role="sglang rollout", logger=logger) + @torch.no_grad() + def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + # if self.config.free_cache_engine: + + idx = prompts.batch["input_ids"] # (bs, prompt_length) + # left-padded attention_mask + attention_mask = prompts.batch["attention_mask"] + position_ids = prompts.batch["position_ids"] + + # used to construct attention_mask + eos_token_id = prompts.meta_info["eos_token_id"] + + batch_size = idx.size(0) + + # Extract non-tensor data + non_tensor_batch = prompts.non_tensor_batch + if "raw_prompt_ids" not in non_tensor_batch: + non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) + + if "multi_modal_data" in non_tensor_batch: + sglang_inputs = [] + for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")): + sglang_inputs.append( + { + "prompt_token_ids": raw_prompt_ids, + "multi_modal_data": multi_modal_data, + "image_data": multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None, + } + ) + else: + sglang_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")] + + # Ensure token IDs are lists + for input_data in sglang_inputs: + if isinstance(input_data["prompt_token_ids"], np.ndarray): + input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() + elif not isinstance(input_data["prompt_token_ids"], list): + raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") + + # Extract token IDs and image data for SGLang Engine + idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] + image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] + + do_sample = prompts.meta_info.get("do_sample", True) + is_validate = prompts.meta_info.get("validate", False) + if not do_sample: + kwargs = dict( + n=1, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + temperature=0, + top_p=1, + top_k=-1, + ignore_eos=False, + min_new_tokens=0, + max_new_tokens=self.config.response_length, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ) + elif is_validate: + kwargs = dict( + top_k=self.config.val_kwargs.top_k, + top_p=self.config.val_kwargs.top_p, + temperature=self.config.val_kwargs.temperature, + n=1, # if validate, already repeat in ray_trainer + ) + + # users can customize different sampling_params at different run + with self.update_sampling_params(**kwargs): + print(f"{self.sampling_params=}") + if self._tp_rank == 0: + loop = asyncio.get_event_loop() + output = loop.run_until_complete( + self._engine.async_generate( + prompt=None, # because we have already convert it to prompt token id + sampling_params=self.sampling_params, + return_logprob=True, + input_ids=idx_list, + image_data=image_list, + ) + ) + else: + output = None + # Most naive implementation, can extract tensor and send via gloo if too slow + [output] = broadcast_pyobj( + data=[output], + rank=self._tp_rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + out = _post_process_outputs(self.tokenizer, output) + + response = out[0].to(idx.device) + # log_probs = out[1].to(idx.device) + + if response.shape[1] < self.config.response_length: + response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) + # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) + + # utilize current sampling params + if self.sampling_params.get("n", 1) > 1 and do_sample: + idx = idx.repeat_interleave(self.sampling_params["n"], dim=0) + attention_mask = attention_mask.repeat_interleave(self.sampling_params["n"], dim=0) + position_ids = position_ids.repeat_interleave(self.sampling_params["n"], dim=0) + batch_size = batch_size * self.sampling_params["n"] + if "multi_modal_inputs" in non_tensor_batch.keys(): + non_tensor_batch["multi_modal_inputs"] = np.repeat(non_tensor_batch["multi_modal_inputs"], self.sampling_params["n"], axis=0) + seq = torch.cat([idx, response], dim=-1) + + response_length = response.size(1) + delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) + delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) + + # TODO(sgm): fix position_ids on right_pad + # prompt: left pad + response: right pad + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + response_position_ids = position_ids[:, -1:] + delta_position_id + position_ids = torch.cat([position_ids, response_position_ids], dim=-1) + response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) + attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) + + # all the tp ranks should contain the same data here. data in all ranks are valid + batch = TensorDict( + { + "prompts": idx, + "responses": response, + "input_ids": seq, # here input_ids become the whole sentences + # 'old_log_probs': log_probs, # we will recompute old log prob with actor + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=batch_size, + ) + + # free cache engine + if self.config.free_cache_engine and self._engine is not None: + self._engine.tokenizer_manager.flush_cache() + + return DataProto(batch=batch) + + async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bool = True, is_validate: bool = False, **kwargs) -> AsyncRolloutRequest: + assert self._tp_rank == 0, "only the master process can call this function" + _req = deepcopy(req) + finish_reason_type = None + output = None + + current_turns = 0 + while current_turns < self.config.multi_turn.max_turns: + if _req.state == AsyncRolloutRequestStateEnum.PENDING: + if _req.tools is not None: + tool_creation_coroutines = [] + for tool_schema in _req.tools: + tool = self._tool_map[tool_schema.function.name] + create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) + tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) + await asyncio.gather(*tool_creation_coroutines) + _req.state = AsyncRolloutRequestStateEnum.RUNNING + elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING: + if _req.messages[-1].tool_calls is not None: + parsed_tool_calls = _req.messages[-1].tool_calls + tool_call_results = await asyncio.gather( + *[ + self._tool_map[tool_call.function.name].execute( + _req.request_id, + tool_call.function.arguments, + **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}), + ) + for tool_call in parsed_tool_calls + ] + ) + for tool_call, (resp, reward, metrics) in zip(parsed_tool_calls, tool_call_results): + _req.add_tool_response_message(self.tokenizer, resp, format=self.config.multi_turn.format) + if len(_req.input_ids) >= self.config.max_model_len: + break + if len(_req.input_ids) >= self.config.max_model_len: + finish_reason_type = FinishReasonTypeEnum.STOP + break + _req.state = AsyncRolloutRequestStateEnum.RUNNING + else: + raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") + elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: + generation_prompt = _req.get_generation_prompt(self.tokenizer) + if not do_sample: + kwargs = dict( + n=1, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + temperature=0, + top_p=1, + top_k=-1, + ignore_eos=False, + min_new_tokens=0, + max_new_tokens=self.config.response_length, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ) + elif is_validate: + # TODO: try ** + kwargs = { + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer + } + if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess + kwargs["n"] = 1 + # users can customize different sampling_params at different run + with self.update_sampling_params(**kwargs): + output = await self._engine.async_generate( + prompt=generation_prompt, + sampling_params=self.sampling_params, + return_logprob=False, + ) + + content = output["text"] + finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) + current_turns += 1 + if finish_reason_type == FinishReasonTypeEnum.LENGTH: + _req.add_assistant_message(self.tokenizer, content, already_over_long=True, format=self.config.multi_turn.format) + break + else: + if self._function_call_parser and self._function_call_parser.has_tool_call(content): + finish_reason_type = FinishReasonTypeEnum.TOOL_CALL + _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING + try: + normed_content, tool_calls = self._function_call_parser.parse_non_stream(content) + except JSONDecodeError: + normed_content = content + tool_calls = [] + except AttributeError: + normed_content = content + tool_calls = [] + parsed_tool_calls = [ + OpenAIFunctionToolCall( + id=str(tool_call.tool_index), + function=OpenAIFunctionParsedSchema(name=tool_call.name, arguments=tool_call.parameters), + ) + for tool_call in tool_calls + ] + if len(parsed_tool_calls) > 0: + _req.add_assistant_message( + self.tokenizer, + normed_content, + tool_calls=parsed_tool_calls, + format=self.config.multi_turn.format, + ) + else: + _req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format) + finish_reason_type = FinishReasonTypeEnum.STOP + _req.state = AsyncRolloutRequestStateEnum.COMPLETED + break + else: + _req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format) + break + + if current_turns >= self.config.multi_turn.max_turns: + finish_reason_type = FinishReasonTypeEnum.STOP + + # Calculate the reward for each tool + async def calc_reward_and_release_fn(name: str, tool: BaseTool): + reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {})) + await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {})) + return name, reward + + tool_reward_tasks = [] + for name in _req.tools_kwargs.keys(): + tool = self._tool_map[name] + tool_reward_tasks.append(calc_reward_and_release_fn(name, tool)) + tool_reward_scores = await asyncio.gather(*tool_reward_tasks) + tool_reward_scores = dict(tool_reward_scores) + _req.finalize(self.tokenizer, tool_reward_scores, finish_reason_type) + + return _req + + @torch.no_grad() + def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto: + # Async rollout with tools support + do_sample = prompts.meta_info.get("do_sample", True) + is_validate = prompts.meta_info.get("validate", False) + tgt_device = prompts.batch["input_ids"].device + if self._tp_rank == 0: + req_list = self._preprocess_prompt_to_async_rollout_requests( + prompts, + n=1 if is_validate else self.config.n, + ) + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete( + asyncio.gather( + *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], + ) + ) + sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset)) + else: + sorted_output_req_list = None + + [sorted_output_req_list] = broadcast_pyobj( + data=[sorted_output_req_list], + rank=self._tp_rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + # Construct the batch data + prompt_ids, response_ids = [], [] + prompt_attention_mask, response_attention_mask = [], [] + prompt_position_ids, response_position_ids = [], [] + prompt_loss_mask, response_loss_mask = [], [] + messages = [] + reward_scores = [] + for req in sorted_output_req_list: + assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" + assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of + {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}""" + error_message_lines = [ + f"""Request {req.request_id} has input_ids length {len(req.input_ids)} + greater than max_model_len {self.config.max_model_len}""", + f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}", + f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}", + f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}", + f"Messages: {req.messages}", + f"Max model length: {req.max_model_len}", + ] + error_message = "\n".join(error_message_lines) + assert len(req.input_ids) <= self.config.max_model_len, error_message + + prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device)) + response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device)) + if len(req.response_ids) > self.config.response_length: + print( + f"""{req.request_id=} has response_ids length {len(req.response_ids)} + greater than max_response_len {self.config.response_length},\n{req=}""" + ) + prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device)) + response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device)) + prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device)) + response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device)) + prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device)) + response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device)) + messages.append({"messages": req.messages}) + reward_scores.append(req.reward_scores) + + prompt_ids = pad_sequence(prompt_ids, batch_first=True, padding_value=self.pad_token_id, padding_side="left") + if prompt_ids.shape[1] < self.config.prompt_length: + prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True) + response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) + if response_ids.shape[1] < self.config.response_length: + response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id) + prompt_attention_mask = pad_sequence(prompt_attention_mask, batch_first=True, padding_value=0, padding_side="left") + if prompt_attention_mask.shape[1] < self.config.prompt_length: + prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True) + response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) + if response_attention_mask.shape[1] < self.config.response_length: + response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) + prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left") + if prompt_position_ids.shape[1] < self.config.prompt_length: + prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True) + response_position_ids = pad_sequence(response_position_ids, batch_first=True, padding_value=0) + if response_position_ids.shape[1] < self.config.response_length: + response_position_ids = pad_sequence_to_length(response_position_ids, self.config.response_length, 0) + prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") + if prompt_loss_mask.shape[1] < self.config.prompt_length: + prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) + response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) + if response_loss_mask.shape[1] < self.config.response_length: + response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0) + + input_ids = torch.cat((prompt_ids, response_ids), dim=-1) + attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) + position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) + loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1) + + # Construct the batch data + batch = TensorDict( + { + "prompts": prompt_ids, + "responses": response_ids, + "input_ids": input_ids, # here input_ids become the whole sentences + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + }, + batch_size=len(sorted_output_req_list), + ) + + return DataProto(batch=batch, non_tensor_batch={"messages": np.array(messages), "reward_scores": np.array(reward_scores)}) + + def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]: + assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages" + req_list = [] + for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]): + for rollout_offset in range(n): + if self._tool_schemas: + _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] + _tool_schemas = [] + for k in _tools_kwargs.keys(): + _tool_schemas.append(self._tool_map[k].get_openai_tool_schema()) + prompt_with_chat_template = self.tokenizer.apply_chat_template( + conversation=raw_prompt, + tools=[tool.model_dump() for tool in _tool_schemas], + add_generation_prompt=True, + tokenize=False, + return_tensors="pt", + ) + input_data = self.tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) + _input_ids = input_data["input_ids"][0].tolist() + _attention_mask = input_data["attention_mask"][0].tolist() + _position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist() + if len(_input_ids) > self.config.prompt_length: + logger.warning( + "Prompt {} has length {} greater than max_prompt_len {}", + data_idx, + len(_input_ids), + self.config.prompt_length, + ) + _input_ids = _input_ids[: self.config.prompt_length] + _attention_mask = _attention_mask[: self.config.prompt_length] + _position_ids = _position_ids[: self.config.prompt_length] + else: + _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) + _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) + _position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() + _tool_schemas = [] + _tools_kwargs = {} + + req = AsyncRolloutRequest( + batch_data_id=data_idx, + rollout_offset=rollout_offset, + request_id=str(uuid4()), + state=AsyncRolloutRequestStateEnum.PENDING, + messages=[Message.model_validate(msg) for msg in raw_prompt], + tools=_tool_schemas, + tools_kwargs=_tools_kwargs, + input_ids=_input_ids, + prompt_ids=_input_ids, + response_ids=[], + attention_mask=_attention_mask, + prompt_attention_mask=_attention_mask, + response_attention_mask=[], + position_ids=_position_ids, + prompt_position_ids=_position_ids, + response_position_ids=[], + loss_mask=[0] * len(_input_ids), + prompt_loss_mask=[0] * len(_input_ids), + response_loss_mask=[], + reward_scores={}, + max_response_len=self.config.response_length, + max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), + ) + + error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}" + assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message + + req_list.append(req) + + return req_list diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index d86ee0536d..4cbba673df 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -276,6 +276,9 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: batch_size = batch_size * self.sampling_params.n if "multi_modal_inputs" in non_tensor_batch.keys(): non_tensor_batch["multi_modal_inputs"] = _repeat_interleave(non_tensor_batch["multi_modal_inputs"], self.sampling_params.n) + # NOTE(linjunrong): for multi-turn https://github.com/volcengine/verl/pull/1037 + if "tools_kwargs" in non_tensor_batch.keys(): + non_tensor_batch["tools_kwargs"] = _repeat_interleave(non_tensor_batch["tools_kwargs"], self.sampling_params.n) seq = torch.cat([idx, response], dim=-1) diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 1d3364605d..bc2ff6775d 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -1,4 +1,6 @@ # Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -27,12 +29,18 @@ import logging import os +from typing import Union import torch +import torch.distributed as dist +from sglang.srt.entrypoints.engine import Engine from sglang.srt.entrypoints.verl_engine import VerlEngine +from sglang.srt.model_executor.model_runner import LocalSerializedTensor +from sglang.srt.utils import MultiprocessingSerializer from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP +from torch.distributed.tensor import DTensor from verl import DataProto from verl.protocol import all_gather_data_proto @@ -48,12 +56,18 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +def _preprocess_tensor_for_update_weights(tensor: torch.Tensor): + if isinstance(tensor, DTensor): + return tensor.full_tensor() + return tensor + + class FSDPSGLangShardingManager(BaseShardingManager): @check_cuda_is_available() def __init__( self, module: FSDP, - inference_engine: VerlEngine, + inference_engine: Union[VerlEngine, Engine], model_config, full_params: bool = False, device_mesh: DeviceMesh = None, @@ -95,10 +109,7 @@ def __enter__(self): params = self.module.state_dict() log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory - # load_format = None if self.full_params else "dtensor" - self.inference_engine.resume_memory_occupation() - - self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None) + self.update_weights(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params @@ -114,14 +125,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) - self.inference_engine.release_memory_occupation() + self.release_memory() log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) - # self.module.to('cuda') - # if torch.distributed.get_rank() == 0: - # print(f'after actor module to cuda in sharding manager memory allocated: - # {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - self.module.train() # add empty cache after each compute @@ -132,6 +138,13 @@ def __exit__(self, exc_type, exc_value, traceback): self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) + def update_weights(self, params): + self.inference_engine.resume_memory_occupation() + self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None) + + def release_memory(self): + self.inference_engine.release_memory_occupation() + def preprocess_data(self, data: DataProto) -> DataProto: """All gather across tp group to make each rank has identical input.""" if self.device_mesh["infer_tp"].mesh.size()[0] == 1: @@ -154,3 +167,54 @@ def postprocess_data(self, data: DataProto) -> DataProto: local_prompts = data.chunk(chunks=tp_size) data = local_prompts[tp_rank] return data + + +class FSDPAsyncSGLangShardingManager(FSDPSGLangShardingManager): + def __init__( + self, + module: FSDP, + inference_engine: Engine, + model_config, + full_params: bool = False, + device_mesh: DeviceMesh = None, + offload_param: bool = False, + ): + super().__init__(module, inference_engine, model_config, full_params, device_mesh, offload_param) + + def update_weights(self, params): + load_format = None if self.full_params else "dtensor" + if self.device_mesh["infer_tp"].get_local_rank() == 0: + self.inference_engine.resume_memory_occupation() + + # Most naive implementation, can optimize a lot if it is bottleneck from sglang VerlEngine + named_tensors = [(k, v) for k, v in params.items()] + load_format = None + for tensor_index, (name, tensor) in enumerate(named_tensors): + serialized_tensor = MultiprocessingSerializer.serialize(_preprocess_tensor_for_update_weights(tensor)) + + if self.device_mesh["infer_tp"].get_local_rank() == 0: + gathered_serialized_tensors = [None for _ in range(self.device_mesh["infer_tp"].mesh.size()[0])] + else: + gathered_serialized_tensors = None + dist.gather_object( + obj=serialized_tensor, + object_gather_list=gathered_serialized_tensors, + dst=self.device_mesh["infer_tp"].mesh.tolist()[0], + group=self.device_mesh["infer_tp"].get_group(), + ) + + if self.device_mesh["infer_tp"].get_local_rank() == 0: + self.inference_engine.update_weights_from_tensor( + named_tensors=[ + ( + name, + LocalSerializedTensor(values=gathered_serialized_tensors), + ) + ], + load_format=load_format, + flush_cache=tensor_index == len(named_tensors) - 1, + ) + + def release_memory(self): + if self.device_mesh["infer_tp"].get_local_rank() == 0: + self.inference_engine.release_memory_occupation()