diff --git a/.github/workflows/e2e_ppo_trainer.yml b/.github/workflows/e2e_ppo_trainer.yml index cf77df2229b..1fbff6e3e51 100644 --- a/.github/workflows/e2e_ppo_trainer.yml +++ b/.github/workflows/e2e_ppo_trainer.yml @@ -204,7 +204,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -222,7 +222,7 @@ jobs: ray stop --force ENGINE=sglang bash tests/e2e/ppo_trainer/run_function_reward.sh - e2e_ppo_trainer_sglang_async: + e2e_ppo_trainer_sglang_multiturn_with_tool: runs-on: [L20x8] needs: pre_commit_for_ppo timeout-minutes: 40 # Increase this timeout value as needed @@ -233,36 +233,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4 - options: --gpus all --shm-size=10g - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - name: Install the current repository - run: | - pip3 install -e .[test,gpu,sglang] --no-deps - - name: Prepare gsm8k dataset - run: | - ray stop --force - python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async - run: | - ray stop --force - ENGINE=sglang_async bash tests/e2e/ppo_trainer/run_function_reward.sh - - e2e_ppo_trainer_sglang_async_with_tool: - runs-on: [L20x8] - needs: pre_commit_for_ppo - timeout-minutes: 40 # Increase this timeout value as needed - env: - HTTP_PROXY: ${{ secrets.PROXY_HTTP }} - HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} - NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" - HF_ENDPOINT: "https://hf-mirror.com" - HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable - container: - image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -275,7 +246,7 @@ jobs: run: | ray stop --force python3 examples/data_preprocess/gsm8k_multiturn_w_tool.py --local_dir $HOME/data/gsm8k_verl_sgl_multi_turn_preprocessed - - name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang async + - name: Running GSM8K with tool E2E training tests on 8 L20 GPUs with rmpad using function rm and save ckpt with sglang run: | ray stop --force bash tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh @@ -295,7 +266,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=50g # Visual dataloader requires large memory steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -367,7 +338,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=50g # Visual dataloader requires large memory steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index f9dc924483e..fbd0c699121 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -50,7 +50,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -92,7 +92,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -134,7 +134,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -167,7 +167,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -206,7 +206,7 @@ jobs: HF_ENDPOINT: "https://hf-mirror.com" HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable container: - image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index 917b7a77909..59ae19cc981 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -56,7 +56,7 @@ jobs: HF_HUB_ENABLE_HF_TRANSFER: 1 SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True" container: - image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4 + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3 options: --gpus all --shm-size=10g steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 @@ -73,11 +73,7 @@ jobs: - name: Test the latest SGLang run: | cd tests/workers/rollout - torchrun --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_sglang_spmd.py - - name: Test the latest SGLang async - run: | - cd tests/workers/rollout - torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_async_spmd.py + torchrun --nnodes=1 --nproc_per_node=2 $(which pytest) -s test_sglang_spmd.py - name: Test the latest SGLang Rollout async with tool run: | cd tests/workers/rollout diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index d8992c4a4e6..faec8aac02a 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -6,7 +6,7 @@ # Support - Traing: fsdp; Inference: vllm # FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 # Support - Traing: fsdp; Inference: vllm, sglang -FROM lmsysorg/sglang:v0.4.6.post4-rocm630 +FROM lmsysorg/sglang:v0.4.6.post5-rocm630 # Set working directory # WORKDIR $PWD/app diff --git a/docker/Dockerfile.sglang b/docker/Dockerfile.sglang index 1a8c16d2bff..11ad4a77da6 100644 --- a/docker/Dockerfile.sglang +++ b/docker/Dockerfile.sglang @@ -36,8 +36,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \ pip config set global.extra-index-url "${PIP_INDEX}" && \ python -m pip install --upgrade pip -# Install sglang-0.4.6.post4 and torch-memory-saver -RUN pip install "sglang[all]==0.4.6.post4" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir +# Install sglang-0.4.6.post5 and torch-memory-saver +RUN pip uninstall -y cuda-python && pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir # Install torch-2.6.0 RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \ diff --git a/docker/Dockerfile.vllm.sglang.megatron b/docker/Dockerfile.vllm.sglang.megatron index 19df378e270..8921990178d 100644 --- a/docker/Dockerfile.vllm.sglang.megatron +++ b/docker/Dockerfile.vllm.sglang.megatron @@ -56,12 +56,12 @@ RUN aria2c --always-resume=true --max-tries=99999 https://developer.download.nvi update-alternatives --set cuda /usr/local/cuda-12.4 && \ rm -rf /usr/local/cuda-12.6 -# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post4 +# Install torch-2.6.0+cu124 + vllm-0.8.5.post1 + sglang-0.4.6.post5 # torch-2.6.0+cu124: cxx11abi=False # torch-2.6.0+cu126: cxx11abi=True # see https://github.com/flashinfer-ai/flashinfer/issues/911 # Install sglang-0.4.6.post1 and torch-memory-saver -RUN pip install "sglang[all]==0.4.6.post1" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir +RUN pip install "sglang[all]==0.4.6.post5" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir RUN pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata diff --git a/docs/amd_tutorial/amd_build_dockerfile_page.rst b/docs/amd_tutorial/amd_build_dockerfile_page.rst index 8f45f89affa..ff2b4e9c72e 100644 --- a/docs/amd_tutorial/amd_build_dockerfile_page.rst +++ b/docs/amd_tutorial/amd_build_dockerfile_page.rst @@ -22,7 +22,7 @@ docker/Dockerfile.rocm # Support - Traing: fsdp; Inference: vllm # FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 # Support - Traing: fsdp; Inference: vllm, sglang - FROM lmsysorg/sglang:v0.4.6.post4-rocm630 + FROM lmsysorg/sglang:v0.4.6.post5-rocm630 # Set working directory # WORKDIR $PWD/app diff --git a/docs/sglang_multiturn/multiturn.rst b/docs/sglang_multiturn/multiturn.rst index bbb3e9bbc5d..970ba46c1d2 100644 --- a/docs/sglang_multiturn/multiturn.rst +++ b/docs/sglang_multiturn/multiturn.rst @@ -11,9 +11,9 @@ To enable multi-turn rollout, make sure to configure the following fields in you actor_rollout_ref: rollout: multi_turn: True - name: "sglang_async" + name: "sglang" -These configuration activates the sglang_async engine for multi-turn interaction during rollout. +These configuration activates the sglang engine for multi-turn interaction during rollout. Custom Tool Configuration ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/start/install.rst b/docs/start/install.rst index e324cadd1b4..20a9b6705b8 100644 --- a/docs/start/install.rst +++ b/docs/start/install.rst @@ -42,7 +42,7 @@ For vLLM with Megatron or FSDP, please use the stable version of image ``whatcan For latest vLLM with FSDP, please refer to ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``. -For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post4`` which is provided by SGLang RL Group. +For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group. See files under ``docker/`` for NGC-based image or if you want to build your own. @@ -79,7 +79,7 @@ See files under ``docker/`` for NGC-based image or if you want to build your own - **Flash Attenttion**: 2.7.4.post1 - **Flash Infer**: 0.2.2.post1 - **vLLM**: 0.8.5 - - **SGLang**: 0.4.6.post4 + - **SGLang**: 0.4.6.post5 - **Megatron-LM**: core_v0.12.0 - **TransformerEngine**: 2.3 - **Ray**: 2.44.1 diff --git a/docs/workers/sglang_worker.rst b/docs/workers/sglang_worker.rst index a9a2b2098c0..df208a45ab8 100644 --- a/docs/workers/sglang_worker.rst +++ b/docs/workers/sglang_worker.rst @@ -21,7 +21,7 @@ Please always follow the following command to install SGLang with verl. .. code-block:: bash pip install --upgrade pip - # Currently 0.4.6.post4, subject to updates at any time, please refer to the latest version specified in `setup.py` + # Currently 0.4.6.post5, subject to updates at any time, please refer to the latest version specified in `setup.py` pip install -e ".[sglang]" You can check the following dependencies are in your environment: @@ -31,8 +31,8 @@ You can check the following dependencies are in your environment: - **PyTorch**: 2.6.0+cu124 - **CUDA**: 12.4 - **flashinfer-python**: 0.2.5+cu124torch2.6 - - **sgLang**: 0.4.6.post4 - - **sgl-kernel**: 0.1.2.post1 + - **sgLang**: 0.4.6.post5 + - **sgl-kernel**: 0.1.4 Using SGLang as the Inference Backend for PPO Training on a Single Machine ------------------------------------------------------------------------- @@ -87,7 +87,7 @@ Why export SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK? 1. ``verl`` initializes a ``SGLangRollout`` module during rollout, which is used to evaluate/generate samples. -2. ``SGLangRollout`` will initialize ``VerlEngine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP). +2. ``SGLangRollout`` will initialize ``Engine``, and further initialize a ``torch.distributed.DeviceMesh``, used to support Tensor Parallel (TP). 3. ``DeviceMesh.init()`` internally checks the free GPU memory of all participating devices. If the difference is too large (more than ~10%), it directly reports an error to avoid initialization failures or deadlocks. @@ -111,7 +111,7 @@ Early workers already use up GPU memory → late workers still have empty memory **3. SGLang's TP init uses "all-device broadcast", but there's no uniform release timing** -Although ``SGLangRollout`` may only involve subset of GPUs, its ``VerlEngine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so: +Although ``SGLangRollout`` may only involve subset of GPUs, its ``Engine`` initialization calls ``torch.distributed.init_process_group()`` and broadcasts weights, so: - Non-rollout GPUs also join the communication. - Later on, ``DeviceMesh`` init will fail due to "inconsistent memory". diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml index 5191d91b3db..db133f8af77 100644 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_grpo.yaml @@ -15,7 +15,7 @@ data: actor_rollout_ref: hybrid_engine: True rollout: - name: sglang_async + name: sglang multi_turn: enable: True max_turns: 5 diff --git a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml index b3c5dcb922d..8609d890166 100644 --- a/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml +++ b/examples/sglang_multiturn/config/gsm8k_multiturn_megatron_grpo.yaml @@ -15,7 +15,7 @@ data: actor_rollout_ref: hybrid_engine: True rollout: - name: sglang_async + name: sglang multi_turn: enable: True max_turns: 5 diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh index 7dedada0ec5..a3bcde50c6f 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh @@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.name=sglang \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.n=16 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ @@ -41,7 +41,7 @@ python3 -m verl.trainer.main_ppo \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-verify-n16' \ + trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh index 5cca4b471b9..ee17a18b927 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_4xgpu.sh @@ -32,7 +32,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.name=sglang \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.n=16 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh index 122b424456a..671d58edd28 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_megatron_gsm8k_multiturn.sh @@ -45,7 +45,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.name=sglang \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ @@ -53,7 +53,7 @@ python3 -m verl.trainer.main_ppo \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ + trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-n8-mcore-v2505201745_seed42' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ diff --git a/requirements_sglang.txt b/requirements_sglang.txt index 2f99c97862c..57d5e0befc1 100644 --- a/requirements_sglang.txt +++ b/requirements_sglang.txt @@ -17,6 +17,6 @@ torchdata torchvision transformers wandb -sglang[all]==0.4.6.post4 +sglang[all]==0.4.6.post5 torch-memory-saver>=0.0.5 -huggingface_hub \ No newline at end of file +huggingface_hub diff --git a/setup.py b/setup.py index 6075bb2af8b..e6d3dc5b4cc 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,12 @@ GPU_REQUIRES = ["liger-kernel", "flash-attn"] MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.5"] -SGLANG_REQUIRES = ["tensordict<=0.6.2", "sglang[srt,openai]==0.4.6.post4", "torch-memory-saver>=0.0.5", "torch==2.6.0"] +SGLANG_REQUIRES = [ + "tensordict<=0.6.2", + "sglang[srt,openai]==0.4.6.post5", + "torch-memory-saver>=0.0.5", + "torch==2.6.0", +] extras_require = { "test": TEST_REQUIRES, diff --git a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh b/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh index 2797b2cf5c5..333f6d2bd93 100644 --- a/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh +++ b/tests/e2e/run_gsm8k_fsdp_sgl_multiturn_w_tool.sh @@ -36,7 +36,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ - actor_rollout_ref.rollout.name=sglang_async \ + actor_rollout_ref.rollout.name=sglang \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ @@ -46,7 +46,7 @@ python3 -m verl.trainer.main_ppo \ trainer.critic_warmup=0 \ trainer.logger=['console'] \ trainer.project_name='gsm8k_async_rl' \ - trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-async-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \ + trainer.experiment_name=qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-$FSDP_STRATEGY-rebased-0427-verify-n16 \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ diff --git a/tests/e2e/run_ppo_trainer_megatron.sh b/tests/e2e/run_ppo_trainer_megatron.sh index 691a9f188de..ac9928cbb0b 100644 --- a/tests/e2e/run_ppo_trainer_megatron.sh +++ b/tests/e2e/run_ppo_trainer_megatron.sh @@ -78,7 +78,7 @@ if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then CHECKPOINT_CONTENTS=['model','optimizer','extra'] fi -ENGINES=("vllm" "sglang_async") +ENGINES=("vllm" "sglang") exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" diff --git a/tests/workers/rollout/test_async_sglang_server.py b/tests/workers/rollout/test_async_sglang_server.py index a9cc7ade01c..914f527c9e9 100644 --- a/tests/workers/rollout/test_async_sglang_server.py +++ b/tests/workers/rollout/test_async_sglang_server.py @@ -1,3 +1,4 @@ +# Copyright 2023-2024 SGLang Team # Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,7 +21,6 @@ @patch.dict( "sys.modules", { - "verl.workers.rollout.sglang_rollout.async_sglang_rollout": MagicMock(AsyncSGLangRollout=MagicMock()), "verl.workers.rollout.sglang_rollout.sglang_rollout": MagicMock(SGLangRollout=MagicMock()), }, ) diff --git a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py index 655e6124a13..667d4927103 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_search_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_search_tools.py @@ -32,7 +32,7 @@ from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema from verl.tools.search_tool import SearchTool from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message -from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout +from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout DEFAULT_USER_CONTENT_PREFIX = ( "Answer the given question. You must conduct reasoning inside and " @@ -143,11 +143,11 @@ def search_data_proto(self, search_data, qwen_tokenizer): prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index}) return prompts - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_tools_registration(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config): - rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) assert len(rollout._tool_schemas) == 1 assert "search" in rollout._tool_map.keys() from verl.tools.search_tool import SearchTool @@ -156,11 +156,11 @@ def test_tools_registration(self, mock_env, mock_engine, mock_sampling, search_r # depend on the tokenizer assert rollout._tool_call_parser_type == "qwen25" - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto): - rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) req_list = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING @@ -186,12 +186,12 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, search ), ) - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): search_rollout_config.multi_turn.max_turns = 1 - rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) req = rollout._preprocess_prompt_to_async_rollout_requests(search_data_proto, n=1)[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() @@ -223,9 +223,9 @@ def test_over_size_case(self, mock_env, mock_engine, mock_sampling, search_rollo ) @patch.object(SearchTool, "execute", new_callable=AsyncMock) - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): _, expect_turn_array, tool_return_array = search_data @@ -233,7 +233,7 @@ def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_e mock_execute.side_effect = [(msg, 0.0, {"status": "success"}) for msg in tool_return_array] search_rollout_config.multi_turn.max_turns = 10 - rollout = AsyncSGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) rollout._tool_map["search"].retrieval_service_url = "mock://dummy" @@ -272,9 +272,9 @@ def test_tool_call_basic_case(self, mock_sampling, mock_engine, mock_env, mock_e assert search_counter == 2 @patch.object(SearchTool, "execute", new_callable=AsyncMock) - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_execute, search_rollout_config, qwen_tokenizer, qwen_model_config, search_data_proto, search_data): _, expect_turn_array, tool_return_array = search_data @@ -285,7 +285,7 @@ def test_tool_call_batch_case(self, mock_sampling, mock_engine, mock_env, mock_e ] * 100 search_rollout_config.multi_turn.max_turns = 10 - rollout = AsyncSGLangRollout( + rollout = SGLangRollout( actor_module="", config=search_rollout_config, tokenizer=qwen_tokenizer, @@ -327,7 +327,7 @@ async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, *_args, **_ req_turns_counter[_req.batch_data_id] += 1 return await fut - with patch.object(AsyncSGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): + with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete(asyncio.gather(*[rollout._async_rollout_a_request(r, True, False) for r in req_list])) diff --git a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py b/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py index 166dc064708..fe027a60e27 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_sf_tools.py @@ -32,7 +32,7 @@ from verl.tools.sandbox_fusion_tools import TokenBucketWorker from verl.tools.schemas import OpenAIFunctionParametersSchema, OpenAIFunctionPropertySchema, OpenAIFunctionSchema, OpenAIFunctionToolSchema from verl.workers.rollout.schemas import AsyncRolloutRequest, AsyncRolloutRequestStateEnum, Message -from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout +from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout sandbox_url = "" @@ -200,11 +200,11 @@ def sandbox_data_proto(self, sandbox_fusion_data, qwen_tokenizer): prompts = DataProto(batch=prompt_dict, non_tensor_batch={"raw_prompt": messages, "tools_kwargs": tools_kwargs, "index": index}) return prompts - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_tools_registration(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config): - rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) assert len(rollout._tool_schemas) == 1 assert "code_interpreter" in rollout._tool_map.keys() from verl.tools.sandbox_fusion_tools import SandboxFusionTool @@ -212,11 +212,11 @@ def test_tools_registration(self, mock_env, mock_engine, mock_sampling, sandbox_ assert isinstance(rollout._tool_map["code_interpreter"], SandboxFusionTool) assert rollout._tool_call_parser_type == "qwen25" - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto): - rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) req_list = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1) assert len(req_list) == 1 assert req_list[0].state == AsyncRolloutRequestStateEnum.PENDING @@ -242,12 +242,12 @@ def test_rollout_req_creation(self, mock_env, mock_engine, mock_sampling, sandbo ), ) - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data): sandbox_fusion_rollout_config.multi_turn.max_turns = 1 - rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) req.finalize = MagicMock() @@ -279,12 +279,12 @@ def test_over_size_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusi ) @skip_if_valid_sandbox(sandbox_url) - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data): sandbox_fusion_rollout_config.multi_turn.max_turns = 10 - rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] req = MagicMock(wraps=req, spec=AsyncRolloutRequest) @@ -323,12 +323,12 @@ def test_tool_call_basic_case(self, mock_env, mock_engine, mock_sampling, sandbo assert code_counter == 2 @skip_if_valid_sandbox(sandbox_url) - @patch.object(AsyncSGLangRollout, "_init_distributed_env", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_inference_engine", return_value=None) - @patch.object(AsyncSGLangRollout, "_init_sampling_params", return_value=None) + @patch.object(SGLangRollout, "_init_distributed_env", return_value=None) + @patch.object(SGLangRollout, "_init_inference_engine", return_value=None) + @patch.object(SGLangRollout, "_init_sampling_params", return_value=None) def test_tool_call_batch_case(self, mock_env, mock_engine, mock_sampling, sandbox_fusion_rollout_config, qwen_tokenizer, qwen_model_config, sandbox_data_proto, sandbox_fusion_data): sandbox_fusion_rollout_config.multi_turn.max_turns = 10 - rollout = AsyncSGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) + rollout = SGLangRollout(actor_module="", config=sandbox_fusion_rollout_config, tokenizer=qwen_tokenizer, model_hf_config=qwen_model_config) self._tool_map["code_interpreter"].sandbox_fusion_url = sandbox_url req = rollout._preprocess_prompt_to_async_rollout_requests(sandbox_data_proto, n=1)[0] req_nums = 100 @@ -357,7 +357,7 @@ async def hacked_handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: re = await result return re - with patch.object(AsyncSGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): + with patch.object(SGLangRollout, "_handle_engine_call", new=hacked_handle_engine_call): rollout._tp_rank = 0 loop = asyncio.get_event_loop() output_req_list = loop.run_until_complete( diff --git a/tests/workers/rollout/test_sglang_async_rollout_w_tools.py b/tests/workers/rollout/test_sglang_async_rollout_w_tools.py index 58b1f99ce49..c9f5ad68abd 100644 --- a/tests/workers/rollout/test_sglang_async_rollout_w_tools.py +++ b/tests/workers/rollout/test_sglang_async_rollout_w_tools.py @@ -35,8 +35,8 @@ ) from verl import DataProto -from verl.workers.rollout.sglang_rollout.async_sglang_rollout import AsyncSGLangRollout -from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager +from verl.workers.rollout.sglang_rollout.sglang_rollout import SGLangRollout +from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager def test_async_sglang_rollout_w_tool(): @@ -78,9 +78,9 @@ def test_async_sglang_rollout_w_tool(): ) rollout_config = get_rollout_config(max_response_length, max_prompt_length, dtype, tensor_parallel_size, None) - rollout = AsyncSGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config) + rollout = SGLangRollout(actor_module=local_model_path, config=rollout_config, tokenizer=tokenizer, model_hf_config=actor_model.config) - rollout_sharding_manager = FSDPAsyncSGLangShardingManager( + rollout_sharding_manager = FSDPSGLangShardingManager( module=fsdp_model, inference_engine=rollout._engine, model_config=actor_model.config, @@ -111,7 +111,7 @@ def test_async_sglang_rollout_w_tool(): prompts = rollout_sharding_manager.preprocess_data(prompts) # log_gpu_memory_usage("Before generating sequences", logger=None) - output = rollout.generate_sequences_with_tools(prompts=prompts) + output = rollout.generate_sequences(prompts=prompts) print(f"generated {output.batch['responses'].shape=}") # log_gpu_memory_usage("After generating sequences", logger=None) output = rollout_sharding_manager.postprocess_data(output) diff --git a/tests/workers/rollout/test_sglang_async_spmd.py b/tests/workers/rollout/test_sglang_async_spmd.py deleted file mode 100644 index d8187c96949..00000000000 --- a/tests/workers/rollout/test_sglang_async_spmd.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -usage: torchrun --standalone --nnodes=1 \ - --nproc_per_node=2 $(which pytest) \ - -s test_sglang_async_spmd.py -""" - -import asyncio - -import torch -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.utils import broadcast_pyobj -from torch.distributed.device_mesh import init_device_mesh -from utils_sglang import ( - are_lists_similar, - clean_torchelastic_env, - generate_hf_output, - initialize_global_process_group, - load_tokenizer_and_model, - prepare_inputs, -) - - -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - token_ids = prompt_token_ids[non_pad_index:].tolist() - return token_ids - - -def test_sglang_spmd(): - assert torch.cuda.device_count() >= 2 - initialize_global_process_group(spmd=True) - clean_torchelastic_env() - - max_prompt_length = 16 - max_response_length = 16 - - local_model_path = "Qwen/Qwen2.5-0.5B" - tokenizer, actor_model = load_tokenizer_and_model(local_model_path) - - preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"] - input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length) - - hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) - - tensor_parallel_size = 2 - inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) - tp_rank = inference_device_mesh_cpu["tp"].get_local_rank() - - if tp_rank == 0: - llm = Engine( - model_path=local_model_path, - dtype="bfloat16", - mem_fraction_static=0.5, - enable_memory_saver=True, - tp_size=inference_device_mesh_cpu["tp"].size(), - ) - - input_ids = input_ids.cuda() - idx_list = [] - - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - for i in range(input_ids.shape[0]): - idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) - - sampling_params = dict( - n=1, - temperature=0, - top_p=1, - top_k=-1, - max_new_tokens=max_response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ignore_eos=False, - ) - - loop = asyncio.get_event_loop() - outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params)) - else: - outputs = None - - [outputs] = broadcast_pyobj( - [outputs], - rank=inference_device_mesh_cpu["tp"].get_local_rank(), - src=inference_device_mesh_cpu["tp"].mesh[0].item(), - dist_group=inference_device_mesh_cpu["tp"].get_group(), - force_cpu_device=False, - ) - - sglang_response_tokens = [output["text"] for output in outputs] - - print(f"sglang response: {sglang_response_tokens}") - assert are_lists_similar(hf_response_tokens, sglang_response_tokens) - print("SPMD Test Passed!") - - torch.distributed.barrier() - torch.distributed.destroy_process_group() diff --git a/tests/workers/rollout/test_sglang_spmd.py b/tests/workers/rollout/test_sglang_spmd.py index 40c514dd5eb..0ad6445a908 100644 --- a/tests/workers/rollout/test_sglang_spmd.py +++ b/tests/workers/rollout/test_sglang_spmd.py @@ -12,188 +12,102 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os +""" +usage: torchrun --standalone --nnodes=1 \ + --nproc_per_node=2 $(which pytest) \ + -s test_sglang_async_spmd.py +""" + +import asyncio import torch -from sglang.srt.entrypoints.verl_engine import VerlEngine +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.utils import broadcast_pyobj from torch.distributed.device_mesh import init_device_mesh -from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig - -from verl.utils.torch_functional import pad_sequence_to_length - - -def levenshtein(s1, s2): - m, n = len(s1), len(s2) - # Initialize matrix of zeros - dp = [[0] * (n + 1) for _ in range(m + 1)] - # Initialize first column and first row of the matrix - for i in range(m + 1): - dp[i][0] = i # Deletion from s1 to empty string - for j in range(n + 1): - dp[0][j] = j # Insertion to s1 from empty string - # Compute the Levenshtein distance matrix - for i in range(1, m + 1): - for j in range(1, n + 1): - cost = 0 if s1[i - 1] == s2[j - 1] else 1 # No cost if characters match - dp[i][j] = min( - dp[i - 1][j] + 1, # Deletion - dp[i][j - 1] + 1, # Insertion - dp[i - 1][j - 1] + cost, # Substitution - ) - return dp[m][n] - - -def are_lists_similar(a, b): - if len(a) != len(b): - print("The lists are of different lengths.") - return False - - total_length = 0 - total_diff = 0 - - for s1, s2 in zip(a, b): - max_len = max(len(s1), len(s2)) - total_length += max_len - diff = levenshtein(s1, s2) - total_diff += diff - print(f"Comparing strings:\n{s1}\n{s2}\nDifference: {diff} characters\n") - - percentage_difference = (total_diff / total_length) * 100 - print(f"Total difference: {percentage_difference:.2f}%") - - return percentage_difference <= 10 - - -def initialize_global_process_group(timeout_second=36000): - from datetime import timedelta - - import torch.distributed +from utils_sglang import ( + are_lists_similar, + clean_torchelastic_env, + generate_hf_output, + initialize_global_process_group, + load_tokenizer_and_model, + prepare_inputs, +) - # NOTE MODIFIED should provide backend=None to have nccl+gloo - # torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second)) - torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second)) - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) - return local_rank, rank, world_size +def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): + non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + token_ids = prompt_token_ids[non_pad_index:].tolist() + return token_ids def test_sglang_spmd(): - assert torch.cuda.device_count() >= 2, "At least 2 GPUs is required to run tp+dp tests." - initialize_global_process_group() - # fill rollout config + assert torch.cuda.device_count() >= 2 + initialize_global_process_group(spmd=True) + clean_torchelastic_env() + max_prompt_length = 16 max_response_length = 16 - # Initialize model and token - local_cache_path = "~/.cache/verl/rlhf" - local_cache_path = os.path.expanduser(local_cache_path) - hdfs_path = "Qwen/Qwen2-7B-Instruct" - from verl.utils.fs import copy_to_local - - local_model_path = copy_to_local(src=hdfs_path, cache_dir=local_cache_path) - tokenizer = AutoTokenizer.from_pretrained(local_model_path, padding_side="left") - - preencode_prompts = [ - "Who won the Champions League in 2019?", - "The founder of Apple is", - "What's your name?", - ] - tokenizer.pad_token = tokenizer.eos_token - prompts = tokenizer(preencode_prompts, return_tensors="pt", padding=True) - input_ids = prompts["input_ids"] - attention_mask = prompts["attention_mask"] - - input_ids = pad_sequence_to_length(input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True) - attention_mask = pad_sequence_to_length(attention_mask, max_prompt_length, 0, left_pad=True) - - actor_model = AutoModelForCausalLM.from_pretrained(local_model_path) - actor_model.to(torch.bfloat16) - - sampling_params = dict( - n=1, - temperature=0, - top_p=1, - top_k=-1, - max_new_tokens=max_response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ignore_eos=False, + local_model_path = "Qwen/Qwen2.5-0.5B" + tokenizer, actor_model = load_tokenizer_and_model(local_model_path) + + preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"] + input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length) + + hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length) + + tensor_parallel_size = 2 + inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) + tp_rank = inference_device_mesh_cpu["tp"].get_local_rank() + + if tp_rank == 0: + llm = Engine( + model_path=local_model_path, + dtype="bfloat16", + mem_fraction_static=0.5, + enable_memory_saver=True, + tp_size=inference_device_mesh_cpu["tp"].size(), + ) + + input_ids = input_ids.cuda() + idx_list = [] + + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + for i in range(input_ids.shape[0]): + idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) + + sampling_params = dict( + n=1, + temperature=0, + top_p=1, + top_k=-1, + max_new_tokens=max_response_length, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ignore_eos=False, + ) + + loop = asyncio.get_event_loop() + outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params)) + else: + outputs = None + + [outputs] = broadcast_pyobj( + [outputs], + rank=inference_device_mesh_cpu["tp"].get_local_rank(), + src=inference_device_mesh_cpu["tp"].mesh[0].item(), + dist_group=inference_device_mesh_cpu["tp"].get_group(), + force_cpu_device=False, ) - tensor_parallel_size = 4 - device_mesh_kwargs = dict(mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"]) - inference_device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) - - for k in ["TORCHELASTIC_USE_AGENT_STORE"]: - if k in os.environ: - del os.environ[k] - print("building sglang rollout engine") - llm = VerlEngine( - model_path=local_model_path, - dtype="bfloat16", - mem_fraction_static=0.5, - device_mesh_cpu=inference_device_mesh_cpu["tp"], - base_gpu_id=0, - gpu_id_step=1, - ) - - llm.release_memory_occupation() - print("start generation") - input_ids = input_ids.cuda() - attention_mask = attention_mask.cuda() - batch_size = input_ids.size(0) - - generation_config = GenerationConfig(do_sample=False) - actor_model.cuda() - output = actor_model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=max_response_length, - # max_length=max_length, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config=generation_config, - # renormalize_logits=True, - output_scores=False, # this is potentially very large - return_dict_in_generate=True, - use_cache=False, - ) # may OOM when use_cache = True - seq = output.sequences - response = seq[:, max_prompt_length:] - - hf_response_tokens = tokenizer.batch_decode(response) - print(f"hf response: {hf_response_tokens}") - print(f"{sampling_params=}") - idx_list = [] - batch_size = input_ids.shape[0] - - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id - for i in range(batch_size): - idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i])) - - outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params) - sglang_response_tokens = [] - - for output in outputs: - print(f"{output=}") - generated_text = output["text"] - sglang_response_tokens.append(generated_text) + sglang_response_tokens = [output["text"] for output in outputs] print(f"sglang response: {sglang_response_tokens}") assert are_lists_similar(hf_response_tokens, sglang_response_tokens), "Strings differ more than 10%:\n" - print("Check Pass") - + print("SPMD Test Passed!") -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor): - # remove the left padding in the prompt token_id - non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] - token_ids = prompt_token_ids[non_pad_index:].tolist() - return token_ids + torch.distributed.barrier() + torch.distributed.destroy_process_group() diff --git a/tests/workers/rollout/utils_sglang.py b/tests/workers/rollout/utils_sglang.py index 5a08a485737..35c43a83a88 100644 --- a/tests/workers/rollout/utils_sglang.py +++ b/tests/workers/rollout/utils_sglang.py @@ -37,7 +37,7 @@ def levenshtein(s1, s2): return dp[m][n] -def are_lists_similar(a, b): +def are_lists_similar(a, b, threshold=10): if len(a) != len(b): print("The lists are of different lengths.") return False @@ -49,7 +49,7 @@ def are_lists_similar(a, b): total_diff += levenshtein(s1, s2) percentage_difference = (total_diff / total_length) * 100 print(f"Total difference: {percentage_difference:.2f}%") - return percentage_difference <= 10 + return percentage_difference <= threshold def initialize_global_process_group(timeout_second=36000, spmd=False): diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 4d2d7333c66..e827c68a6d0 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -168,7 +168,7 @@ actor_rollout_ref: n: 1 do_sample: False # default eager for validation multi_turn: - enable: False # should set rollout.name to sglang_async if True + enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well max_turns: null # null for no limit (default max_length // 3) tool_config_path: null # null for no tool format: chatml # chatml, more formats will be supported in the future diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index f43fed0f8bf..474e1303f2f 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -140,7 +140,7 @@ actor_rollout_ref: n: 1 do_sample: False # default eager for validation multi_turn: - enable: False # should set rollout.name to sglang_async if True + enable: False # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well max_turns: null # null for no limit (default max_length // 3) tool_config_path: null # null for no tool format: chatml # chatml, more formats will be supported in the future diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 880bd9bd75f..d65f1b3c26c 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -448,86 +448,39 @@ def _build_rollout(self, trust_remote_code=False): ) log_gpu_memory_usage("After building sharding manager", logger=logger) - elif rollout_name == "sglang": - if self.config.rollout.mode == "sync": - from verl.workers.rollout.sglang_rollout import SGLangRollout - - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to - # SGLang's model_runner would check CUDA device capability. However, due to verl's setting, - # the main process of ray can not find any CUDA device, which would potentially lead to: - # "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and - # we import it here use the abs path. - # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 - from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager - - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) - local_path = copy_to_local(self.config.model.path) - rollout = SGLangRollout( - actor_module=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - trust_remote_code=trust_remote_code, - ) - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) - - if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = "dummy_hf" - rollout_sharding_manager = FSDPSGLangShardingManager( - module=self.actor_module_fsdp, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - full_params="hf" in self.config.rollout.load_format, - device_mesh=rollout_device_mesh, - offload_param=self._is_offload_param, - ) - log_gpu_memory_usage("After building sharding manager", logger=logger) - elif self.config.rollout.mode == "async": - from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout - from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager - - local_path = copy_to_local(self.config.model.path) - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) - rollout = AsyncSGLangRollout( - actor_module=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - trust_remote_code=trust_remote_code, - ) - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) - - if torch.distributed.get_world_size() == 1: - self.config.rollout.load_format = "dummy_hf" - rollout_sharding_manager = FSDPAsyncSGLangShardingManager( - module=self.actor_module_fsdp, - inference_engine=rollout._engine, - model_config=self.actor_model_config, - full_params="hf" in self.config.rollout.load_format, - device_mesh=rollout_device_mesh, - offload_param=self._is_offload_param, + elif rollout_name in ["sglang", "sglang_async"]: + if rollout_name == "sglang_async": + warnings.warn( + "'sglang_async' has been deprecated and merged into 'sglang'. " + "Please use 'sglang' going forward.", + DeprecationWarning, + stacklevel=2, ) - log_gpu_memory_usage("After building sharding manager", logger=None) - elif rollout_name == "sglang_async": - # TODO replace by rollout.mode == "async" - from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout - from verl.workers.sharding_manager.fsdp_sglang import FSDPAsyncSGLangShardingManager + from verl.workers.rollout.sglang_rollout import SGLangRollout + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to + # SGLang's model_runner would check CUDA device capability. However, due to verl's setting, + # the main process of ray can not find any CUDA device, which would potentially lead to: + # "RuntimeError: No CUDA GPUs are available". + # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and + # we import it here use the abs path. + # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 + + from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager local_path = copy_to_local(self.config.model.path) - log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) - rollout = AsyncSGLangRollout( + log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=logger) + rollout = SGLangRollout( actor_module=local_path, config=self.config.rollout, tokenizer=self.tokenizer, model_hf_config=self.actor_model_config, trust_remote_code=trust_remote_code, ) - log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) + log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger) if torch.distributed.get_world_size() == 1: self.config.rollout.load_format = "dummy_hf" - rollout_sharding_manager = FSDPAsyncSGLangShardingManager( + rollout_sharding_manager = FSDPSGLangShardingManager( module=self.actor_module_fsdp, inference_engine=rollout._engine, model_config=self.actor_model_config, @@ -535,7 +488,7 @@ def _build_rollout(self, trust_remote_code=False): device_mesh=rollout_device_mesh, offload_param=self._is_offload_param, ) - log_gpu_memory_usage("After building sharding manager", logger=None) + log_gpu_memory_usage("After building sharding manager", logger=logger) else: raise NotImplementedError(f"Rollout name: {self.config.rollout.name} is not supported") @@ -696,16 +649,8 @@ def generate_sequences(self, prompts: DataProto): log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) - - if self.config.rollout.name == "sglang_async": - from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout - - if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0: - output = self.rollout.generate_sequences_with_tools(prompts=prompts) - else: - output = self.rollout.generate_sequences(prompts=prompts) - else: - output = self.rollout.generate_sequences(prompts=prompts) + output = self.rollout.generate_sequences(prompts=prompts) + log_gpu_memory_usage("After rollout generation", logger=logger) output = self.rollout_sharding_manager.postprocess_data(output) diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index f31d99cdebd..b14f3bb90f2 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -18,6 +18,7 @@ import logging import os import time +import warnings import torch import torch.distributed @@ -258,7 +259,14 @@ def _build_rollout(self, trust_remote_code=False): weight_converter=weight_converter, ) log_gpu_memory_usage("After building sharding manager", logger=logger) - elif self.config.rollout.name == "sglang": + + elif self.config.rollout.name in ["sglang", "sglang_async"]: + if self.config.rollout.name == "sglang_async": + warnings.warn( + "'sglang_async' has been deprecated and merged into 'sglang'. Please use 'sglang' going forward.", + DeprecationWarning, + stacklevel=2, + ) from verl.workers.rollout.sglang_rollout import SGLangRollout # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. @@ -276,45 +284,6 @@ def _build_rollout(self, trust_remote_code=False): local_path = copy_to_local(self.config.model.path) log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) rollout = SGLangRollout( - actor_module=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - trust_remote_code=self.config.model.get("trust_remote_code", False), - ) - log_gpu_memory_usage(f"After building {self.config.rollout.name} rollout", logger=None) - - from verl.models.mcore import get_mcore_weight_converter - - weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - sharding_manager = MegatronSGLangShardingManager( - actor_module=self.actor.actor_module, - inference_engine=rollout.inference_engine, - model_config=self.actor_model_config, - transformer_config=self.tf_config, - layer_name_mapping=layer_name_mapping, - weight_converter=weight_converter, - device_mesh=rollout_device_mesh, - ) - log_gpu_memory_usage("After building sharding manager", logger=logger) - elif self.config.rollout.name == "sglang_async": - from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout - - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. - # However, due to verl's setting, the main process of ray can not find any CUDA device, which would potentially lead to: - # "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path. - # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 - from verl.workers.sharding_manager.megatron_sglang import MegatronAsyncSGLangShardingManager - - infer_tp = self.config.rollout.tensor_model_parallel_size - dp = self.world_size // infer_tp - assert self.world_size % infer_tp == 0, f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" - rollout_device_mesh = init_device_mesh("cpu", mesh_shape=(dp, infer_tp, 1), mesh_dim_names=("dp", "tp", "pp")) - - local_path = copy_to_local(self.config.model.path) - log_gpu_memory_usage(f"Before building {self.config.rollout.name} rollout", logger=None) - rollout = AsyncSGLangRollout( actor_module=local_path, config=self.config.rollout, tokenizer=self.tokenizer, @@ -327,7 +296,7 @@ def _build_rollout(self, trust_remote_code=False): from verl.models.mcore import get_mcore_weight_converter weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) - sharding_manager = MegatronAsyncSGLangShardingManager( + sharding_manager = MegatronSGLangShardingManager( actor_module=self.actor.actor_module, inference_engine=rollout._engine, model_config=self.actor_model_config, @@ -491,16 +460,7 @@ def generate_sequences(self, prompts: DataProto): log_gpu_memory_usage("After entering sharding manager", logger=logger) prompts = self.sharding_manager.preprocess_data(prompts) - # output = self.rollout.generate_sequences(prompts=prompts) - if self.config.rollout.name == "sglang_async": - from verl.workers.rollout.sglang_rollout import AsyncSGLangRollout - - if isinstance(self.rollout, AsyncSGLangRollout) and hasattr(self.rollout, "_tool_schemas") and len(self.rollout._tool_schemas) > 0: - output = self.rollout.generate_sequences_with_tools(prompts=prompts) - else: - output = self.rollout.generate_sequences(prompts=prompts) - else: - output = self.rollout.generate_sequences(prompts=prompts) + output = self.rollout.generate_sequences(prompts=prompts) output = self.sharding_manager.postprocess_data(output) output = output.to("cpu") diff --git a/verl/workers/rollout/sglang_rollout/__init__.py b/verl/workers/rollout/sglang_rollout/__init__.py index fd00061e3be..43a1eebb4cd 100644 --- a/verl/workers/rollout/sglang_rollout/__init__.py +++ b/verl/workers/rollout/sglang_rollout/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -from .async_sglang_rollout import AsyncSGLangRollout from .sglang_rollout import SGLangRollout -__all__ = ["AsyncSGLangRollout", "SGLangRollout"] +__all__ = ["SGLangRollout"] diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py deleted file mode 100644 index 985268213f7..00000000000 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ /dev/null @@ -1,876 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import asyncio -import logging -import os -import time -from contextlib import contextmanager -from copy import deepcopy -from json import JSONDecodeError -from typing import TYPE_CHECKING, Union -from uuid import uuid4 - -import numpy as np -import torch -import torch.distributed as dist -from omegaconf import DictConfig -from sglang.srt.entrypoints.engine import Engine -from sglang.srt.function_call_parser import FunctionCallParser -from sglang.srt.openai_api.protocol import Tool -from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.utils import get_ip, get_open_port -from tensordict import TensorDict -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.nn.utils.rnn import pad_sequence -from transformers import PreTrainedTokenizer - -from verl import DataProto -from verl.third_party.sglang import parallel_state as sglang_ps -from verl.tools.base_tool import BaseTool -from verl.tools.schemas import OpenAIFunctionCallSchema, OpenAIFunctionParsedSchema, OpenAIFunctionToolCall -from verl.utils.debug import GPUMemoryLogger -from verl.utils.model import compute_position_id_with_mask -from verl.utils.net_utils import is_ipv6 -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length -from verl.workers.rollout.base import BaseRollout -from verl.workers.rollout.schemas import ( - AsyncRolloutRequest, - AsyncRolloutRequestStateEnum, - FinishReasonTypeEnum, - Message, -) -from verl.workers.rollout.sglang_rollout.sglang_rollout import _post_process_outputs, _pre_process_inputs -from verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj - -if TYPE_CHECKING: - from torch import nn - -logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - - -def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str: - for parser_type, parser_cls in FunctionCallParser.ToolCallParserEnum.items(): - parser = parser_cls() - if parser.bot_token in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab()): - return parser_type - else: - raise ValueError(f"No tool call parser found for tokenizer {tokenizer}") - - -class AsyncSGLangRollout(BaseRollout): - def __init__( - self, - actor_module: nn.Module | str, - config: DictConfig, - tokenizer, - model_hf_config, - port=None, - trust_remote_code: bool = False, - device_mesh: DeviceMesh | None = None, - **kwargs, - ): - """A SGLang rollout. It requires the module is supported by the SGLang. - - Args: - actor_module: module here follows huggingface APIs - config: DictConfig - tokenizer: the task/model tokenizer - model_hf_config: the huggingface config to initiallize the generating model in SGLang - **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group - """ - super().__init__() - self.config = config - - self._tool_schemas, self._tool_map, self._tool_call_parser_type, self._sgl_tools, self._function_call_parser = self._initialize_tools(config, tokenizer) - assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" - logger.info(f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: {self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: {self._function_call_parser}") - - self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs) - - self._verify_config(model_hf_config=model_hf_config) - # initialize the inference engine - self._init_inference_engine(trust_remote_code, actor_module, port) - - self._init_sampling_params(**kwargs) - - self.tokenizer = tokenizer - self.pad_token_id = tokenizer.pad_token_id - - def _init_distributed_env(self, device_mesh_cpu, **kwargs): - self._device_mesh_cpu = device_mesh_cpu - os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") - self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert self.tensor_parallel_size <= dist.get_world_size(), "tensor parallel size should be less than or equal to the world size" - self.train_tp = kwargs.get("train_tp", None) - if self.train_tp is not None: - # deployed with megatron - os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" - os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - train_tp = kwargs.get("train_tp", None) - num_tp_per_train_tp = train_tp // self.tensor_parallel_size - sglang_ps.initialize_parallel_state( - tensor_model_parallel_size=self.tensor_parallel_size, - num_tp_per_train_tp=num_tp_per_train_tp, - ) - - tp_size = self.tensor_parallel_size - world_size = int(os.getenv("WORLD_SIZE", "-1")) - - # init device mesh - if self._device_mesh_cpu is None: - device_mesh_kwargs = dict( - mesh_shape=(world_size // tp_size, tp_size, 1), - mesh_dim_names=["dp", "tp", "pp"], - ) - - self._device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) - - self._rank = self._device_mesh_cpu.get_rank() - self._tp_rank = self._device_mesh_cpu["tp"].get_local_rank() - self._tp_size = self._device_mesh_cpu["tp"].size() - if self._rank == 0: - logger.info(f"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}") - # get tp_rank of this process in this tp group - visible_devices = [None] * self._device_mesh_cpu.size(1) - - torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp")) - self.visible_devices_set = set(",".join(visible_devices).split(",")) - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(self.visible_devices_set))) - - def _verify_config(self, model_hf_config): - if not self.config.get("max_model_len", None): - self.config.max_model_len = self.config.prompt_length + self.config.response_length - assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length): - {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" - assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length" - # currently max_turns stand for max number of tool calls - if self.config.multi_turn.max_turns is None: - self.config.multi_turn.max_turns = self.config.max_model_len // 3 - - def _init_inference_engine(self, trust_remote_code, actor_module, port): - # initialize the inference engine - nnodes = -(-self._tp_size // len(self.visible_devices_set)) - if nnodes > 1: - ip = get_ip() - port = get_open_port() if port is None else port - [ip, port] = broadcast_pyobj( - [ip, port], - rank=self._rank, - dist_group=self._device_mesh_cpu.get_group("tp"), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, - ) - dist_init_addr = f"[{ip}]:{port}" if is_ipv6(ip) else f"{ip}:{port}" - else: - dist_init_addr = None - - load_format = "dummy" if self.config.load_format.startswith("dummy") else self.config.load_format - tp_size_per_node = self._tp_size // nnodes - node_rank = self._tp_rank // tp_size_per_node - first_rank_in_node = self._tp_rank % tp_size_per_node == 0 - - if first_rank_in_node: - rank = dist.get_rank() - os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" - self._engine = Engine( - model_path=actor_module, - dtype=self.config.dtype, - mem_fraction_static=self.config.gpu_memory_utilization, - enable_memory_saver=True, - base_gpu_id=0, - gpu_id_step=1, - tp_size=self._tp_size, - node_rank=node_rank, - load_format=load_format, - dist_init_addr=dist_init_addr, - nnodes=nnodes, - trust_remote_code=trust_remote_code, - # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new - # when random.seed is being set during training - port=30000 + rank, - # NOTE(Chenyang): if you want to debug the SGLang engine output - # please set the following parameters - # Otherwise, it will make the engine run too slow - # log_level="INFO", - # log_requests=True, - # log_requests_level=2, - # max_running_requests=1, - ) - else: - self._engine = None - - self.sharding_manager = None - # offload - if self._tp_rank == 0: - self._engine.release_memory_occupation() - self.is_sleep = True - - def _init_sampling_params(self, **kwargs): - kwargs = dict( - n=1, - max_new_tokens=self.config.response_length, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - ) - # supporting adding any sampling params from the config file - for k in self.config.keys(): - if hasattr(SamplingParams(), str(k)): - kwargs[k] = self.config.get(k) - self.sampling_params = kwargs - - def _initialize_tools(self, config, tokenizer): - """Initialize tools from configuration. - - Args: - config: Configuration object containing tool settings - tokenizer: Tokenizer instance for tool call parsing - - Returns: - tuple: (tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser) - """ - if config.multi_turn.tool_config_path is None: - return [], {}, None, [], None - - import importlib.util - import sys - - from omegaconf import OmegaConf - - from verl.tools.schemas import OpenAIFunctionToolSchema - - def initialize_tools_from_config(tools_config) -> list: - tool_list = [] - - for tool_config in tools_config.tools: - cls_name = tool_config.class_name - module_name, class_name = cls_name.rsplit(".", 1) - - if module_name not in sys.modules: - spec = importlib.util.find_spec(module_name) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - else: - module = sys.modules[module_name] - - tool_cls = getattr(module, class_name) - - tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) - tool_schema = OpenAIFunctionToolSchema.parse_obj(tool_schema_dict) - - tool = tool_cls(config=OmegaConf.to_container(tool_config.config, resolve=True), tool_schema=tool_schema) - tool_list.append(tool) - - return tool_list - - tools_config_file = config.multi_turn.tool_config_path - tools_config = OmegaConf.load(tools_config_file) - tool_list = initialize_tools_from_config(tools_config) - logger.info(f"Initialize tools from configuration.: tool_list: {tool_list}") - tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] - tool_map = {tool.name: tool for tool in tool_list} - tool_call_parser_type = get_tool_call_parser_type(tokenizer) - sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas] - function_call_parser = FunctionCallParser( - sgl_tools, - tool_call_parser_type, - ) - - return tool_schemas, tool_map, tool_call_parser_type, sgl_tools, function_call_parser - - @contextmanager - def update_sampling_params(self, **kwargs): - # update sampling params - old_sampling_params_args = {} - if kwargs: - for key, value in kwargs.items(): - if key in self.sampling_params: - old_value = self.sampling_params[key] - old_sampling_params_args[key] = old_value - self.sampling_params[key] = value - yield - # roll back to previous sampling params - # if len(old_sampling_params_args): - for key, value in old_sampling_params_args.items(): - self.sampling_params[key] = value - - @GPUMemoryLogger(role="sglang async rollout", logger=logger) - @torch.no_grad() - def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - # if self.config.free_cache_engine: - - idx = prompts.batch["input_ids"] # (bs, prompt_length) - # left-padded attention_mask - attention_mask = prompts.batch["attention_mask"] - position_ids = prompts.batch["position_ids"] - - # used to construct attention_mask - eos_token_id = prompts.meta_info["eos_token_id"] - - batch_size = idx.size(0) - - # Extract non-tensor data - non_tensor_batch = prompts.non_tensor_batch - if "raw_prompt_ids" not in non_tensor_batch: - non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) - - if "multi_modal_data" in non_tensor_batch: - sglang_inputs = [] - for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")): - sglang_inputs.append( - { - "prompt_token_ids": raw_prompt_ids, - "multi_modal_data": multi_modal_data, - "image_data": multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None, - } - ) - else: - sglang_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")] - - # Ensure token IDs are lists - for input_data in sglang_inputs: - if isinstance(input_data["prompt_token_ids"], np.ndarray): - input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() - elif not isinstance(input_data["prompt_token_ids"], list): - raise TypeError(f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}") - - # Extract token IDs and image data for SGLang Engine - idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs] - image_list = [input_data.get("image_data", None) for input_data in sglang_inputs] - - do_sample = prompts.meta_info.get("do_sample", True) - is_validate = prompts.meta_info.get("validate", False) - if not do_sample: - kwargs = dict( - n=1, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - temperature=0, - top_p=1, - top_k=-1, - ignore_eos=False, - min_new_tokens=0, - max_new_tokens=self.config.response_length, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ) - elif is_validate: - kwargs = dict( - top_k=self.config.val_kwargs.top_k, - top_p=self.config.val_kwargs.top_p, - temperature=self.config.val_kwargs.temperature, - n=1, # if validate, already repeat in ray_trainer - ) - - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs): - # print(f"{self.sampling_params=}") - if self._tp_rank == 0: - loop = asyncio.get_event_loop() - output = loop.run_until_complete( - self._engine.async_generate( - prompt=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - return_logprob=True, - input_ids=idx_list, - image_data=image_list, - ) - ) - else: - output = None - # Most naive implementation, can extract tensor and send via gloo if too slow - [output] = broadcast_pyobj( - data=[output], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, - ) - out = _post_process_outputs(self.tokenizer, output) - - response = out[0].to(idx.device) - rollout_log_probs = out[1].to(idx.device) - - if response.shape[1] < self.config.response_length: - response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id) - - # utilize current sampling params - if self.sampling_params.get("n", 1) > 1 and do_sample: - idx = idx.repeat_interleave(self.sampling_params["n"], dim=0) - attention_mask = attention_mask.repeat_interleave(self.sampling_params["n"], dim=0) - position_ids = position_ids.repeat_interleave(self.sampling_params["n"], dim=0) - batch_size = batch_size * self.sampling_params["n"] - _non_tensor_batch = {} - for key, val in non_tensor_batch.items(): - _non_tensor_batch[key] = np.repeat(val, self.sampling_params["n"], axis=0) - else: - _non_tensor_batch = non_tensor_batch - seq = torch.cat([idx, response], dim=-1) - - response_length = response.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1) - - # TODO(sgm): fix position_ids on right_pad - # prompt: left pad + response: right pad - # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] - # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] - response_position_ids = position_ids[:, -1:] + delta_position_id - position_ids = torch.cat([position_ids, response_position_ids], dim=-1) - response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) - attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - - # all the tp ranks should contain the same data here. data in all ranks are valid - batch = TensorDict( - { - "prompts": idx, - "responses": response, - "input_ids": seq, # here input_ids become the whole sentences - "rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor - "attention_mask": attention_mask, - "position_ids": position_ids, - }, - batch_size=batch_size, - ) - - # free cache engine - if self.config.free_cache_engine and self._engine is not None: - self._engine.flush_cache() - - return DataProto(batch=batch, non_tensor_batch=_non_tensor_batch) - - async def _async_rollout_a_request(self, req: AsyncRolloutRequest, do_sample: bool = True, is_validate: bool = False, **kwargs) -> AsyncRolloutRequest: - assert self._tp_rank == 0, "only the master process can call this function" - _req = deepcopy(req) - finish_reason_type = None - output = None - - current_turns = 0 - while current_turns < self.config.multi_turn.max_turns: - if _req.state == AsyncRolloutRequestStateEnum.PENDING: - await self._handle_pending_state(_req) - _req.state = AsyncRolloutRequestStateEnum.RUNNING - elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING: - if _req.messages[-1].tool_calls is not None: - parsed_tool_calls = _req.messages[-1].tool_calls - tool_call_results = await asyncio.gather( - *[ - self._tool_map[tool_call.function.name].execute( - _req.request_id, - tool_call.function.arguments, - **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}), - ) - for tool_call in parsed_tool_calls - ] - ) - for i, (tool_call, (resp, reward, metrics)) in enumerate(zip(parsed_tool_calls, tool_call_results)): - _req.add_tool_response_message(self.tokenizer, resp, (i == len(parsed_tool_calls) - 1), format=self.config.multi_turn.format) - _req.update_metrics(metrics, tool_call.function.name) - if len(_req.input_ids) >= self.config.max_model_len: - break - if len(_req.input_ids) >= self.config.max_model_len: - finish_reason_type = FinishReasonTypeEnum.STOP - break - _req.state = AsyncRolloutRequestStateEnum.RUNNING - else: - raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") - elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: - output = await self._handle_engine_call(_req, do_sample, is_validate, **kwargs) - content = output["text"] - finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) - current_turns += 1 - if finish_reason_type == FinishReasonTypeEnum.LENGTH: - _req.add_assistant_message(self.tokenizer, content, already_over_long=True, format=self.config.multi_turn.format) - break - else: - if self._function_call_parser and self._function_call_parser.has_tool_call(content): - finish_reason_type = FinishReasonTypeEnum.TOOL_CALL - _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING - try: - normed_content, tool_calls = self._function_call_parser.parse_non_stream(content) - except JSONDecodeError: - normed_content = content - tool_calls = [] - except AttributeError: - normed_content = content - tool_calls = [] - parsed_tool_calls = [] - for tool_call in tool_calls: - function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema(OpenAIFunctionParsedSchema(name=tool_call.name, arguments=tool_call.parameters)) - # Drop the tool call if its arguments has decode error - if has_decode_error: - continue - parsed_tool_calls.append( - OpenAIFunctionToolCall( - id=str(tool_call.tool_index), - function=function, - ) - ) - if len(parsed_tool_calls) > 0: - _req.add_assistant_message( - self.tokenizer, - normed_content, - tool_calls=parsed_tool_calls, - format=self.config.multi_turn.format, - ) - else: - _req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format) - finish_reason_type = FinishReasonTypeEnum.STOP - _req.state = AsyncRolloutRequestStateEnum.COMPLETED - break - else: - _req.add_assistant_message(self.tokenizer, content, format=self.config.multi_turn.format) - break - - if current_turns >= self.config.multi_turn.max_turns: - finish_reason_type = FinishReasonTypeEnum.STOP - - # Calculate the reward for each tool - async def calc_reward_and_release_fn(name: str, tool: BaseTool): - reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {})) - await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {})) - return name, reward - - tool_reward_tasks = [] - for name in _req.tools_kwargs.keys(): - tool = self._tool_map[name] - tool_reward_tasks.append(calc_reward_and_release_fn(name, tool)) - tool_reward_scores = await asyncio.gather(*tool_reward_tasks) - tool_reward_scores = dict(tool_reward_scores) - _req.finalize(self.tokenizer, tool_reward_scores, finish_reason_type) - - return _req - - async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs) -> dict: - generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) - max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) - if not do_sample: - kwargs = dict( - n=1, - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - temperature=0, - top_p=1, - top_k=-1, - ignore_eos=False, - min_new_tokens=0, - max_new_tokens=self.config.response_length, - skip_special_tokens=True, - spaces_between_special_tokens=True, - ) - elif is_validate: - # TODO: try ** - kwargs = { - "top_k": self.config.val_kwargs.top_k, - "top_p": self.config.val_kwargs.top_p, - "temperature": self.config.val_kwargs.temperature, - "n": 1, # if validate, already repeat in ray_trainer - } - kwargs["max_new_tokens"] = max_new_tokens - if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess - kwargs["n"] = 1 - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs): - output = await self._engine.async_generate( - input_ids=generation_prompt_ids, - sampling_params=self.sampling_params, - return_logprob=False, - ) - return output - - async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest: - if _req.tools is not None: - tool_creation_coroutines = [] - for tool_schema in _req.tools: - tool = self._tool_map[tool_schema.function.name] - create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) - tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) - await asyncio.gather(*tool_creation_coroutines) - - @GPUMemoryLogger(role="sglang async rollout", logger=logger) - @torch.no_grad() - def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto: - # Async rollout with tools support - do_sample = prompts.meta_info.get("do_sample", True) - is_validate = prompts.meta_info.get("validate", False) - tgt_device = prompts.batch["input_ids"].device - if self._tp_rank == 0: - req_list = self._preprocess_prompt_to_async_rollout_requests( - prompts, - n=1 if is_validate else self.config.n, - ) - loop = asyncio.get_event_loop() - output_req_list = loop.run_until_complete( - asyncio.gather( - *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], - ) - ) - sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset)) - else: - sorted_output_req_list = None - - [sorted_output_req_list] = broadcast_pyobj( - data=[sorted_output_req_list], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, - ) - # Construct the batch data - prompt_ids, response_ids = [], [] - prompt_attention_mask, response_attention_mask = [], [] - prompt_position_ids, response_position_ids = [], [] - prompt_loss_mask, response_loss_mask = [], [] - messages = [] - reward_scores = [] - for req in sorted_output_req_list: - assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" - assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of - {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}""" - error_message_lines = [ - f"""Request {req.request_id} has input_ids length {len(req.input_ids)} - greater than max_model_len {self.config.max_model_len}""", - f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}", - f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}", - f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}", - f"Messages: {req.messages}", - f"Max model length: {req.max_model_len}", - ] - error_message = "\n".join(error_message_lines) - assert len(req.input_ids) <= self.config.max_model_len, error_message - - prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device)) - response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device)) - if len(req.response_ids) > self.config.response_length: - logger.warning( - f"""{req.request_id=} has response_ids length {len(req.response_ids)} - greater than max_response_len {self.config.response_length},\n{req=}""" - ) - prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device)) - response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device)) - prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device)) - response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device)) - prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device)) - response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device)) - messages.append({"messages": req.messages}) - reward_scores.append(req.reward_scores) - - prompt_ids = pad_sequence(prompt_ids, batch_first=True, padding_value=self.pad_token_id, padding_side="left") - if prompt_ids.shape[1] < self.config.prompt_length: - prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True) - response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) - if response_ids.shape[1] < self.config.response_length: - response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id) - prompt_attention_mask = pad_sequence(prompt_attention_mask, batch_first=True, padding_value=0, padding_side="left") - if prompt_attention_mask.shape[1] < self.config.prompt_length: - prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True) - response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) - if response_attention_mask.shape[1] < self.config.response_length: - response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) - prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left") - if prompt_position_ids.shape[1] < self.config.prompt_length: - prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True) - response_length = response_ids.size(1) - delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device) - delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1) - response_position_ids = prompt_position_ids[:, -1:] + delta_position_id - prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") - if prompt_loss_mask.shape[1] < self.config.prompt_length: - prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) - response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) - if response_loss_mask.shape[1] < self.config.response_length: - response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0) - - input_ids = torch.cat((prompt_ids, response_ids), dim=-1) - attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) - position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) - loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1) - - # Construct the batch data - batch = TensorDict( - { - "prompts": prompt_ids, - "responses": response_ids, - "input_ids": input_ids, # here input_ids become the whole sentences - "attention_mask": attention_mask, - "position_ids": position_ids, - "loss_mask": loss_mask, - }, - batch_size=len(sorted_output_req_list), - ) - - # free cache engine - if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0: - self._engine.flush_cache() - - return DataProto(batch=batch, non_tensor_batch={"messages": np.array(messages), "reward_scores": np.array(reward_scores)}) - - def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]: - assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages" - req_list = [] - for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]): - for rollout_offset in range(n): - if self._tool_schemas: - _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] - _tool_schemas = [] - for k in _tools_kwargs.keys(): - _tool_schemas.append(self._tool_map[k].get_openai_tool_schema()) - prompt_with_chat_template = self.tokenizer.apply_chat_template( - conversation=raw_prompt, - tools=[tool.model_dump() for tool in _tool_schemas], - add_generation_prompt=True, - tokenize=False, - return_tensors="pt", - ) - input_data = self.tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False) - _input_ids = input_data["input_ids"][0].tolist() - _attention_mask = input_data["attention_mask"][0].tolist() - _position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist() - if len(_input_ids) > self.config.prompt_length: - logger.warning( - "Prompt {} has length {} greater than max_prompt_len {}", - data_idx, - len(_input_ids), - self.config.prompt_length, - ) - _input_ids = _input_ids[: self.config.prompt_length] - _attention_mask = _attention_mask[: self.config.prompt_length] - _position_ids = _position_ids[: self.config.prompt_length] - else: - _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) - _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) - _position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() - _tool_schemas = [] - _tools_kwargs = {} - - req = AsyncRolloutRequest( - batch_data_id=data_idx, - rollout_offset=rollout_offset, - request_id=str(uuid4()), - state=AsyncRolloutRequestStateEnum.PENDING, - messages=[Message.model_validate(msg) for msg in raw_prompt], - tools=_tool_schemas, - tools_kwargs=_tools_kwargs, - input_ids=_input_ids, - prompt_ids=_input_ids, - response_ids=[], - attention_mask=_attention_mask, - prompt_attention_mask=_attention_mask, - response_attention_mask=[], - position_ids=_position_ids, - prompt_position_ids=_position_ids, - response_position_ids=[], - loss_mask=[0] * len(_input_ids), - prompt_loss_mask=[0] * len(_input_ids), - response_loss_mask=[], - reward_scores={}, - max_response_len=self.config.response_length, - max_model_len=min(self.config.max_model_len, self.config.prompt_length + self.config.response_length), - ) - - error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}" - assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message - - req_list.append(req) - - return req_list - - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - if method == "chat_completion": - json_request = args[0] - - formatted_messages = [] - for msg in json_request["messages"]: - role = msg.get("role", "user") - content = msg.get("content", "") - formatted_messages.append(f"{role}: {content}") - prompt_str = "\n".join(formatted_messages) - - sampling_params_dict = { - "n": json_request.get("n", 1), - "max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length), - "temperature": json_request.get("temperature", 1.0), - "top_p": json_request.get("top_p", 1.0), - } - output = None - if self._tp_rank == 0: - loop = asyncio.get_event_loop() - output = loop.run_until_complete( - self._engine.async_generate( - prompt=prompt_str, - sampling_params=sampling_params_dict, - return_logprob=True, - ) - ) - output = broadcast_pyobj( - data=[output], - rank=self._rank, - dist_group=self._device_mesh_cpu["tp"].get_group(), - src=self._device_mesh_cpu["tp"].mesh[0].item(), - force_cpu_device=False, - ) - - # only return value from master rank - if self._tp_rank != 0: - return None - # build openai chat completion format - choices = [] - id = None - for i, content in enumerate(output): - choices.append( - { - "index": i, - "message": { - "role": "assistant", - "content": content["text"], - }, - "finish_reason": content["meta_info"]["finish_reason"]["type"], - } - ) - id = content["meta_info"]["id"] - - return { - "id": "chatcmpl-" + id, - "object": "chat.completion", - "created": int(time.time()), - "model": json_request.get("model", "sglang_model"), - "choices": choices, - } - else: - raise ValueError(f"not supported method : {method}") - - # this function is left for uniform train-inference resharding - - def resume(self): - if not self.is_sleep: - return - self.sharding_manager.__enter__() # pylint: disable=C2801 - - self.is_sleep = False - - # this function is left for uniform train-inference resharding - def offload(self): - if self.is_sleep: - return - - self.sharding_manager.__exit__(None, None, None) - self.is_sleep = True diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index c29494f794b..b3a36818871 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -1,3 +1,4 @@ +# Copyright 2023-2024 SGLang Team # Copyright 2025 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index af30a568df5..43a601a2acc 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -1,16 +1,5 @@ # Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== +# Copyright 2025 ModelBest Inc. and/or its affiliates # Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,45 +13,72 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from __future__ import annotations +import asyncio import logging +import math import os +import time from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING +from json import JSONDecodeError +from typing import TYPE_CHECKING, Union +from uuid import uuid4 import numpy as np -import torch.distributed -from omegaconf import DictConfig, OmegaConf -from sglang.srt.entrypoints.verl_engine import VerlEngine +import torch +import torch.distributed as dist +from omegaconf import DictConfig +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.openai_api.protocol import Tool from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.utils import get_ip, get_open_port from tensordict import TensorDict -from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.nn.utils.rnn import pad_sequence +from transformers import PreTrainedTokenizer from verl import DataProto from verl.third_party.sglang import parallel_state as sglang_ps +from verl.tools.base_tool import BaseTool +from verl.tools.schemas import ( + OpenAIFunctionCallSchema, + OpenAIFunctionParsedSchema, + OpenAIFunctionToolCall, +) from verl.utils.debug import GPUMemoryLogger +from verl.utils.model import compute_position_id_with_mask from verl.utils.net_utils import is_ipv6 -from verl.utils.torch_functional import get_response_mask, pad_sequence_to_length +from verl.utils.torch_functional import ( + get_response_mask, + pad_sequence_to_length, +) from verl.workers.rollout.base import BaseRollout +from verl.workers.rollout.schemas import ( + AsyncRolloutRequest, + AsyncRolloutRequestStateEnum, + FinishReasonTypeEnum, + Message, +) from verl.workers.rollout.sglang_rollout.utils import broadcast_pyobj +try: + from sglang.srt.function_call.function_call_parser import FunctionCallParser +except ImportError: + from sglang.srt.function_call_parser import FunctionCallParser -if TYPE_CHECKING: - from torch import nn logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) -# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. -def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[int]: +# NOTE(sgm): add for verl. We can optimize it by making +# the dataloader yield List[int] without padding. +def _pre_process_inputs( + pad_token_id, + prompt_token_ids: torch.Tensor, +) -> list[int]: # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is - # not None else self.llm_engine.tokenizer.eos_token_id non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids @@ -71,14 +87,9 @@ def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> list[in # NOTE(linjunrong): adhoc def _post_process_outputs(tokenizer, output): def _map_each_response(resp): - log_probs = [] - output_token_ids = [] - for log_prob, token_ids, _ in resp["meta_info"]["output_token_logprobs"]: - log_probs.append(log_prob) - output_token_ids.append(token_ids) - log_probs = torch.tensor(log_probs) - output_token_ids = torch.tensor(output_token_ids) - return output_token_ids, log_probs + output_token_logprobs = resp["meta_info"]["output_token_logprobs"] + log_probs, output_token_ids = zip(*[(log_prob, token_ids) for log_prob, token_ids, _ in output_token_logprobs]) + return torch.tensor(output_token_ids), torch.tensor(log_probs) out_map = map(lambda x: _map_each_response(x), output) batched_output_token_ids = [] @@ -93,160 +104,377 @@ def _map_each_response(resp): return batched_output_token_ids, batched_logprobs +def get_tool_call_parser_type(tokenizer: PreTrainedTokenizer) -> str: + items = FunctionCallParser.ToolCallParserEnum.items() + for parser_type, parser_cls in items: + parser = parser_cls() + if parser.bot_token in tokenizer.get_vocab() and (parser.eot_token == "" or parser.eot_token in tokenizer.get_vocab()): + return parser_type + else: + raise ValueError(f"No tool call parser found for tokenizer {tokenizer}") + + class SGLangRollout(BaseRollout): def __init__( self, - actor_module: nn.Module | str, + actor_module: str, config: DictConfig, tokenizer, model_hf_config, port=None, trust_remote_code: bool = False, + device_mesh: DeviceMesh | None = None, **kwargs, ): - """A SGLang rollout. It requires the module is supported by the SGLang. + """Synchronized SGLang rollout engine. Args: - actor_module: module here follows huggingface APIs - config: DictConfig - tokenizer: the task/model tokenizer - model_hf_config: the huggingface config to initiallize the generating model in SGLang - **kwargs: train_tp, for Megatron Backend to initialize hybrid engine (zero redundancy) process group + actor_module: Huggingface model name or path to the model. The + model should be supported by SGLang. + config: A DictConfig object containing SGLang-specific operational + parameters and rollout settings. + Refer to https://docs.sglang.ai/backend/server_arguments.html + tokenizer: The tokenizer instance compatible with the actor_module. + model_hf_config: The Hugging Face model's configuration (e.g., + `transformers.PretrainedConfig`). It provides architectural + details and hyperparameters like `max_position_embeddings`, + used by SGLang for correct model initialization. This is + the model's inherent design, not SGLang's runtime behavior. + port: Optional port for multi-node initialization when nnodes > 1. + trust_remote_code: Whether or not to allow for custom models + defined on the Hub in their own modeling files. + device_mesh: Optional `DeviceMesh` object for distributed setup. + **kwargs: Additional keyword arguments, primarily `train_tp` for + Megatron Backend integration to initialize hybrid engine + process groups. """ super().__init__() self.config = config + self._device_mesh_cpu = device_mesh os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") + ( + self._tool_schemas, + self._tool_map, + self._tool_call_parser_type, + self._sgl_tools, + self._function_call_parser, + ) = self._initialize_tools(config, tokenizer) + # If turn on `free_cache_engine`, SGLang engine's KV cache + # will be freed after each `generate_sequences` call. assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine" - tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) - assert tensor_parallel_size <= torch.distributed.get_world_size(), "tensor parallel size should be less than or equal to the world size" + logger.info(f"tool_schemas: {self._tool_schemas}, tool_map: {self._tool_map}, tool_call_parser_type: {self._tool_call_parser_type}, sgl_tools: {self._sgl_tools}, function_call_parser: {self._function_call_parser}") + + self._init_distributed_env(device_mesh_cpu=device_mesh, **kwargs) + + self._verify_config(model_hf_config=model_hf_config) + # initialize the inference engine + self._init_inference_engine(trust_remote_code, actor_module, port) + + self._init_sampling_params(**kwargs) + + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id - if kwargs.get("train_tp") is not None: + def _init_distributed_env(self, device_mesh_cpu, **kwargs): + self._device_mesh_cpu = device_mesh_cpu + os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true") + self.tensor_parallel_size = self.config.get("tensor_model_parallel_size", 1) + assert self.tensor_parallel_size <= dist.get_world_size(), "tensor parallel size should be less than or equal to the world size" + self.train_tp = kwargs.get("train_tp", None) + if self.train_tp is not None: # deployed with megatron os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" os.environ["MEGATRON_IMPORT_TIMERS"] = "0" - train_tp = kwargs.get("train_tp") - num_tp_per_train_tp = train_tp // tensor_parallel_size + train_tp = kwargs.get("train_tp", None) + num_tp_per_train_tp = train_tp // self.tensor_parallel_size sglang_ps.initialize_parallel_state( - tensor_model_parallel_size=tensor_parallel_size, + tensor_model_parallel_size=self.tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp, ) - assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length" - tp_size = tensor_parallel_size + tp_size = self.tensor_parallel_size world_size = int(os.getenv("WORLD_SIZE", "-1")) # init device mesh - device_mesh_kwargs = dict( - mesh_shape=(world_size // tp_size, tp_size, 1), - mesh_dim_names=["dp", "tp", "pp"], - ) + if self._device_mesh_cpu is None: + device_mesh_kwargs = dict( + mesh_shape=(world_size // tp_size, tp_size, 1), + mesh_dim_names=["dp", "tp", "pp"], + ) - device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) - # device_mesh_device = init_device_mesh("cuda", **device_mesh_kwargs) + self._device_mesh_cpu = init_device_mesh("cpu", **device_mesh_kwargs) + self._rank = self._device_mesh_cpu.get_rank() + self._tp_rank = self._device_mesh_cpu["tp"].get_local_rank() + self._tp_size = self._device_mesh_cpu["tp"].size() + if self._rank == 0: + logger.info(f"_init_distributed_env: :tp_world: {self._tp_size}, global_world: {world_size}") # get tp_rank of this process in this tp group - rank = device_mesh_cpu.get_rank() - visible_devices = [None] * device_mesh_cpu.size(1) + visible_devices = [None] * self._device_mesh_cpu.size(1) + + torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], self._device_mesh_cpu.get_group("tp")) + self.visible_devices_set = set(",".join(visible_devices).split(",")) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(self.visible_devices_set))) - torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], device_mesh_cpu.get_group("tp")) - visible_devices_set = set(",".join(visible_devices).split(",")) - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(sorted(list(visible_devices_set))) + def _verify_config(self, model_hf_config): + if not self.config.get("max_model_len", None): + self.config.max_model_len = self.config.prompt_length + self.config.response_length + assert self.config.max_model_len >= self.config.prompt_length + self.config.response_length, f"""max_model_len should be greater than total sequence length (prompt_length + response_length): + {self.config.max_model_len} >= {self.config.prompt_length} + {self.config.response_length}""" + assert model_hf_config.max_position_embeddings >= self.config.max_model_len, "model context length should be greater than total sequence length" + # currently max_turns stand for max number of tool calls + if self.config.multi_turn.max_turns is None: + self.config.multi_turn.max_turns = self.config.max_model_len // 3 - nnodes = -(-tp_size // len(visible_devices_set)) + def _init_inference_engine(self, trust_remote_code, actor_module, port): + # initialize the inference engine + nnodes = -(-self._tp_size // len(self.visible_devices_set)) if nnodes > 1: ip = get_ip() port = get_open_port() if port is None else port [ip, port] = broadcast_pyobj( [ip, port], - rank=rank, - dist_group=device_mesh_cpu.get_group("tp"), - src=device_mesh_cpu["tp"].mesh[0].item(), + rank=self._rank, + dist_group=self._device_mesh_cpu.get_group("tp"), + src=self._device_mesh_cpu["tp"].mesh[0].item(), force_cpu_device=False, ) dist_init_addr = f"[{ip}]:{port}" if is_ipv6(ip) else f"{ip}:{port}" - else: dist_init_addr = None - load_format = "dummy" if config.load_format.startswith("dummy") else config.load_format - # copy it to avoid secretly modifying the engine config - engine_kwargs = {} if "engine_kwargs" not in config or "sglang" not in config.engine_kwargs else OmegaConf.to_container(deepcopy(config.engine_kwargs.sglang)) - engine_kwargs = {key: val for key, val in engine_kwargs.items() if val is not None} - self.inference_engine = VerlEngine( - model_path=actor_module, - dtype=config.dtype, - mem_fraction_static=config.gpu_memory_utilization, - device_mesh_cpu=device_mesh_cpu["tp"], - enable_memory_saver=True, - base_gpu_id=0, - gpu_id_step=1, - load_format=load_format, - dist_init_addr=dist_init_addr, - nnodes=nnodes, - trust_remote_code=trust_remote_code, - # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new - # when random.seed is being set during training - port=30000 + rank, - # Note: Enable below to display SGLang engine logs at INFO level - # log_level="INFO", - # Note: Enable below to display ReqInput in details, be careful about the log volume - # log_requests=True, - # Note: Log level for ReqInput, 0 for concise, 1 for log middle leve, 2 for verbose - # log_requests_level=2, - # Note: Enable below to limit the number of running requests - # max_running_requests=1, - **engine_kwargs, - ) + load_format = "dummy" if self.config.load_format.startswith("dummy") else self.config.load_format + tp_size_per_node = self._tp_size // nnodes + node_rank = self._tp_rank // tp_size_per_node + first_rank_in_node = self._tp_rank % tp_size_per_node == 0 - # offload - self.inference_engine.release_memory_occupation() + if first_rank_in_node: + rank = dist.get_rank() + os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" + self._engine = Engine( + model_path=actor_module, + dtype=self.config.dtype, + mem_fraction_static=self.config.gpu_memory_utilization, + enable_memory_saver=True, + base_gpu_id=0, + gpu_id_step=1, + tp_size=self._tp_size, + node_rank=node_rank, + load_format=load_format, + dist_init_addr=dist_init_addr, + nnodes=nnodes, + trust_remote_code=trust_remote_code, + # NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new + # when random.seed is being set during training + port=30000 + rank, + # NOTE(Chenyang): if you want to debug the SGLang engine output + # please set the following parameters + # Otherwise, it will make the engine run too slow + # log_level="INFO", + # log_requests=True, + # log_requests_level=2, + # max_running_requests=1, + ) + else: + self._engine = None + self.sharding_manager = None + # offload + if self._tp_rank == 0: + self._engine.release_memory_occupation() + self.is_sleep = True + + def _init_sampling_params(self, **kwargs): kwargs = dict( n=1, - max_new_tokens=config.response_length, + max_new_tokens=self.config.response_length, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, ) # supporting adding any sampling params from the config file - for k in config.keys(): + for k in self.config.keys(): if hasattr(SamplingParams(), str(k)): - kwargs[k] = config.get(k) - print(f"kwargs: {kwargs}") + kwargs[k] = self.config.get(k) self.sampling_params = kwargs - self.tokenizer = tokenizer - self.pad_token_id = tokenizer.pad_token_id + + def _initialize_tools(self, config, tokenizer): + """Initialize tools from configuration. + + Args: + config: Configuration object containing tool-related settings, + specifically `config.multi_turn.tool_config_path`. + tokenizer: The tokenizer instance used for parsing tool calls from + the model's generated text. + + Returns: + tuple: A tuple containing: + - tool_schemas (list[dict]): OpenAI-formatted JSON schemas + defining each tool's capabilities. + - tool_map (dict[str, BaseTool]): A dictionary mapping tool + names to their executable `BaseTool` objects. + - tool_call_parser_type (str): The identifier for the specific + parser type (e.g., 'json_mode', 'tool_code') used to extract + tool calls. + - sgl_tools (list[sglang.srt.openai_api.protocol.Tool]): Tool + definitions optimized for SGLang's internal engine. + - function_call_parser (sglang.srt.function_call_parser.FunctionCallParser): + The active parser instance responsible for extracting + structured tool calls from model outputs. + """ + if config.multi_turn.tool_config_path is None: + return [], {}, None, [], None + + import importlib.util + import sys + + from omegaconf import OmegaConf + + from verl.tools.schemas import OpenAIFunctionToolSchema + + def initialize_tools_from_config(tools_config) -> list: + tool_list = [] + + for tool_config in tools_config.tools: + cls_name = tool_config.class_name + module_name, class_name = cls_name.rsplit(".", 1) + + if module_name not in sys.modules: + spec = importlib.util.find_spec(module_name) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + else: + module = sys.modules[module_name] + + tool_cls = getattr(module, class_name) + + tool_schema_dict = OmegaConf.to_container(tool_config.tool_schema, resolve=True) + tool_schema = OpenAIFunctionToolSchema.model_validate(tool_schema_dict) + + tool = tool_cls( + config=OmegaConf.to_container(tool_config.config, resolve=True), + tool_schema=tool_schema, + ) + tool_list.append(tool) + + return tool_list + + tools_config_file = config.multi_turn.tool_config_path + tools_config = OmegaConf.load(tools_config_file) + tool_list = initialize_tools_from_config(tools_config) + logger.info(f"Initialize tools from configuration.: tool_list: {tool_list}") + tool_schemas = [tool.get_openai_tool_schema().model_dump() for tool in tool_list] + tool_map = {tool.name: tool for tool in tool_list} + tool_call_parser_type = get_tool_call_parser_type(tokenizer) + sgl_tools = [Tool.model_validate(tool_schema) for tool_schema in tool_schemas] + function_call_parser = FunctionCallParser( + sgl_tools, + tool_call_parser_type, + ) + + return ( + tool_schemas, + tool_map, + tool_call_parser_type, + sgl_tools, + function_call_parser, + ) @contextmanager def update_sampling_params(self, **kwargs): - # update sampling params - old_sampling_params_args = {} - if kwargs: - for key, value in kwargs.items(): - if key in self.sampling_params: - old_value = self.sampling_params[key] - old_sampling_params_args[key] = old_value - self.sampling_params[key] = value - yield - # roll back to previous sampling params - # if len(old_sampling_params_args): - for key, value in old_sampling_params_args.items(): - self.sampling_params[key] = value + """ + Temporarily updates the model's sampling parameters for the + duration of a `with` block. Parameters are automatically fall + back to their original values upon exiting the block. + + Args: + **kwargs: Keyword arguments representing sampling parameters + to be updated. Only parameters that already exist in + `self.sampling_params` will be updated. + """ + # Store original values of parameters that will be updated + old_sampling_params_args = {key: self.sampling_params[key] for key in kwargs if key in self.sampling_params} + + # Update sampling parameters with new values + for key, value in kwargs.items(): + if key in self.sampling_params: + self.sampling_params[key] = value + + try: + yield + # Yield and execute the code within the 'with' block + finally: + # Always restore original values, even if an error + # occurred in the `with` block + for key, value in old_sampling_params_args.items(): + self.sampling_params[key] = value @GPUMemoryLogger(role="sglang rollout", logger=logger) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: - # if self.config.free_cache_engine: + if self.config.multi_turn.enable: + return self._req_level_generate_sequences(prompts, **kwargs) + return self._batch_level_generate_sequences(prompts, **kwargs) - idx = prompts.batch["input_ids"] # (bs, prompt_length) - # left-padded attention_mask + @GPUMemoryLogger(role="sglang rollout", logger=logger) + @torch.no_grad() + def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + """Generates sequences for a batch of prompts. + For single-turn generation, all prompts are processed in one request. + For multi-turn generation, each prompt is processed separately via + `_generate_req_level_sequences` for better tool calling control. + `_generate_batch_level_sequences` involves: + 1. Extracting and pre-processing prompt token IDs from the input + `prompts`. This includes handling padding and preparing raw + token ID lists. + 2. Preparing inputs for the SGLang engine, including multi-modal + data if present. + 3. Invoking the SGLang engine (`self._engine.async_generate`, + an async coroutine) with the batch of processed inputs and + specified sampling parameters on the master TP rank. + 4. Broadcasting the results from the master TP rank to all + other TP ranks. + 5. Post-processing the engine's output to format the generated + token IDs and (if applicable) log probabilities. + 6. Constructing the final sequences by concatenating original + prompts with the generated responses. + 7. Updating attention masks and position IDs to reflect the full + concatenated sequences. + 8. If `self.config.free_cache_engine` is true, the SGLang engine's + KV cache is flushed after generation on the master TP rank. + Args: + prompts: A `DataProto` object containing the batch of + input prompts, including tensor data (like `input_ids`, + `attention_mask`) and meta-information (like `eos_token_id`, + `do_sample`). + **kwargs: Additional keyword arguments that can override the + default sampling parameters (e.g., `temperature`, `top_p`, + `max_new_tokens`). These are temporarily applied using + `update_sampling_params`. + Returns: + DataProto: A `DataProto` object containing the batch of + generated sequences. This includes tensors for `prompts` + (original input IDs), `responses` (generated token IDs), + `input_ids` (concatenated prompt and response), + `attention_mask`, and `position_ids` for the full + sequences. + Note that when `n > 1`, each prompt generates multiple sequences, + so we need to replicate its non-tensor data (i.e. raw prompts, + messages, reward scores, etc.) n times to match the expanded + tensor data. This is done in the `_non_tensor_batch` dictionary. + """ + # input ids: (bs, prompt_length), left-padded + idx = prompts.batch["input_ids"] + # attention_mask: (bs, seq_length), left-padded attention_mask = prompts.batch["attention_mask"] position_ids = prompts.batch["position_ids"] - # used to construct attention_mask + # used to generate attention mask for the + # response based on EOS token position eos_token_id = prompts.meta_info["eos_token_id"] batch_size = idx.size(0) @@ -254,22 +482,28 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # Extract non-tensor data non_tensor_batch = prompts.non_tensor_batch if "raw_prompt_ids" not in non_tensor_batch: - non_tensor_batch["raw_prompt_ids"] = np.array([_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object) + non_tensor_batch["raw_prompt_ids"] = np.array( + [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], + dtype=object, + ) if "multi_modal_data" in non_tensor_batch: sglang_inputs = [] - for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop("raw_prompt_ids"), non_tensor_batch.pop("multi_modal_data")): + for raw_prompt_ids, multi_modal_data in zip( + non_tensor_batch.pop("raw_prompt_ids"), + non_tensor_batch.pop("multi_modal_data"), + ): sglang_inputs.append( { "prompt_token_ids": raw_prompt_ids, "multi_modal_data": multi_modal_data, - "image_data": multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None, + "image_data": (multi_modal_data.get("image", None) if isinstance(multi_modal_data, dict) else None), } ) else: sglang_inputs = [{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")] - # Ensure token IDs are lists + # Ensure token IDs are lists or numpy arrays for input_data in sglang_inputs: if isinstance(input_data["prompt_token_ids"], np.ndarray): input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist() @@ -308,22 +542,35 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): # print(f"{self.sampling_params=}") - output = self.inference_engine.generate( - prompt=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - return_logprob=True, - input_ids=idx_list, - image_data=image_list, + if self._tp_rank == 0: + loop = asyncio.get_event_loop() + output = loop.run_until_complete( + self._engine.async_generate( + prompt=None, # because we have already convert it to prompt token id + sampling_params=self.sampling_params, + return_logprob=True, + input_ids=idx_list, + image_data=image_list, + ) + ) + else: + output = None + # Most naive implementation, can extract tensor and send via gloo if too slow + [output] = broadcast_pyobj( + data=[output], + rank=self._rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, ) - out = _post_process_outputs(self.tokenizer, output) response = out[0].to(idx.device) - # log_probs = out[1].to(idx.device) + rollout_log_probs = out[1].to(idx.device) if response.shape[1] < self.config.response_length: response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) + rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id) # utilize current sampling params if self.sampling_params.get("n", 1) > 1 and do_sample: @@ -357,7 +604,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor + "rollout_log_probs": rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, @@ -365,16 +612,487 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: ) # free cache engine - if self.config.free_cache_engine and self.inference_engine._engine is not None and self.inference_engine._engine.tokenizer_manager is not None: - self.inference_engine._engine.flush_cache() + if self.config.free_cache_engine and self._engine is not None: + self._engine.flush_cache() return DataProto(batch=batch, non_tensor_batch=_non_tensor_batch) - # this function is left for uniform train-inference resharding - def update_weights(self, params_iter): - self.inference_engine.resume_memory_occupation() - self.inference_engine.update_weights_from_tensor(params_iter, load_format=None) + async def _async_rollout_a_request( + self, + req: AsyncRolloutRequest, + do_sample: bool = True, + is_validate: bool = False, + **kwargs, + ) -> AsyncRolloutRequest: + assert self._tp_rank == 0, "only the master process can call this function" + _req = deepcopy(req) + finish_reason_type = None + output = None + + current_turns = 0 + while current_turns < self.config.multi_turn.max_turns: + if _req.state == AsyncRolloutRequestStateEnum.PENDING: + await self._handle_pending_state(_req) + _req.state = AsyncRolloutRequestStateEnum.RUNNING + elif _req.state == AsyncRolloutRequestStateEnum.TOOL_CALLING: + if _req.messages[-1].tool_calls is not None: + parsed_tool_calls = _req.messages[-1].tool_calls + tool_call_results = await asyncio.gather( + *[ + self._tool_map[tool_call.function.name].execute( + _req.request_id, + tool_call.function.arguments, + **_req.tools_kwargs[tool_call.function.name].get("execute_kwargs", {}), + ) + for tool_call in parsed_tool_calls + ] + ) + for i, (tool_call, (resp, reward, metrics)) in enumerate(zip(parsed_tool_calls, tool_call_results)): + _req.add_tool_response_message( + self.tokenizer, + resp, + (i == len(parsed_tool_calls) - 1), + format=self.config.multi_turn.format, + ) + _req.update_metrics(metrics, tool_call.function.name) + if len(_req.input_ids) >= self.config.max_model_len: + break + if len(_req.input_ids) >= self.config.max_model_len: + finish_reason_type = FinishReasonTypeEnum.STOP + break + _req.state = AsyncRolloutRequestStateEnum.RUNNING + else: + raise ValueError(f"Unexpected tool calling last message state: {_req.messages[-1]}") + elif _req.state == AsyncRolloutRequestStateEnum.RUNNING: + output = await self._handle_engine_call(_req, do_sample, is_validate, **kwargs) + content = output["text"] + finish_reason_type = FinishReasonTypeEnum.from_str(output["meta_info"]["finish_reason"]["type"]) + current_turns += 1 + if finish_reason_type == FinishReasonTypeEnum.LENGTH: + _req.add_assistant_message( + self.tokenizer, + content, + already_over_long=True, + format=self.config.multi_turn.format, + ) + break + else: + if self._function_call_parser and self._function_call_parser.has_tool_call(content): + finish_reason_type = FinishReasonTypeEnum.TOOL_CALL + _req.state = AsyncRolloutRequestStateEnum.TOOL_CALLING + try: + normed_content, tool_calls = self._function_call_parser.parse_non_stream(content) + except JSONDecodeError: + normed_content = content + tool_calls = [] + except AttributeError: + normed_content = content + tool_calls = [] + parsed_tool_calls = [] + for tool_call in tool_calls: + function, has_decode_error = OpenAIFunctionCallSchema.from_openai_function_parsed_schema( + OpenAIFunctionParsedSchema( + name=tool_call.name, + arguments=tool_call.parameters, + ) + ) + # Drop the tool call if its arguments has decode error + if has_decode_error: + continue + parsed_tool_calls.append( + OpenAIFunctionToolCall( + id=str(tool_call.tool_index), + function=function, + ) + ) + if len(parsed_tool_calls) > 0: + _req.add_assistant_message( + self.tokenizer, + normed_content, + tool_calls=parsed_tool_calls, + format=self.config.multi_turn.format, + ) + else: + _req.add_assistant_message( + self.tokenizer, + content, + format=self.config.multi_turn.format, + ) + finish_reason_type = FinishReasonTypeEnum.STOP + _req.state = AsyncRolloutRequestStateEnum.COMPLETED + break + else: + _req.add_assistant_message( + self.tokenizer, + content, + format=self.config.multi_turn.format, + ) + break + + if current_turns >= self.config.multi_turn.max_turns: + finish_reason_type = FinishReasonTypeEnum.STOP + + # Calculate the reward for each tool + async def calc_reward_and_release_fn(name: str, tool: BaseTool): + reward = await tool.calc_reward(_req.request_id, **_req.tools_kwargs[name].get("calc_reward_kwargs", {})) + await tool.release(_req.request_id, **_req.tools_kwargs[name].get("release_kwargs", {})) + return name, reward + + tool_reward_tasks = [] + for name in _req.tools_kwargs.keys(): + tool = self._tool_map[name] + tool_reward_tasks.append(calc_reward_and_release_fn(name, tool)) + tool_reward_scores = await asyncio.gather(*tool_reward_tasks) + tool_reward_scores = dict(tool_reward_scores) + _req.finalize(self.tokenizer, tool_reward_scores, finish_reason_type) + + return _req + + async def _handle_engine_call(self, _req: AsyncRolloutRequest, do_sample: bool, is_validate: bool, **kwargs) -> dict: + generation_prompt_ids = _req.get_generation_prompt(self.tokenizer) + max_new_tokens = min(self.config.response_length, self.config.max_model_len - len(generation_prompt_ids) - 1) + if not do_sample: + kwargs = dict( + n=1, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + temperature=0, + top_p=1, + top_k=-1, + ignore_eos=False, + min_new_tokens=0, + max_new_tokens=self.config.response_length, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ) + elif is_validate: + # TODO: try ** + kwargs = { + "top_k": self.config.val_kwargs.top_k, + "top_p": self.config.val_kwargs.top_p, + "temperature": self.config.val_kwargs.temperature, + "n": 1, # if validate, already repeat in ray_trainer + } + kwargs["max_new_tokens"] = max_new_tokens + if "n" not in kwargs or kwargs["n"] > 1: # group size is supported in preprocess + kwargs["n"] = 1 + # users can customize different sampling_params at different run + with self.update_sampling_params(**kwargs): + output = await self._engine.async_generate( + input_ids=generation_prompt_ids, + sampling_params=self.sampling_params, + return_logprob=False, + ) + return output + + async def _handle_pending_state(self, _req: AsyncRolloutRequest) -> AsyncRolloutRequest: + if _req.tools is not None: + tool_creation_coroutines = [] + for tool_schema in _req.tools: + tool = self._tool_map[tool_schema.function.name] + create_kwargs = _req.tools_kwargs[tool.name].get("create_kwargs", {}) + tool_creation_coroutines.append(tool.create(_req.request_id, **create_kwargs)) + await asyncio.gather(*tool_creation_coroutines) + + @GPUMemoryLogger(role="sglang rollout", logger=logger) + @torch.no_grad() + def generate_sequences_with_tools(self, prompts: DataProto, **kwargs) -> DataProto: + logger.warning( + "`generate_sequences_with_tools` is deprecated, please use `generate_sequences(...)`", + DeprecationWarning, + stacklevel=2, + ) + return self._req_level_generate_sequences(prompts, **kwargs) + + @GPUMemoryLogger(role="sglang rollout", logger=logger) + @torch.no_grad() + def _req_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: + # Async rollout with tools support + do_sample = prompts.meta_info.get("do_sample", True) + is_validate = prompts.meta_info.get("validate", False) + tgt_device = prompts.batch["input_ids"].device + if self._tp_rank == 0: + req_list = self._preprocess_prompt_to_async_rollout_requests( + prompts, + n=1 if is_validate else self.config.n, + ) + loop = asyncio.get_event_loop() + output_req_list = loop.run_until_complete( + asyncio.gather( + *[self._async_rollout_a_request(req, do_sample, is_validate, **kwargs) for req in req_list], + ) + ) + sorted_output_req_list = sorted(output_req_list, key=lambda x: (x.batch_data_id, x.rollout_offset)) + else: + sorted_output_req_list = None + + [sorted_output_req_list] = broadcast_pyobj( + data=[sorted_output_req_list], + rank=self._rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + # Construct the batch data + prompt_ids, response_ids = [], [] + prompt_attention_mask, response_attention_mask = [], [] + prompt_position_ids, response_position_ids = [], [] + prompt_loss_mask, response_loss_mask = [], [] + messages = [] + reward_scores = [] + for req in sorted_output_req_list: + assert req.state == AsyncRolloutRequestStateEnum.COMPLETED, f"Request {req.request_id} is not completed" + assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), f"""Request {req.request_id} has different length of + {len(req.input_ids)=}, {len(req.attention_mask)=}, {len(req.position_ids)=}, {len(req.loss_mask)=}""" + error_message_lines = [ + f"""Request {req.request_id} has input_ids length {len(req.input_ids)} + greater than max_model_len {self.config.max_model_len}""", + f"Decoded input_ids: {self.tokenizer.decode(req.input_ids)}", + f"Decoded prompt_ids: {self.tokenizer.decode(req.prompt_ids)}", + f"Decoded response_ids: {self.tokenizer.decode(req.response_ids)}", + f"Messages: {req.messages}", + f"Max model length: {req.max_model_len}", + ] + error_message = "\n".join(error_message_lines) + assert len(req.input_ids) <= self.config.max_model_len, error_message + + prompt_ids.append(torch.tensor(req.prompt_ids, dtype=torch.int, device=tgt_device)) + response_ids.append(torch.tensor(req.response_ids, dtype=torch.int, device=tgt_device)) + if len(req.response_ids) > self.config.response_length: + logger.warning( + f"""{req.request_id=} has response_ids length {len(req.response_ids)} + greater than max_response_len {self.config.response_length},\n{req=}""" + ) + prompt_attention_mask.append(torch.tensor(req.prompt_attention_mask, dtype=torch.int, device=tgt_device)) + response_attention_mask.append(torch.tensor(req.response_attention_mask, dtype=torch.int, device=tgt_device)) + prompt_position_ids.append(torch.tensor(req.prompt_position_ids, dtype=torch.int, device=tgt_device)) + response_position_ids.append(torch.tensor(req.response_position_ids, dtype=torch.int, device=tgt_device)) + prompt_loss_mask.append(torch.tensor(req.prompt_loss_mask, dtype=torch.int, device=tgt_device)) + response_loss_mask.append(torch.tensor(req.response_loss_mask, dtype=torch.int, device=tgt_device)) + messages.append({"messages": req.messages}) + reward_scores.append(req.reward_scores) + + prompt_ids = pad_sequence( + prompt_ids, + batch_first=True, + padding_value=self.pad_token_id, + padding_side="left", + ) + if prompt_ids.shape[1] < self.config.prompt_length: + prompt_ids = pad_sequence_to_length(prompt_ids, self.config.prompt_length, self.pad_token_id, left_pad=True) + response_ids = pad_sequence(response_ids, batch_first=True, padding_value=self.pad_token_id) + if response_ids.shape[1] < self.config.response_length: + response_ids = pad_sequence_to_length(response_ids, self.config.response_length, self.pad_token_id) + prompt_attention_mask = pad_sequence( + prompt_attention_mask, + batch_first=True, + padding_value=0, + padding_side="left", + ) + if prompt_attention_mask.shape[1] < self.config.prompt_length: + prompt_attention_mask = pad_sequence_to_length(prompt_attention_mask, self.config.prompt_length, 0, left_pad=True) + response_attention_mask = pad_sequence(response_attention_mask, batch_first=True, padding_value=0) + if response_attention_mask.shape[1] < self.config.response_length: + response_attention_mask = pad_sequence_to_length(response_attention_mask, self.config.response_length, 0) + prompt_position_ids = pad_sequence(prompt_position_ids, batch_first=True, padding_value=0, padding_side="left") + if prompt_position_ids.shape[1] < self.config.prompt_length: + prompt_position_ids = pad_sequence_to_length(prompt_position_ids, self.config.prompt_length, 0, left_pad=True) + response_length = response_ids.size(1) + delta_position_id = torch.arange(1, response_length + 1, device=response_ids.device) + delta_position_id = delta_position_id.unsqueeze(0).repeat(len(sorted_output_req_list), 1) + response_position_ids = prompt_position_ids[:, -1:] + delta_position_id + prompt_loss_mask = pad_sequence(prompt_loss_mask, batch_first=True, padding_value=0, padding_side="left") + if prompt_loss_mask.shape[1] < self.config.prompt_length: + prompt_loss_mask = pad_sequence_to_length(prompt_loss_mask, self.config.prompt_length, 0, left_pad=True) + response_loss_mask = pad_sequence(response_loss_mask, batch_first=True, padding_value=0) + if response_loss_mask.shape[1] < self.config.response_length: + response_loss_mask = pad_sequence_to_length(response_loss_mask, self.config.response_length, 0) + + input_ids = torch.cat((prompt_ids, response_ids), dim=-1) + attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) + position_ids = torch.cat((prompt_position_ids, response_position_ids), dim=-1) + loss_mask = torch.cat((prompt_loss_mask, response_loss_mask), dim=-1) + + # Construct the batch data + batch = TensorDict( + { + "prompts": prompt_ids, + "responses": response_ids, + "input_ids": input_ids, # here input_ids become the whole sentences + "attention_mask": attention_mask, + "position_ids": position_ids, + "loss_mask": loss_mask, + }, + batch_size=len(sorted_output_req_list), + ) + + # free cache engine + if self.config.free_cache_engine and self._engine is not None and self._tp_rank == 0: + self._engine.flush_cache() + + return DataProto( + batch=batch, + non_tensor_batch={ + "messages": np.array(messages), + "reward_scores": np.array(reward_scores), + }, + ) + + def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: int) -> list[AsyncRolloutRequest]: + assert "raw_prompt" in prompts.non_tensor_batch, "need data.return_raw_chat=True, due to no official way do parse_messages" + req_list = [] + for data_idx, raw_prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]): + for rollout_offset in range(n): + if self._tool_schemas: + _tools_kwargs = prompts.non_tensor_batch["tools_kwargs"][data_idx] + _tool_schemas = [] + for k in _tools_kwargs.keys(): + _tool_schemas.append(self._tool_map[k].get_openai_tool_schema()) + prompt_with_chat_template = self.tokenizer.apply_chat_template( + conversation=raw_prompt, + tools=[tool.model_dump() for tool in _tool_schemas], + add_generation_prompt=True, + tokenize=False, + return_tensors="pt", + ) + input_data = self.tokenizer( + prompt_with_chat_template, + return_tensors="pt", + add_special_tokens=False, + ) + _input_ids = input_data["input_ids"][0].tolist() + _attention_mask = input_data["attention_mask"][0].tolist() + _position_ids = compute_position_id_with_mask(input_data["attention_mask"][0]).tolist() + if len(_input_ids) > self.config.prompt_length: + logger.warning( + "Prompt {} has length {} greater than max_prompt_len {}", + data_idx, + len(_input_ids), + self.config.prompt_length, + ) + _input_ids = _input_ids[: self.config.prompt_length] + _attention_mask = _attention_mask[: self.config.prompt_length] + _position_ids = _position_ids[: self.config.prompt_length] + else: + _input_ids = _pre_process_inputs(self.pad_token_id, prompts.batch["input_ids"][data_idx]) + _attention_mask = _pre_process_inputs(0, prompts.batch["attention_mask"][data_idx]) + _position_ids = compute_position_id_with_mask(torch.tensor(_attention_mask)).tolist() + _tool_schemas = [] + _tools_kwargs = {} + + req = AsyncRolloutRequest( + batch_data_id=data_idx, + rollout_offset=rollout_offset, + request_id=str(uuid4()), + state=AsyncRolloutRequestStateEnum.PENDING, + messages=[Message.model_validate(msg) for msg in raw_prompt], + tools=_tool_schemas, + tools_kwargs=_tools_kwargs, + input_ids=_input_ids, + prompt_ids=_input_ids, + response_ids=[], + attention_mask=_attention_mask, + prompt_attention_mask=_attention_mask, + response_attention_mask=[], + position_ids=_position_ids, + prompt_position_ids=_position_ids, + response_position_ids=[], + loss_mask=[0] * len(_input_ids), + prompt_loss_mask=[0] * len(_input_ids), + response_loss_mask=[], + reward_scores={}, + max_response_len=self.config.response_length, + max_model_len=min( + self.config.max_model_len, + self.config.prompt_length + self.config.response_length, + ), + ) + + error_message = f"Request {req.request_id} has mismatched lengths: input_ids={len(req.input_ids)}, attention_mask={len(req.attention_mask)}, position_ids={len(req.position_ids)}, loss_mask={len(req.loss_mask)}" + assert len(req.input_ids) == len(req.attention_mask) == len(req.position_ids) == len(req.loss_mask), error_message + + req_list.append(req) + + return req_list + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + if method == "chat_completion": + json_request = args[0] + + formatted_messages = [] + for msg in json_request["messages"]: + role = msg.get("role", "user") + content = msg.get("content", "") + formatted_messages.append(f"{role}: {content}") + prompt_str = "\n".join(formatted_messages) + + sampling_params_dict = { + "n": json_request.get("n", 1), + "max_new_tokens": json_request.get("max_completion_tokens", self.config.response_length), + "temperature": json_request.get("temperature", 1.0), + "top_p": json_request.get("top_p", 1.0), + } + output = None + if self._tp_rank == 0: + loop = asyncio.get_event_loop() + output = loop.run_until_complete( + self._engine.async_generate( + prompt=prompt_str, + sampling_params=sampling_params_dict, + return_logprob=True, + ) + ) + output = broadcast_pyobj( + data=[output], + rank=self._rank, + dist_group=self._device_mesh_cpu["tp"].get_group(), + src=self._device_mesh_cpu["tp"].mesh[0].item(), + force_cpu_device=False, + ) + + # only return value from master rank + if self._tp_rank != 0: + return None + # build openai chat completion format + choices = [] + id = None + for i, content in enumerate(output): + choices.append( + { + "index": i, + "message": { + "role": "assistant", + "content": content["text"], + }, + "finish_reason": content["meta_info"]["finish_reason"]["type"], + } + ) + id = content["meta_info"]["id"] + + return { + "id": "chatcmpl-" + id, + "object": "chat.completion", + "created": int(time.time()), + "model": json_request.get("model", "sglang_model"), + "choices": choices, + } + else: + raise ValueError(f"not supported method : {method}") + + # this function is left for uniform train-inference resharding + + def resume(self): + if not self.is_sleep: + return + self.sharding_manager.__enter__() # pylint: disable=C2801 + + self.is_sleep = False # this function is left for uniform train-inference resharding def offload(self): - self.inference_engine.release_memory_occupation() + if self.is_sleep: + return + + self.sharding_manager.__exit__(None, None, None) + self.is_sleep = True diff --git a/verl/workers/rollout/sglang_rollout/utils.py b/verl/workers/rollout/sglang_rollout/utils.py index c887e3f9e7d..438facd9e14 100644 --- a/verl/workers/rollout/sglang_rollout/utils.py +++ b/verl/workers/rollout/sglang_rollout/utils.py @@ -34,7 +34,9 @@ def broadcast_pyobj( The `rank` here refer to the source rank on global process group (regardless of dist_group argument). """ - device = torch.device("cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu") + device = torch.device( + "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" + ) if rank == src: if len(data) == 0: @@ -44,7 +46,9 @@ def broadcast_pyobj( serialized_data = pickle.dumps(data) size = len(serialized_data) - tensor_data = torch.ByteTensor(np.frombuffer(serialized_data, dtype=np.uint8)).to(device) + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8) + ).to(device) tensor_size = torch.tensor([size], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group) diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index 044d973c3ce..3608d932b6c 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -1,18 +1,5 @@ # Copyright 2023-2024 SGLang Team # Copyright 2025 ModelBest Inc. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== # Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -29,12 +16,10 @@ import logging import os -from typing import Union import torch import torch.distributed as dist from sglang.srt.entrypoints.engine import Engine -from sglang.srt.entrypoints.verl_engine import VerlEngine from sglang.srt.model_executor.model_runner import LocalSerializedTensor from sglang.srt.utils import MultiprocessingSerializer from torch.distributed.device_mesh import DeviceMesh @@ -66,7 +51,7 @@ class FSDPSGLangShardingManager(BaseShardingManager): def __init__( self, module: FSDP, - inference_engine: Union[VerlEngine, Engine], + inference_engine: Engine, model_config, full_params: bool = False, device_mesh: DeviceMesh = None, @@ -144,44 +129,6 @@ def __exit__(self, exc_type, exc_value, traceback): self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) - def update_weights(self, params): - self.inference_engine.resume_memory_occupation() - self.inference_engine.update_weights_from_tensor([(k, v) for k, v in params.items()], load_format=None) - - def release_memory(self): - self.inference_engine.release_memory_occupation() - - def preprocess_data(self, data: DataProto) -> DataProto: - """All gather across tp group to make each rank has identical input.""" - if self.tp_size == 1: - return data - - # TODO: Current impl doesn't consider FSDP with torch micro-dp - group = self.device_mesh["infer_tp"].get_group() - - all_gather_data_proto(data=data, process_group=group) - return data - - def postprocess_data(self, data: DataProto) -> DataProto: - """Get chunk data of this tp rank since we do all gather in preprocess.""" - if self.tp_size == 1: - return data - - return data.chunk(chunks=self.tp_size)[self.tp_rank] - - -class FSDPAsyncSGLangShardingManager(FSDPSGLangShardingManager): - def __init__( - self, - module: FSDP, - inference_engine: Engine, - model_config, - full_params: bool = False, - device_mesh: DeviceMesh = None, - offload_param: bool = False, - ): - super().__init__(module, inference_engine, model_config, full_params, device_mesh, offload_param) - def update_weights(self, params): if self.device_mesh["infer_tp"].get_local_rank() == 0: self.inference_engine.resume_memory_occupation() @@ -218,3 +165,21 @@ def update_weights(self, params): def release_memory(self): if self.device_mesh["infer_tp"].get_local_rank() == 0: self.inference_engine.release_memory_occupation() + + def preprocess_data(self, data: DataProto) -> DataProto: + """All gather across tp group to make each rank has identical input.""" + if self.tp_size == 1: + return data + + # TODO: Current impl doesn't consider FSDP with torch micro-dp + group = self.device_mesh["infer_tp"].get_group() + + all_gather_data_proto(data=data, process_group=group) + return data + + def postprocess_data(self, data: DataProto) -> DataProto: + """Get chunk data of this tp rank since we do all gather in preprocess.""" + if self.tp_size == 1: + return data + + return data.chunk(chunks=self.tp_size)[self.tp_rank] diff --git a/verl/workers/sharding_manager/megatron_sglang.py b/verl/workers/sharding_manager/megatron_sglang.py index 0a1352d9e74..4e047e212ff 100644 --- a/verl/workers/sharding_manager/megatron_sglang.py +++ b/verl/workers/sharding_manager/megatron_sglang.py @@ -22,7 +22,6 @@ import torch from sglang.srt.entrypoints.engine import Engine -from sglang.srt.entrypoints.verl_engine import VerlEngine from torch import nn from torch.distributed.device_mesh import DeviceMesh @@ -50,7 +49,7 @@ class MegatronSGLangShardingManager(BaseShardingManager): def __init__( self, actor_module: nn.ModuleList, - inference_engine: VerlEngine, + inference_engine: Engine, model_config, transformer_config, layer_name_mapping, @@ -113,50 +112,6 @@ def __exit__(self, exc_type, exc_value, traceback): self.gen_random_states = torch.cuda.get_rng_state() torch.cuda.set_rng_state(self.torch_random_states) - def update_weights(self, params): - self.inference_engine.resume_memory_occupation() - self.inference_engine.update_weights_from_tensor(params, load_format=None) - - def release_memory(self): - self.inference_engine.release_memory_occupation() - - @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) - def preprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - all_gather_data_proto(data, self.device_mesh["tp"].get_group()) - return data - - @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) - def postprocess_data(self, data: DataProto) -> DataProto: - # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp - if self.infer_tp_size == 1: - return data - return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()] - - -class MegatronAsyncSGLangShardingManager(MegatronSGLangShardingManager): - def __init__( - self, - actor_module: nn.ModuleList, - inference_engine: Engine, - model_config, - transformer_config, - layer_name_mapping, - weight_converter, - device_mesh: DeviceMesh = None, - ): - super().__init__( - actor_module, - inference_engine, - model_config, - transformer_config, - layer_name_mapping, - weight_converter, - device_mesh, - ) - def update_weights(self, params): if self.device_mesh["tp"].get_local_rank() == 0: self.inference_engine.resume_memory_occupation() @@ -184,3 +139,18 @@ def update_weights(self, params): def release_memory(self): if self.device_mesh["tp"].get_local_rank() == 0: self.inference_engine.release_memory_occupation() + + @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) + def preprocess_data(self, data: DataProto) -> DataProto: + # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp + if self.infer_tp_size == 1: + return data + all_gather_data_proto(data, self.device_mesh["tp"].get_group()) + return data + + @GPUMemoryLogger(role="megatron sglang sharding_manager", logger=logger) + def postprocess_data(self, data: DataProto) -> DataProto: + # DP_COMPUTE_PROTO: all training ranks are dp, the same as fsdp + if self.infer_tp_size == 1: + return data + return data.chunk(chunks=self.infer_tp_size)[self.device_mesh["tp"].get_local_rank()]