diff --git a/.github/workflows/e2e_ascend.yml b/.github/workflows/e2e_ascend.yml index 77e246bf20a..2f719644799 100644 --- a/.github/workflows/e2e_ascend.yml +++ b/.github/workflows/e2e_ascend.yml @@ -257,8 +257,9 @@ jobs: - name: Preprocess gsm8k dataset run: | python examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/.cache/datasets/openai/gsm8k - - name: Running the E2E test with one_step_off_policy algorithm on ASCEND NPU (FSDP2) - run: | - ray stop --force - bash tests/special_npu/run_one_step_off_policy.sh - rm -rf $HOME/ckpts + # TODO(wuxibin): temporary disable until we refactor with checkpoint engine + # - name: Running the E2E test with one_step_off_policy algorithm on ASCEND NPU (FSDP2) + # run: | + # ray stop --force + # bash tests/special_npu/run_one_step_off_policy.sh + # rm -rf $HOME/ckpts diff --git a/.github/workflows/npu_unit_tests.yml b/.github/workflows/npu_unit_tests.yml index 2b2f9256b63..665d6da5439 100644 --- a/.github/workflows/npu_unit_tests.yml +++ b/.github/workflows/npu_unit_tests.yml @@ -12,7 +12,7 @@ # - `special_sanity`: a suite of quick sanity tests # - `special_standalone`: a set of test that are designed to run in dedicated environments -# Accelerators for tests +# Accelerators for tests # - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. # - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. @@ -67,7 +67,7 @@ concurrency: cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} # Declare permissions just read content. -permissions: +permissions: contents: read jobs: @@ -109,7 +109,7 @@ jobs: - name: Run all NPU unit tests run: | export PYTHONPATH=$PYTHONPATH:/Megatron-LM - pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob="*on_cpu.py" --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob="tests/special*" --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" --ignore-glob="*test_rvdz*" --ignore-glob="*test_ray_collectives*" --ignore-glob="*test_nvtx_profile*" --ignore-glob="*test_nccl*" --ignore-glob="*test_nixl*" tests/ + pytest -s -x --ignore-glob="*test_special_*.py" --ignore-glob="*on_cpu.py" --ignore-glob="*test_vllm*" --ignore-glob="*_sglang*" --ignore-glob="*_hf_rollout*" --ignore-glob="tests/models/" --ignore-glob="tests/special*" --ignore-glob="tests/experimental" --ignore-glob="tests/workers/reward_model" --ignore-glob="*test_rvdz*" --ignore-glob="*test_ray_collectives*" --ignore-glob="*test_nvtx_profile*" --ignore-glob="tests/checkpoint_engine" tests/ - name: Testing FSDP2 actor functionality run: | torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/actor/test_special_dp_actor.py @@ -118,4 +118,4 @@ jobs: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/workers/critic/test_special_dp_critic.py - name: Running NPU profiling unit tests run: | - pytest -s -x tests/utils/test_special_mstx_profile.py \ No newline at end of file + pytest -s -x tests/utils/test_special_mstx_profile.py diff --git a/.github/workflows/e2e_fully_async_policy.yml b/.github/workflows/stash/e2e_fully_async_policy.yml similarity index 100% rename from .github/workflows/e2e_fully_async_policy.yml rename to .github/workflows/stash/e2e_fully_async_policy.yml diff --git a/.github/workflows/e2e_one_step_off_policy.yml b/.github/workflows/stash/e2e_one_step_off_policy.yml similarity index 100% rename from .github/workflows/e2e_one_step_off_policy.yml rename to .github/workflows/stash/e2e_one_step_off_policy.yml diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh index 39e6ed24dea..a2d17d45ad4 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn.sh @@ -63,6 +63,5 @@ python3 -m verl.trainer.main_ppo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@ + trainer.total_epochs=15 $@ diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_server.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_server.sh index 47ba0f12db1..17f2ed40b8a 100755 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_server.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_server.sh @@ -58,6 +58,5 @@ python3 -m verl.trainer.main_ppo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \ - trainer.total_epochs=15 \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 $@ + trainer.total_epochs=15 $@ diff --git a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_vllm_fsdp.sh b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_vllm_fsdp.sh index cf5e065097f..c3c40b1076c 100644 --- a/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_vllm_fsdp.sh +++ b/examples/sglang_multiturn/run_qwen2.5-3b_gsm8k_multiturn_vllm_fsdp.sh @@ -50,7 +50,6 @@ python3 -m verl.trainer.main_ppo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ trainer.total_epochs=15 \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \ actor_rollout_ref.rollout.trace.token2text=False \ actor_rollout_ref.rollout.mode=async \ actor_rollout_ref.rollout.multi_turn.enable=true \ diff --git a/examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh b/examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh index 39948693264..f65905bf7ba 100644 --- a/examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh +++ b/examples/sglang_multiturn/run_qwen3_4b_dapo_multiturn.sh @@ -76,7 +76,6 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=512 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \ actor_rollout_ref.rollout.multi_stage_wake_up=True \ actor_rollout_ref.rollout.multi_turn.enable=True \ diff --git a/tests/experimental/agent_loop/test_standalone_rollout.py b/tests/experimental/agent_loop/test_standalone_rollout.py index a530bae8281..c17e70bc19f 100644 --- a/tests/experimental/agent_loop/test_standalone_rollout.py +++ b/tests/experimental/agent_loop/test_standalone_rollout.py @@ -52,6 +52,7 @@ async def test_standalone_rollout(init_config, tp_size): "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "INFO", "VLLM_USE_V1": "1", + "NCCL_P2P_DISABLE": "1", # disable p2p in L20 } } ) diff --git a/tests/special_e2e/ppo_trainer/run_function_reward.sh b/tests/special_e2e/ppo_trainer/run_function_reward.sh index 2160c8b6955..76601d31212 100644 --- a/tests/special_e2e/ppo_trainer/run_function_reward.sh +++ b/tests/special_e2e/ppo_trainer/run_function_reward.sh @@ -21,7 +21,7 @@ ROLLOUT_MODE="async" RETURN_RAW_CHAT="True" SKIP_TOKENIZER_INIT="True" -GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.8} +GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.7} ACTOR_FSDP_PARAM_OFFLOAD=${ACTOR_FSDP_PARAM_OFFLOAD:-False} ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False} REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True} diff --git a/tests/special_e2e/run_ppo_trainer_megatron.sh b/tests/special_e2e/run_ppo_trainer_megatron.sh index cd8033f132e..a10545ecaaf 100644 --- a/tests/special_e2e/run_ppo_trainer_megatron.sh +++ b/tests/special_e2e/run_ppo_trainer_megatron.sh @@ -196,7 +196,6 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ - actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \ ++actor_rollout_ref.rollout.quantization=${ROLLOUT_QUANTIZATION} \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ diff --git a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml index 3dd0b8a38d6..88477e211c0 100644 --- a/tests/trainer/config/legacy_ppo_megatron_trainer.yaml +++ b/tests/trainer/config/legacy_ppo_megatron_trainer.yaml @@ -175,7 +175,7 @@ actor_rollout_ref: tensor_model_parallel_size: 2 max_num_batched_tokens: 8192 max_model_len: null - max_num_seqs: 1024 + max_num_seqs: 256 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} diff --git a/tests/trainer/config/legacy_ppo_trainer.yaml b/tests/trainer/config/legacy_ppo_trainer.yaml index 25919bd15d9..cbbc497739b 100644 --- a/tests/trainer/config/legacy_ppo_trainer.yaml +++ b/tests/trainer/config/legacy_ppo_trainer.yaml @@ -466,7 +466,7 @@ actor_rollout_ref: max_model_len: null # max length of sequences - max_num_seqs: 1024 + max_num_seqs: 256 # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. log_prob_micro_batch_size: null diff --git a/verl/experimental/reward_loop/router/inner_sglang_router.py b/verl/experimental/reward_loop/router/inner_sglang_router.py index 55dc1fac122..e05b17c89fc 100644 --- a/verl/experimental/reward_loop/router/inner_sglang_router.py +++ b/verl/experimental/reward_loop/router/inner_sglang_router.py @@ -21,7 +21,7 @@ import requests from sglang_router.launch_server import RouterArgs, launch_router -from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) diff --git a/verl/experimental/reward_loop/router/naive_router.py b/verl/experimental/reward_loop/router/naive_router.py index 230519fd3d8..a495c0592e3 100644 --- a/verl/experimental/reward_loop/router/naive_router.py +++ b/verl/experimental/reward_loop/router/naive_router.py @@ -25,7 +25,7 @@ from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index b04a5d7a77b..232ad2a0f09 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -218,7 +218,7 @@ actor_rollout_ref: pipeline_model_parallel_size: 1 max_num_batched_tokens: 8192 max_model_len: null - max_num_seqs: 1024 + max_num_seqs: 256 enable_chunked_prefill: true enable_prefix_caching: true logprobs_mode: processed_logprobs @@ -268,7 +268,7 @@ actor_rollout_ref: _target_: verl.workers.config.CustomAsyncServerConfig path: null name: null - update_weights_bucket_megabytes: 512 + update_weights_bucket_megabytes: 2048 trace: _target_: verl.workers.config.TraceConfig backend: null diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 686b5d08d15..1114171680e 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -209,7 +209,7 @@ actor_rollout_ref: pipeline_model_parallel_size: 1 max_num_batched_tokens: 8192 max_model_len: null - max_num_seqs: 1024 + max_num_seqs: 256 enable_chunked_prefill: true enable_prefix_caching: true logprobs_mode: processed_logprobs @@ -259,7 +259,7 @@ actor_rollout_ref: _target_: verl.workers.config.CustomAsyncServerConfig path: null name: null - update_weights_bucket_megabytes: 512 + update_weights_bucket_megabytes: 2048 trace: _target_: verl.workers.config.TraceConfig backend: null diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 6c24d4b88cc..c4bbf8c52a7 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -65,7 +65,7 @@ max_num_batched_tokens: 8192 max_model_len: null # max length of sequences -max_num_seqs: 1024 +max_num_seqs: 256 # may get higher throughput when set to True. When activated, Please increase max_num_batched_tokens or decrease max_model_len. enable_chunked_prefill: True @@ -253,7 +253,7 @@ agent: # 1. Enable `RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES` # 2. Manually set `CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7` # when using Tensor Parallelism (TP) >= 8. -update_weights_bucket_megabytes: 512 +update_weights_bucket_megabytes: 2048 # trace rollout data trace: diff --git a/verl/trainer/constants_ppo.py b/verl/trainer/constants_ppo.py index 7c3a0bf09d6..6fab7e29eb3 100644 --- a/verl/trainer/constants_ppo.py +++ b/verl/trainer/constants_ppo.py @@ -31,6 +31,13 @@ # https://docs.vllm.ai/en/latest/usage/troubleshooting.html?h=nccl_cumem_enable#known-issues # https://github.com/vllm-project/vllm/blob/c6b0a7d3ba03ca414be1174e9bd86a97191b7090/vllm/worker/worker_base.py#L445 "NCCL_CUMEM_ENABLE": "0", + # TODO: disable compile cache due to cache corruption issue + # https://github.com/vllm-project/vllm/issues/31199 + "VLLM_DISABLE_COMPILE_CACHE": "1", + # Needed for multi-processes colocated on same NPU device + # https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0143.html + "HCCL_HOST_SOCKET_PORT_RANGE": "auto", + "HCCL_NPU_SOCKET_PORT_RANGE": "auto", }, } diff --git a/verl/trainer/runtime_env.yaml b/verl/trainer/runtime_env.yaml index 63750cd72f7..7d38fdde25d 100644 --- a/verl/trainer/runtime_env.yaml +++ b/verl/trainer/runtime_env.yaml @@ -3,3 +3,5 @@ excludes: ["/.git/"] env_vars: TORCH_NCCL_AVOID_RECORD_STREAMS: "1" CUDA_DEVICE_MAX_CONNECTIONS: "1" + HCCL_HOST_SOCKET_PORT_RANGE: "auto" + HCCL_NPU_SOCKET_PORT_RANGE: "auto" \ No newline at end of file diff --git a/verl/utils/device.py b/verl/utils/device.py index 66d6b958067..24df2bdab7e 100644 --- a/verl/utils/device.py +++ b/verl/utils/device.py @@ -43,6 +43,14 @@ def is_torch_npu_available(check_device=True) -> bool: is_npu_available = is_torch_npu_available() +def get_resource_name() -> str: + """Function that return ray resource name based on the device type. + Returns: + ray resource name string, either "GPU" or "NPU". + """ + return "GPU" if is_cuda_available else "NPU" + + def get_visible_devices_keyword() -> str: """Get the environment variable name for visible device selection. diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 7a98b6fcbb5..dc0e16d8610 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -581,6 +581,19 @@ def _iter_opts(opt): v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True) if "exp_avg_sq" in v: v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True) + + try: + # Free TransformerEngine's dummy weight gradients cache + # https://github.com/NVIDIA/TransformerEngine/blob/release_v2.10/transformer_engine/pytorch/module/base.py#L64 + from transformer_engine.pytorch.module.base import _dummy_wgrads + + _dummy_wgrads.clear() + except ImportError: + pass + + # Free Megatron-LM's global memory buffer + # get_global_memory_buffer().buffer.clear() + gc.collect() get_torch_device().empty_cache() diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 0afc84e6ef3..1b9d9861e19 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -838,5 +838,6 @@ def update_policy(self, dataloader: Iterable[DataProto], enable_mtp: bool = Fals RouterReplay.clear_global_router_replay_action() RouterReplay.clear_global_indices() + self.actor_optimizer.zero_grad() get_torch_device().empty_cache() return metrics diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index c91564ba353..a6f38599e2c 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -572,6 +572,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): assert isinstance(self.engine, MegatronEngine) + self.engine.optimizer_zero_grad() super().__exit__(exc_type, exc_value, traceback) diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index f6ab2a5c68b..4aa5671f7cf 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -16,7 +16,6 @@ from contextlib import nullcontext from functools import partial from itertools import chain -from typing import Any, Optional import torch from codetiming import Timer @@ -486,8 +485,6 @@ def init_model(self): self.set_dispatch_collect(mesh_name="actor", **self.actor.get_dispatch_collect()) # 3. build rollout engine - # - vllm: vLLMAsyncRollout - # - sglang: ServerAdapter if "rollout" in self.role: rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout) @@ -595,27 +592,3 @@ async def wake_up(self): # important: need to manually set the random states of each tp to be identical. self.torch_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.gen_random_states) - - # ============================ vLLM related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def get_zeromq_address(self): - return self.rollout.get_zeromq_address() - - # ============================ SGLang related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def chat_completion(self, json_request): - ret = await self.rollout.chat_completion(json_request) - return ret - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def generate( - self, - prompt_ids: list[int], - sampling_params: dict[str, Any], - request_id: str, - image_data: Optional[list[Any]] = None, - ) -> list[int]: - ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data) - return ret diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 83372cdb21c..659936891af 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -21,7 +21,6 @@ import os import warnings from dataclasses import asdict -from typing import Any, Optional import numpy as np import psutil @@ -2020,27 +2019,3 @@ async def wake_up(self): async def sleep(self): await self.trainer_mode() return True - - # ============================ vLLM related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def get_zeromq_address(self): - return self.rollout.get_zeromq_address() - - # ============================ SGLang related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def chat_completion(self, json_request): - ret = await self.rollout.chat_completion(json_request) - return ret - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def generate( - self, - prompt_ids: list[int], - sampling_params: dict[str, Any], - request_id: str, - image_data: Optional[list[Any]] = None, - ) -> list[int]: - ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data) - return ret diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index d4367c07030..3d1c3843660 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -19,7 +19,6 @@ import logging import os import time -from typing import Any, Optional import psutil import torch @@ -989,30 +988,6 @@ async def sleep(self): await self.trainer_mode() return True - # ============================ vLLM related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def get_zeromq_address(self): - return self.rollout.get_zeromq_address() - - # ============================ SGLang related ============================ - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def chat_completion(self, json_request): - ret = await self.rollout.chat_completion(json_request) - return ret - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False) - async def generate( - self, - prompt_ids: list[int], - sampling_params: dict[str, Any], - request_id: str, - image_data: Optional[list[Any]] = None, - ) -> list[int]: - ret = await self.rollout.generate(prompt_ids, sampling_params, request_id, image_data=image_data) - return ret - class CriticWorker(MegatronWorker, DistProfilerExtension): def __init__(self, config: McoreCriticConfig): diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py index dfd857c15cf..31d5b9736b7 100644 --- a/verl/workers/rollout/base.py +++ b/verl/workers/rollout/base.py @@ -79,7 +79,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: _ROLLOUT_REGISTRY = { - ("vllm", "async"): "verl.workers.rollout.vllm_rollout.vLLMAsyncRollout", + ("vllm", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", ("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter", ("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter", } diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 75e79e0287e..1e176bafd7b 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -41,19 +41,13 @@ from verl.single_controller.ray import RayClassWithInitArgs from verl.utils.config import omega_conf_to_dataclass -from verl.utils.device import ( - get_visible_devices_keyword, -) +from verl.utils.device import get_visible_devices_keyword +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address from verl.utils.profiler.profile import DistProfiler from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput from verl.workers.rollout.sglang_rollout.sglang_rollout import ServerAdapter, _set_envs_and_config -from verl.workers.rollout.utils import ( - get_free_port, - get_max_position_embeddings, - is_valid_ipv6_address, - run_unvicorn, -) +from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) @@ -93,8 +87,10 @@ def __init__( self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) - if self.config.max_model_len is None: - self.config.max_model_len = get_max_position_embeddings(self.model_config.hf_config) + self.config.max_model_len = min( + get_max_position_embeddings(self.model_config.hf_config), + self.config.prompt_length + self.config.response_length, + ) self.rollout_mode = rollout_mode self.workers = workers diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index adbb84ec4e6..24da85f20c5 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -33,11 +33,11 @@ from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights from torch.distributed.device_mesh import DeviceMesh +from verl.utils.net_utils import is_valid_ipv6_address from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.base import BaseRollout from verl.workers.rollout.sglang_rollout.http_server_engine import AsyncHttpServerAdapter from verl.workers.rollout.sglang_rollout.utils import get_named_tensor_buckets -from verl.workers.rollout.utils import is_valid_ipv6_address logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index fa80e3e39ab..791c14bd1ac 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -25,10 +25,11 @@ from verl.single_controller.ray import RayClassWithInitArgs, SubRayResourcePool from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import is_cuda_available +from verl.utils.net_utils import is_valid_ipv6_address from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput from verl.workers.rollout.trtllm_rollout.trtllm_rollout import ServerAdapter -from verl.workers.rollout.utils import is_valid_ipv6_address, run_unvicorn +from verl.workers.rollout.utils import run_unvicorn logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index b3d1d9b62b6..f95980b24f6 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -31,9 +31,9 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.multiprocessing.reductions import reduce_tensor +from verl.utils.net_utils import is_valid_ipv6_address from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.base import BaseRollout -from verl.workers.rollout.utils import is_valid_ipv6_address logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) diff --git a/verl/workers/rollout/utils.py b/verl/workers/rollout/utils.py index c568cd01bce..16dcfc4a5c9 100644 --- a/verl/workers/rollout/utils.py +++ b/verl/workers/rollout/utils.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -import ipaddress import logging import os -import socket import uvicorn from fastapi import FastAPI +from verl.utils.net_utils import get_free_port + logger = logging.getLogger(__file__) @@ -35,28 +35,6 @@ def get_max_position_embeddings(hf_config) -> int: return int(max_len) -def is_valid_ipv6_address(address: str) -> bool: - try: - ipaddress.IPv6Address(address) - return True - except ValueError: - return False - - -def get_free_port(address: str) -> tuple[int, socket.socket]: - family = socket.AF_INET - if is_valid_ipv6_address(address): - family = socket.AF_INET6 - - sock = socket.socket(family=family, type=socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - sock.bind((address, 0)) - - port = sock.getsockname()[1] - return port, sock - - async def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5) -> tuple[int, asyncio.Task]: server_port, server_task = None, None diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 799351385af..2ecf113c839 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -14,7 +14,7 @@ import os from importlib.metadata import PackageNotFoundError, version -from .vllm_rollout import vLLMAsyncRollout # noqa: F401 +from .vllm_rollout import ServerAdapter # noqa: F401 def get_version(pkg): diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index bb0c9f0b5b2..1d222dc20ec 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -11,15 +11,57 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import ctypes +import gc import json -from typing import Any +import logging +import os +import platform +import signal +import threading +from types import MethodType +from typing import Any, Callable, TypedDict + +import torch +import zmq + +from verl.utils.device import get_torch_device, is_npu_available +from verl.utils.vllm import TensorLoRARequest, VLLMHijack +from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader +from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) # magic numbers that ensure we are using the same LoRA adapter during the rollout and training process VLLM_LORA_INT_ID = 123 VLLM_LORA_NAME = "123" VLLM_LORA_PATH = "simon_lora_path" +VLLM_ASCEND_REQUIRED_ENV_VARS = {"VLLM_ALL2ALL_BACKEND": "flashinfer_all2allv", "VLLM_ASCEND_ENABLE_NZ": "0"} + + +def set_death_signal(): + """Kill the current process when the parent process exits.""" + if platform.system() != "Linux": + return + libc = ctypes.CDLL("libc.so.6") + libc.prctl(1, signal.SIGKILL) + if os.getppid() == 1: + os.kill(os.getpid(), signal.SIGKILL) + + +def get_device_uuid(device_id: int) -> str: + from vllm.platforms import current_platform + + # Convert torch.npu.current_device to its corresponding ASCEND_RT_VISIBLE_DEVICES. + if is_npu_available: + npu_visible_devices = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",") + assert device_id < len(npu_visible_devices), f"device_id {device_id} must less than {npu_visible_devices}" + return "NPU-" + npu_visible_devices[device_id] + else: + return current_platform.get_device_uuid(device_id) + def get_vllm_max_lora_rank(lora_rank: int): """ @@ -37,6 +79,174 @@ def get_vllm_max_lora_rank(lora_rank: int): raise ValueError(f"lora_rank must be less than or equal to {vllm_max_lora_ranks[-1]}, but got {lora_rank}") +# https://github.com/vllm-project/vllm/issues/13175 +def monkey_patch_compute_logits(model, vocab_size: int): + original_compute_logits = model.compute_logits + + def compute_logits( + self, + *args, + **kwargs, + ) -> torch.Tensor: + logits = original_compute_logits(*args, **kwargs) + logits[..., vocab_size:] = float("-inf") + return logits + + model.compute_logits = MethodType(compute_logits, model) + + +# copy from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py +def rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: + func, args = handle + list_args = list(args) + if device_id is not None: + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + buffer = func(*list_args) + return buffer + + +class TensorMetadata(TypedDict): + name: str + shape: torch.Size + dtype: torch.dtype + offset: int + + +class vLLMColocateWorkerExtension: + """ + The class for vLLM's worker to inherit from, in the colocate setting. + By defining an extension class, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_extension_cls` argument. + + Feature support: + 1. LoRA + 2. Online FP8 quantization + """ + + def __new__(cls, **kwargs): + set_death_signal() + + # 1. patch for Lora + VLLMHijack.hijack() + # 2. patch online fp8 quant + if os.environ.get("VERL_VLLM_FP8_QUANT_ENABLED", "0") == "1": + apply_vllm_fp8_patches() + + # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, + # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. + # This is only a fix for vllm version < v0.13.0. + if is_npu_available: + for k in VLLM_ASCEND_REQUIRED_ENV_VARS: + if k not in os.environ: + os.environ[k] = VLLM_ASCEND_REQUIRED_ENV_VARS[k] + + return super().__new__(cls) + + def monkey_patch_model(self, vocab_size: int): + # patch compute_logits to avoid sampling OOV token + monkey_patch_compute_logits(self.model_runner.model, vocab_size) + # patch weight loader to support MoE model + patch_vllm_moe_model_weight_loader(self.model_runner.model) + + def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False): + """Update the weights of the rollout model.""" + from vllm.platforms import current_platform + + if current_platform.device_type == "npu" and self.device is None: + self.device = torch.device(f"npu:{self.local_rank}") + + # In async mode, make sure the old lora is removed before adding the new one + if peft_config and base_sync_done: + self.remove_lora(VLLM_LORA_INT_ID) + + # build cuda ipc buffer + assert self.device is not None + if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + socket = self._zmq_ctx.socket(zmq.REP) + socket.connect(self._get_zmq_handle()) + handle = socket.recv_pyobj() + buffer: torch.Tensor = rebuild_ipc(handle, self.device.index) + assert buffer.dtype == torch.uint8 + socket.send(b"") + + # receive bucket and update weights + while True: + metadata = socket.recv_pyobj() + weights = [] + for name, meta in metadata["bucket_meta"].items(): + shape, dtype, offset = meta["shape"], meta["dtype"], meta["offset"] + size = dtype.itemsize * shape.numel() + # NOTE: we need to clone the tensor to release CUDA IPC memory + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape).clone() + weights.append((name, tensor)) + get_torch_device().synchronize() + socket.send(b"") + self._update_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done) + del weights + if metadata["is_last"]: + break + + # clean up + socket.close() + del buffer + gc.collect() + get_torch_device().ipc_collect() + get_torch_device().empty_cache() + + def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: dict, base_sync_done: bool): + if peft_config and base_sync_done: + weights = dict(weights) + lora_request = TensorLoRARequest( + lora_name=VLLM_LORA_NAME, + lora_int_id=VLLM_LORA_INT_ID, + lora_path=VLLM_LORA_PATH, + peft_config=peft_config, + lora_tensors=weights, + ) + self.add_lora(lora_request) + logger.info(f"vLLM load weights, loaded_params: {len(weights)}") + else: + # Add the FP8 related logic here as sharding manager has been deprecated. + # Check if FP8 quantization is enabled and apply appropriate weight loading + if is_fp8_model(self.model_runner.vllm_config): + logger.info(f"FP8 model detected (async): {self.model_runner.vllm_config.quant_config}") + # Convert bf16 weights to fp8 format before loading + loaded_params = load_quanted_weights(weights, self.model_runner) + logger.info(f"FP8 weights loaded (async), loaded_params: {len(loaded_params)}") + else: + logger.info("Loading standard weights (non-FP8, async)") + self.model_runner.model.load_weights(weights) + + def _get_zmq_handle(self) -> str: + """Get ZMQ handle for communication.""" + if not hasattr(self, "device_uuid") or not self.device_uuid: + self.device_uuid = get_device_uuid(self.device.index) + return f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock" + + +class SuppressSignalInThread: + def __enter__(self): + self.original_signal = signal.signal + + def no_op_signal(sig, action): + if threading.current_thread() is not threading.main_thread(): + print(f"Ignored signal {sig} in thread {threading.current_thread().name}") + return + return self.original_signal(sig, action) + + signal.signal = no_op_signal + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + signal.signal = self.original_signal + + def build_cli_args_from_config(config: dict[str, Any]) -> list[str]: """ Convert a config dictionary to CLI arguments for vLLM server. diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index bb3c52ecb4e..ff68ef5cef4 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -17,20 +17,17 @@ import json import logging import os -from concurrent.futures import Future from pprint import pprint from typing import Any, Callable, Optional -from uuid import uuid4 -import cloudpickle as pickle import numpy as np import ray import vllm.entrypoints.cli.serve -import zmq from packaging import version from ray.actor import ActorHandle from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.cli.serve import run_headless from vllm.entrypoints.openai.api_server import ( build_app, init_app_state, @@ -40,27 +37,22 @@ from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.async_llm import AsyncLLM -from vllm.v1.engine.core import EngineCoreProc -from vllm.v1.engine.utils import CoreEngineProcManager -from vllm.v1.executor.abstract import Executor from verl.single_controller.ray import RayClassWithInitArgs from verl.utils.config import omega_conf_to_dataclass +from verl.utils.device import get_resource_name, get_visible_devices_keyword +from verl.utils.net_utils import get_free_port, is_valid_ipv6_address from verl.utils.profiler.profile import DistProfiler from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.replica import RolloutMode, RolloutReplica, TokenOutput -from verl.workers.rollout.utils import ( - get_free_port, - get_max_position_embeddings, - is_valid_ipv6_address, - run_unvicorn, -) -from verl.workers.rollout.vllm_rollout import vLLMAsyncRollout +from verl.workers.rollout.utils import get_max_position_embeddings, run_unvicorn +from verl.workers.rollout.vllm_rollout import ServerAdapter from verl.workers.rollout.vllm_rollout.utils import ( VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH, + SuppressSignalInThread, build_cli_args_from_config, get_vllm_max_lora_rank, ) @@ -69,7 +61,6 @@ if _VLLM_VERSION > version.parse("0.11.0"): from vllm.utils.argparse_utils import FlexibleArgumentParser - from vllm.utils.network_utils import get_tcp_uri if _VLLM_VERSION == version.parse("0.12.0"): from vllm.entrypoints.harmony_utils import get_encoding @@ -80,101 +71,13 @@ get_encoding() else: - from vllm.utils import FlexibleArgumentParser, get_tcp_uri -if _VLLM_VERSION >= version.parse("0.12.0"): - from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput - from vllm.v1.outputs import ModelRunnerOutput + from vllm.utils import FlexibleArgumentParser + logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) -class ExternalZeroMQDistributedExecutor(Executor): - """An executor that engines are launched by external ray actors.""" - - uses_ray: bool = False - - def _init_executor(self) -> None: - dp_rank_local = self.vllm_config.parallel_config.data_parallel_rank_local - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - - addresses = os.environ["VERL_VLLM_ZMQ_ADDRESSES"].split(",") - addresses = addresses[dp_rank_local * tp_size : (dp_rank_local + 1) * tp_size] - self.context = zmq.Context() - self.sockets = [] - for address in addresses: - socket = self.context.socket(zmq.REQ) - if address.startswith("tcp://["): - socket.setsockopt(zmq.IPV6, 1) - socket.connect(address) - self.sockets.append(socket) - - kwargs = dict( - vllm_config=self.vllm_config, - local_rank=None, - rank=None, - distributed_init_method="env://", - is_driver_worker=True, - ) - self.collective_rpc("init_worker", args=([kwargs],)) - self.collective_rpc("init_device") - self.collective_rpc("load_model") - - if _VLLM_VERSION >= version.parse("0.12.0"): - - def execute_model( - self, scheduler_output: "SchedulerOutput", non_block: bool = False - ) -> "ModelRunnerOutput | None | Future[ModelRunnerOutput | None]": - output = self.collective_rpc("execute_model", args=(scheduler_output,)) - result = output[0] - if non_block: - f = Future() - f.set_result(result) - return f - return result - - def sample_tokens( - self, grammar_output: "GrammarOutput | None", non_block: bool = False - ) -> "ModelRunnerOutput | None | Future[ModelRunnerOutput | None]": - output = self.collective_rpc("sample_tokens", args=(grammar_output,)) - result = output[0] - if non_block: - f = Future() - f.set_result(result) - return f - return result - - def collective_rpc( - self, - method: str | Callable, - timeout: Optional[float] = None, - args: tuple = (), - kwargs: Optional[dict[str, Any]] = None, - **kwargs_extra: Any, - ) -> list[Any]: - if isinstance(method, str): - sent_method = method - else: - sent_method = pickle.dumps(method) - del method - - message = pickle.dumps((sent_method, args, kwargs or {})) - for socket in self.sockets: - socket.send(message, zmq.DONTWAIT) - - outputs = [] - for socket in self.sockets: - outputs.append(pickle.loads(socket.recv())) - - for output in outputs: - if isinstance(output, Exception): - raise output - return outputs - - def check_health(self): - return - - class vLLMHttpServer: """vLLM http server in single node, this is equivalent to launch server with command line: ``` @@ -192,6 +95,7 @@ def __init__( node_rank: int, gpus_per_node: int, nnodes: int, + cuda_visible_devices: str, ): """ Args: @@ -202,13 +106,16 @@ def __init__( node_rank (int): node rank. gpus_per_node (int): number of gpus per node. nnodes (int): number of nodes. + cuda_visible_devices (str): cuda visible devices. """ - super().__init__() + os.environ[get_visible_devices_keyword()] = cuda_visible_devices self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) - if self.config.max_model_len is None: - self.config.max_model_len = get_max_position_embeddings(self.model_config.hf_config) + self.config.max_model_len = min( + get_max_position_embeddings(self.model_config.hf_config), + self.config.prompt_length + self.config.response_length, + ) self.rollout_mode = rollout_mode self.workers = workers @@ -240,30 +147,58 @@ def __init__( # used for data parallel: --data-parallel-address, --data-parallel-rpc-port if self.node_rank == 0: self._master_address = self._server_address + # used for torch.distributed.init_process_group self._master_port, self._master_sock = get_free_port(self._server_address) + # used for data parallel: --data-parallel-address, --data-parallel-rpc-port + self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address) self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address) - logger.info( - f"vLLMHttpServer, replica_rank: {self.replica_rank}, master address: {self._master_address}, " - f"master port: {self._master_port}, data parallel master port: {self._dp_master_port}" - ) else: self._master_address = None self._master_port = None + self._dp_rpc_port = None + self._dp_master_port = None + + logger.info( + f"vLLMHttpServer, replica_rank: {self.replica_rank}, node_rank: {self.node_rank}, " + f"{get_visible_devices_keyword()}: {cuda_visible_devices}, " + f"master_address: {self._master_address}, master_port: {self._master_port}, " + f"data_parallel_rpc_port: {self._dp_rpc_port}, data_parallel_master_port: {self._dp_master_port}" + ) def get_master_address(self): - """Get master address and port for data parallel.""" - return self._master_address, self._master_port + """Get master address and port for data parallel. + Returns: + tuple: (master_address, master_port, dp_rpc_port) + """ + return self._master_address, self._master_port, self._dp_rpc_port def get_server_address(self): """Get http server address and port.""" assert self._server_port is not None, "http server is not launched, port is None" return self._server_address, self._server_port - async def launch_server(self, master_address: str = None, master_port: int = None): + async def collective_rpc( + self, + method: str | Callable, + timeout: float | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + ): + await self.engine.collective_rpc( + method=method, + timeout=timeout, + args=args, + kwargs=kwargs, + ) + + async def launch_server(self, master_address: str = None, master_port: int = None, dp_rpc_port: int = None): if self.node_rank != 0: - assert master_address and master_port, "non-master node should provide master address and port" + assert master_address and master_port and dp_rpc_port, ( + "non-master node should provide master_address, master_port and dp_rpc_port" + ) self._master_address = master_address self._master_port = master_port + self._dp_rpc_port = dp_rpc_port # 1. setup vllm serve cli args engine_kwargs = self.config.get("engine_kwargs", {}).get("vllm", {}) or {} @@ -308,6 +243,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non # Apply vllm fp8 patches # Will remove the patch after vllm support on-the-fly quant for rollout natively. apply_vllm_fp8_patches() + # for subprocesses patching + os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1" hf_overrides = {} if quantization is not None and self.config.quantization_config_file is not None: @@ -320,6 +257,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "dtype": self.config.dtype, "load_format": self.config.load_format, "skip_tokenizer_init": False, + "distributed_executor_backend": "mp", + "worker_extension_cls": "verl.workers.rollout.vllm_rollout.utils.vLLMColocateWorkerExtension", "trust_remote_code": self.model_config.trust_remote_code, "max_model_len": self.config.max_model_len, "max_num_seqs": self.config.max_num_seqs, @@ -338,6 +277,7 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "quantization": quantization, "hf_overrides": hf_overrides, "scheduling_policy": self.config.scheduling_policy, + "compilation_config": json.dumps({"cudagraph_mode": "FULL_DECODE_ONLY"}), **engine_kwargs, } @@ -375,7 +315,20 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "data_parallel_size_local": data_parallel_size_local, "data_parallel_start_rank": self.node_rank * data_parallel_size_local, "data_parallel_address": self._master_address, - "data_parallel_rpc_port": self._master_port, + "data_parallel_rpc_port": self._dp_rpc_port, + } + ) + + # used for torch.distributed.init_process_group + if self.nnodes > 1: + args.update( + { + "master_addr": self._master_address, + "master_port": self._master_port, + "node_rank": self.node_rank, + "nnodes": self.nnodes, + "data_parallel_address": self._master_address, + "data_parallel_rpc_port": self._dp_rpc_port, } ) @@ -411,21 +364,13 @@ async def launch_server(self, master_address: str = None, master_port: int = Non if server_args.subparser in cmds: cmds[server_args.subparser].validate(server_args) - # 2. setup distributed executor backend - distributed_executor_backend = ExternalZeroMQDistributedExecutor if len(self.workers) > 0 else None - server_args.distributed_executor_backend = distributed_executor_backend - - zmq_addresses = ray.get([worker.get_zeromq_address.remote() for worker in self.workers]) - logger.info( - f"replica_rank={self.replica_rank}, node_rank={self.node_rank}, nnodes={self.nnodes}, " - f"get worker zmq addresses: {zmq_addresses}" - ) - os.environ["VERL_VLLM_ZMQ_ADDRESSES"] = ",".join(zmq_addresses) - # 3. launch server if self.node_rank == 0: + self._master_sock.close() await self.run_server(server_args) else: + # TODO: avoid connect before master_sock close + await asyncio.sleep(3) await self.run_headless(server_args) async def run_server(self, args: argparse.Namespace): @@ -445,6 +390,9 @@ async def run_server(self, args: argparse.Namespace): # Don't keep the dummy data in memory await engine_client.reset_mm_cache() + await engine_client.collective_rpc( + method="monkey_patch_model", kwargs={"vocab_size": len(self.model_config.tokenizer)} + ) app = build_app(args) if _VLLM_VERSION > version.parse("0.11.0"): @@ -458,30 +406,26 @@ async def run_server(self, args: argparse.Namespace): self._server_port, self._server_task = await run_unvicorn(app, args, self._server_address) async def run_headless(self, args: argparse.Namespace): - # Create the EngineConfig. - engine_args = vllm.AsyncEngineArgs.from_cli_args(args) - usage_context = UsageContext.OPENAI_API_SERVER - vllm_config = engine_args.create_engine_config(usage_context=usage_context, headless=True) - - parallel_config = vllm_config.parallel_config - local_engine_count = parallel_config.data_parallel_size_local - - host = parallel_config.data_parallel_master_ip - port = engine_args.data_parallel_rpc_port # add to config too - handshake_address = get_tcp_uri(host, port) - - # Create the engines. - self.engine_manager = CoreEngineProcManager( - target_fn=EngineCoreProc.run_engine_core, - local_engine_count=local_engine_count, - start_index=vllm_config.parallel_config.data_parallel_rank, - local_start_index=0, - vllm_config=vllm_config, - local_client=False, - handshake_address=handshake_address, - executor_class=Executor.get_class(vllm_config), - log_stats=not engine_args.disable_log_stats, - ) + """Run headless server in a separate thread.""" + + def run_headless_wrapper(): + with SuppressSignalInThread(): + run_headless(args) + + def on_run_headless_done(future: asyncio.Future): + try: + exc = future.exception() + if exc: + logger.exception(f"run_headless failed with exception: {exc}") + else: + logger.warning("run_headless completed successfully, but it's not expected.") + except Exception as e: + logger.exception(f"get result from run_headless failed: {e}") + finally: + os._exit(1) + + self.task = asyncio.create_task(asyncio.to_thread(run_headless_wrapper)) + self.task.add_done_callback(on_run_headless_done) async def generate( self, @@ -716,7 +660,7 @@ async def abort_request(self, request_id: str, reset_prefix_cache: bool = True) return {"aborted": False, "request_id": request_id, "error": str(e)} -_rollout_worker_actor_cls = ray.remote(vLLMAsyncRollout) +_rollout_worker_actor_cls = ray.remote(ServerAdapter) class vLLMReplica(RolloutReplica): @@ -747,35 +691,47 @@ async def launch_servers(self): f"worker number {len(self.workers)} not equal to world size {self.world_size}" ) - # get node_id of all workers - worker_node_ids = await asyncio.gather( + # NOTE: We always use MP Executor backend whether it's single-node or multi-node. + # For multi-node without DP (e.g TP=16), need vllm>=0.11.1, https://github.com/vllm-project/vllm/pull/23691 + if self.config.data_parallel_size == 1 and self.nnodes > 1: + assert _VLLM_VERSION >= version.parse("0.11.1"), ( + "For multi-node MP Executor, either (1) set data_parallel_size > 1 or (2) upgrade vLLM to >= 0.11.1" + ) + + # get (node_id, CUDA_VISIBLE_DEVICES) of all workers + worker_infos = await asyncio.gather( *[ - worker.__ray_call__.remote(lambda self: ray.get_runtime_context().get_node_id()) + worker.__ray_call__.remote( + lambda self: ( + ray.get_runtime_context().get_node_id(), + ray.get_runtime_context().get_accelerator_ids()[get_resource_name()][0], + ) + ) for worker in self.workers ] ) + worker_cuda_visible_devices = [worker_info[1] for worker_info in worker_infos] + worker_node_ids = [worker_info[0] for worker_info in worker_infos] - # For non-data parallel case, there's only one server whether it's single or multi nodes. + # create server actor in each node with node affinity and cuda visible devices nnodes, gpus_per_replica_node = self.nnodes, self.gpus_per_replica_node - if self.config.data_parallel_size == 1: - nnodes = 1 - gpus_per_replica_node = self.world_size - - # create server actor in each node with node affinity for node_rank in range(nnodes): workers = self.workers[node_rank * gpus_per_replica_node : (node_rank + 1) * gpus_per_replica_node] + node_cuda_visible_devices = ",".join( + worker_cuda_visible_devices[node_rank * gpus_per_replica_node : (node_rank + 1) * gpus_per_replica_node] + ) node_id = worker_node_ids[node_rank * gpus_per_replica_node] name = ( f"vllm_server_{self.replica_rank}_{node_rank}" if not self.is_reward_model else f"vllm_server_reward_{self.replica_rank}_{node_rank}" ) - name = name + f"_{uuid4().hex[:8]}" server = self.server_class.options( scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( node_id=node_id, soft=False, ), + runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}, name=name, ).remote( config=self.config, @@ -786,14 +742,17 @@ async def launch_servers(self): node_rank=node_rank, gpus_per_node=gpus_per_replica_node, nnodes=nnodes, + cuda_visible_devices=node_cuda_visible_devices, ) self.servers.append(server) # launch http server in each node - master_address, master_port = await self.servers[0].get_master_address.remote() + master_address, master_port, dp_rpc_port = await self.servers[0].get_master_address.remote() await asyncio.gather( *[ - server.launch_server.remote(master_address=master_address, master_port=master_port) + server.launch_server.remote( + master_address=master_address, master_port=master_port, dp_rpc_port=dp_rpc_port + ) for server in self.servers ] ) diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 92b5e94d650..ebbb6e19e48 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -26,63 +26,29 @@ - After inference, all the parameters that doesn't belong to this pp rank is freed. """ -import getpass +import gc import logging import os -from dataclasses import asdict -from types import MethodType -from typing import Any, Generator +import time +from typing import Any, Generator, Optional -import cloudpickle as pickle import ray import torch -import torch.distributed import zmq -import zmq.asyncio -from filelock import FileLock -from torch.distributed.device_mesh import DeviceMesh -from vllm.config import LoRAConfig - -from verl.utils.ray_utils import get_event_loop - -try: - from vllm.worker.worker_base import WorkerWrapperBase -except ModuleNotFoundError: - # https://github.com/vllm-project/vllm/commit/6a113d9aed8221a9c234535958e70e34ab6cac5b - from vllm.v1.worker.worker_base import WorkerWrapperBase - from packaging import version as vs +from torch.distributed.device_mesh import DeviceMesh +from torch.multiprocessing.reductions import reduce_tensor from verl import DataProto from verl.third_party.vllm import VLLM_SLEEP_LEVEL, get_version -from verl.utils.device import is_npu_available -from verl.utils.distributed import initialize_global_process_group_ray -from verl.utils.ray_utils import ray_noset_visible_devices -from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge -from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights +from verl.utils.device import get_device_id, get_device_name, get_torch_device +from verl.utils.torch_dtypes import PrecisionType from verl.workers.config import HFModelConfig, RolloutConfig from verl.workers.rollout.base import BaseRollout -from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address -from verl.workers.rollout.vllm_rollout.utils import ( - VLLM_LORA_INT_ID, - VLLM_LORA_NAME, - VLLM_LORA_PATH, - get_vllm_max_lora_rank, -) +from verl.workers.rollout.vllm_rollout.utils import TensorMetadata, get_device_uuid logger = logging.getLogger(__file__) -logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) - -VLLM_ASCEND_REQUIRED_ENV_VARS = {"VLLM_ALL2ALL_BACKEND": "flashinfer_all2allv", "VLLM_ASCEND_ENABLE_NZ": "0"} - -# TODO -# 1. support pp in vllm -# 2. passing tokenizer is not necessary? no encoding/decoding is happending here -# 3. simplify init logics - - -if is_version_ge(pkg="vllm", minver="0.7.3"): - VLLMHijack.hijack() +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "INFO")) def _check_vllm_version_for_sleep_level(): @@ -95,24 +61,11 @@ def _check_vllm_version_for_sleep_level(): return vs.parse(current_version) >= vs.parse(minver) -# https://github.com/vllm-project/vllm/issues/13175 -def _monkey_patch_compute_logits(model, vocab_size: int): - original_compute_logits = model.compute_logits - - def compute_logits( - self, - *args, - **kwargs, - ) -> torch.Tensor: - logits = original_compute_logits(*args, **kwargs) - logits[..., vocab_size:] = float("-inf") - return logits - - model.compute_logits = MethodType(compute_logits, model) - - -class vLLMAsyncRollout(BaseRollout): - """vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase, which is engine in single worker process.""" +class ServerAdapter(BaseRollout): + """ + vLLM server adapter used in native async mode, serve as a client to request vLLM server + to resume/release/update weights and kv_cache. + """ def __init__( self, @@ -121,14 +74,18 @@ def __init__( device_mesh: DeviceMesh, ): super().__init__(config, model_config, device_mesh) - self.tokenizer = self.model_config.tokenizer - self.inference_engine: WorkerWrapperBase = None - self.address = self._init_zeromq() - self.lora_config = ( - {"max_loras": 1, "max_lora_rank": get_vllm_max_lora_rank(self.model_config.lora_rank)} - if self.model_config.lora_rank > 0 - else {} + self.server_handle: ray.actor.ActorHandle = None + + rank = int(os.environ["RANK"]) + local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"]) + rollout_world_size = ( + self.config.tensor_model_parallel_size + * self.config.data_parallel_size + * self.config.pipeline_model_parallel_size ) + self.replica_rank = rank // rollout_world_size + self.rollout_rank = rank % rollout_world_size + self.node_rank = self.rollout_rank // local_world_size if config.layered_summon or (config.expert_parallel_size > 1 and not _check_vllm_version_for_sleep_level()): logger.warning("Setting the sleep level to 1 may cause a memory overflow.") @@ -136,109 +93,39 @@ def __init__( else: self.sleep_level = VLLM_SLEEP_LEVEL - def _init_zeromq(self) -> str: - tensor_parallel_size = self.config.tensor_model_parallel_size - - # single node: ipc, multi nodes: tcp - local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"]) - socket_type = "ipc" if tensor_parallel_size <= local_world_size else "tcp" - - # File lock to prevent multiple workers listen to same port - with FileLock(f"/tmp/verl_vllm_zmq_{getpass.getuser()}.lock"): - context = zmq.asyncio.Context() - self.socket = context.socket(zmq.REP) - if socket_type == "ipc": - pid = os.getpid() - address = f"ipc:///tmp/verl_vllm_zmq_{pid}_{getpass.getuser()}.ipc" - else: - ip = ray.util.get_node_ip_address().strip("[]") - port, sock = get_free_port(ip) - if is_valid_ipv6_address(ip): - address = f"tcp://[{ip}]:{port}" - self.socket.setsockopt(zmq.IPV6, 1) - else: - address = f"tcp://{ip}:{port}" - self.socket.bind(address) - - loop = get_event_loop() - self.zmq_loop_task = loop.create_task(self._loop_forever()) + self.device_uuid = get_device_uuid(get_device_id()) + self.zmq_context = zmq.Context() + self.zmq_handle = f"ipc:///tmp/rl-colocate-zmq-{self.device_uuid}.sock" - return address - - async def _loop_forever(self): - while True: - try: - message = await self.socket.recv() - method, args, kwargs = pickle.loads(message) - result = await self._execute_method(method, *args, **kwargs) - await self.socket.send(pickle.dumps(result)) - except Exception as e: - logger.exception(f"vLLMAsyncRollout _loop_forever error: {e}") - await self.socket.send(pickle.dumps(e)) - break - - def _build_inference_engine(self) -> WorkerWrapperBase: - """Create a vLLM worker wrapper across vLLM versions. - - vLLM changed WorkerWrapperBase signature (i.e., removing vllm_config from - __init__). We keep a small runtime fallback to support multiple versions. + async def _execute_method( + self, + method: str, + non_block: bool = False, + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict] = None, + ) -> Any: + """Execute method on inference engine via ray. - https://github.com/vllm-project/vllm/commit/aafd4d23548ae54adeca1d4898cc15a4d2c390ac + Args: + method: The method name to execute on the server. + non_block: If True, execute the method asynchronously and return immediately. + timeout: Timeout for the collective_rpc call. + args: Positional arguments for the method. + kwargs: Keyword arguments for the method. + + Returns: + The result of the method execution, or None if non_block=True. """ - try: - return WorkerWrapperBase(vllm_config=self.vllm_config) - except TypeError: - return WorkerWrapperBase() + if self.rollout_rank != 0: + return None - def _init_worker(self, all_kwargs: list[dict[str, Any]]): - """Initialize worker engine.""" - # TODO: For ascend NPU, when the corresponding vllm-ascend version is upgraded to v0.13.0, - # please remove the VLLM_ASCEND_REQUIRED_ENV_VARS variable replacement action. - # This is only a fix for vllm version < v0.13.0. - if is_npu_available: - for k in VLLM_ASCEND_REQUIRED_ENV_VARS: - if k not in os.environ: - os.environ[k] = VLLM_ASCEND_REQUIRED_ENV_VARS[k] + # Lazy init http server adapter because http server is launched after hybrid engine. + if self.server_handle is None: + self.server_handle = ray.get_actor(f"vllm_server_{self.replica_rank}_{self.node_rank}") - if not torch.distributed.is_initialized(): - initialize_global_process_group_ray() - all_kwargs[0]["rank"] = int(os.environ["RANK"]) - device_name = "NPU" if is_npu_available else "GPU" - all_kwargs[0]["local_rank"] = ( - 0 - if not ray_noset_visible_devices() - else int(ray.get_runtime_context().get_accelerator_ids()[device_name][0]) - ) - self.vllm_config = all_kwargs[0]["vllm_config"] - if self.lora_config: - lora_dtype = getattr(torch, self.config.dtype) - self.vllm_config.lora_config = LoRAConfig(lora_dtype=lora_dtype, **self.lora_config) - if self.config.quantization is not None: - _SUPPORTED_QUANTIZATION = ["fp8", "torchao"] - if self.config.quantization not in _SUPPORTED_QUANTIZATION: - raise ValueError( - f"Currently only support {_SUPPORTED_QUANTIZATION} quantization, got: {self.config.quantization}" - ) - - if self.config.quantization == "fp8": - # Apply vllm fp8 patches - # Will remove the patch after vllm support on-the-fly quant for rollout natively. - apply_vllm_fp8_patches() - - self.inference_engine = self._build_inference_engine() - self.inference_engine.init_worker(all_kwargs) - - def _load_model(self, *args, **kwargs): - self.inference_engine.load_model(*args, **kwargs) - _monkey_patch_compute_logits(self.inference_engine.worker.model_runner.model, len(self.tokenizer)) - - async def _execute_method(self, method: str | bytes, *args, **kwargs): - if method == "init_worker": - return self._init_worker(*args, **kwargs) - elif method == "load_model": - return self._load_model(*args, **kwargs) - else: - return self.inference_engine.execute_method(method, *args, **kwargs) + future = self.server_handle.collective_rpc.remote(method, timeout=timeout, args=args, kwargs=kwargs) + return future if non_block else await future async def resume(self, tags: list[str]): """Resume rollout weights or kv cache in GPU memory. @@ -247,55 +134,84 @@ async def resume(self, tags: list[str]): tags: weights or kv_cache. """ if self.config.free_cache_engine: - self.inference_engine.wake_up(tags=tags) + await self._execute_method("wake_up", kwargs={"tags": tags}) async def release(self): """Release weights and kv cache in GPU memory.""" if self.config.free_cache_engine: - self.inference_engine.sleep(level=self.sleep_level) + await self._execute_method("sleep", kwargs={"level": self.sleep_level}) + @torch.no_grad() async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs): - """Update the weights of the rollout model. + """Update model weights via CUDA IPC to inference workers.""" + start_time = time.time() + future = await self._execute_method( + "update_weights_from_ipc", + non_block=True, + kwargs=kwargs, + ) - Args: - weights: A generator that yields the name of the weight tensor and the tensor itself. - """ - peft_config, base_sync_done = kwargs.get("peft_config", None), kwargs.get("base_sync_done", False) - if peft_config and base_sync_done: - # In async mode, make sure the old lora is removed before adding the new one - self.inference_engine.worker.remove_lora(VLLM_LORA_INT_ID) - weights = dict(weights) - lora_request = TensorLoRARequest( - lora_name=VLLM_LORA_NAME, - lora_int_id=VLLM_LORA_INT_ID, - lora_path=VLLM_LORA_PATH, - peft_config=asdict(peft_config), - lora_tensors=weights, + # build cuda ipc buffer + bucket_size_mb = self.config.update_weights_bucket_megabytes + bucket_size = int(bucket_size_mb) << 20 + buffer = torch.empty(bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:0") + handle = reduce_tensor(buffer) + s = self.zmq_context.socket(zmq.REQ) + s.bind(self.zmq_handle) + s.send_pyobj(handle) + s.recv() + + # send bucket weights + offset = 0 + bucket_meta: dict[str, TensorMetadata] = {} + dtype = PrecisionType.to_dtype(self.config.dtype) + for name, weight in weights: + # model parameters are in fp32 full precision + weight = weight.to(dtype, non_blocking=True) + + # fill the tensor bucket + if offset + weight.nbytes > bucket_size: + get_torch_device().synchronize() + s.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) + s.recv() + bucket_meta = {} + offset = 0 + + # TODO: slice embedding layer weight into chunks + assert offset + weight.nbytes <= bucket_size, ( + f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." + f"Please increase rollout.update_weights_bucket_megabytes({bucket_size_mb} MB)." ) - self.inference_engine.worker.add_lora(lora_request) - logger.info(f"vLLM load weights, loaded_params: {len(weights)}") - else: - from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader - - model_runner = self.inference_engine.worker.model_runner - model = model_runner.model - patch_vllm_moe_model_weight_loader(model) - - # Add the FP8 related logic here as sharding manager has been deprecated. - # Check if FP8 quantization is enabled and apply appropriate weight loading - if is_fp8_model(model_runner.vllm_config): - logger.info(f"FP8 model detected (async): {model_runner.vllm_config.quant_config}") - # Convert bf16 weights to fp8 format before loading - loaded_params = load_quanted_weights(weights, model_runner) - logger.info(f"FP8 weights loaded (async), loaded_params: {len(loaded_params)}") - else: - logger.info("Loading standard weights (non-FP8, async)") - model.load_weights(weights) + bucket_meta[name] = { + "name": name, + "shape": weight.shape, + "dtype": weight.dtype, + "offset": offset, + } + buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight.nbytes + + # send the last bucket + get_torch_device().synchronize() + s.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) + s.recv() + + # clean up + s.close() + del buffer + gc.collect() + get_torch_device().ipc_collect() + get_torch_device().empty_cache() + if future is not None: + await future + + if self.replica_rank == 0 and self.rollout_rank == 0: + logger.info(f"update_weights done, time cost: {time.time() - start_time:.2f}s") def generate_sequences(self, prompts: DataProto) -> DataProto: """Batch generate sequences in sync mode. - Note: vLLMAsyncRollout uses async server mode and does not support synchronous + Note: ServerAdapter uses async server mode and does not support synchronous generation. Since SPMD mode was retired (PR #4411), the generation workflow should use the async server interface instead. @@ -303,14 +219,9 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: NotImplementedError: Always raised as sync generation is not supported. """ raise NotImplementedError( - "vLLMAsyncRollout does not support synchronous generate_sequences(). " + "ServerAdapter does not support synchronous generate_sequences(). " "The vLLM SPMD mode was retired in PR #4411. For batch generation, " "please use the async server interface via vLLMReplica and AsyncLLMServerManager, " "or use HFRollout for synchronous generation. " "See https://github.com/volcengine/verl/issues/4682 for more details." ) - - # ==================== server mode public methods ==================== - - def get_zeromq_address(self): - return self.address