From d4773b8100f97d6c08edba8bdcdd8f72bcc21153 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 17 Apr 2025 19:05:18 +0800 Subject: [PATCH 01/10] feat: introduce vLLM AsyncLLM to support multi-turn rollout --- examples/ppo_trainer/naive_chat_scheduler.py | 72 ++++ .../ppo_trainer/run_qwen2-7b_seq_balance.sh | 10 + recipe/dapo/src/dapo_ray_trainer.py | 2 +- recipe/dapo/src/main_dapo.py | 5 +- recipe/prime/main_prime.py | 8 +- recipe/prime/prime_ray_trainer.py | 2 +- tests/rollout/test_vllm_multi_turn.py | 149 ++++++++ verl/single_controller/base/decorator.py | 11 + .../base/register_center/ray.py | 10 + verl/single_controller/base/worker.py | 7 + verl/single_controller/ray/base.py | 67 +++- verl/trainer/config/ppo_trainer.yaml | 2 + verl/trainer/main_ppo.py | 8 +- verl/trainer/ppo/ray_trainer.py | 31 +- verl/workers/fsdp_async_workers.py | 326 ++++++++++++++++++ verl/workers/fsdp_workers.py | 27 +- verl/workers/rollout/chat_scheduler.py | 113 ++++++ verl/workers/rollout/vllm_rollout/__init__.py | 2 +- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 59 +++- verl/workers/sharding_manager/fsdp_vllm.py | 8 +- 20 files changed, 870 insertions(+), 49 deletions(-) create mode 100644 examples/ppo_trainer/naive_chat_scheduler.py create mode 100644 tests/rollout/test_vllm_multi_turn.py create mode 100644 verl/workers/fsdp_async_workers.py create mode 100644 verl/workers/rollout/chat_scheduler.py diff --git a/examples/ppo_trainer/naive_chat_scheduler.py b/examples/ppo_trainer/naive_chat_scheduler.py new file mode 100644 index 00000000000..ee81b52d22b --- /dev/null +++ b/examples/ppo_trainer/naive_chat_scheduler.py @@ -0,0 +1,72 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import Any, Dict, List + +from omegaconf import DictConfig +from openai.types.chat.chat_completion import ChatCompletion + +from verl.protocol import DataProto +from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler + + +class NaiveChatCompletionScheduler(ChatCompletionScheduler): + + def __init__(self, config: DictConfig, model_path: str, server_addresses: List[str], max_cache_size: int = 10000): + super().__init__(config, model_path, server_addresses, max_cache_size) + + async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: + kwargs = dict( + n=self.config.n, + max_completion_tokens=self.config.response_length, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + + do_sample = prompts.meta_info.get('do_sample', True) + is_validate = prompts.meta_info.get('validate', False) + if not do_sample or is_validate: + kwargs["n"] = 1 + kwargs["temperature"] = 0 + + kwargs.update(sampling_params) + print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}") + + async def callback(completions: ChatCompletion, info: Dict[str, Any]): + info["all_completions"][info["index"]] = completions + + # NOTE: we can call tools and resubmit chat completions here. + # call_tools(completions, info) + # await self.submit_chat_completions(callback2, ...) + + tasks, all_completions = [], [None] * len(prompts) + for i, prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]): + # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] + tasks.append( + asyncio.create_task( + self.submit_chat_completions( + callback=callback, + callback_additional_info={ + "all_completions": all_completions, + "index": i + }, + model=self.model_name, + messages=prompt, + **kwargs, + ))) + await asyncio.gather(*tasks) + + print("[NaiveChatCompletionScheduler] generate_sequences done") + # TODO: completions => DataProto + return all_completions diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh index a5824a542d5..abc490acdc7 100644 --- a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh @@ -8,10 +8,18 @@ math_test_path=$HOME/data/math/test.parquet train_files="['$gsm8k_train_path', '$math_train_path']" test_files="['$gsm8k_test_path', '$math_test_path']" +# For async rollout mode, dataset should return raw chat. +rollout_mode="sync" +if [ "$rollout_mode" = "async" ]; then + return_raw_chat="True" + chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler +fi + python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ + data.return_raw_chat=$return_raw_chat \ data.train_batch_size=4096 \ data.max_prompt_length=4096 \ data.max_response_length=4096 \ @@ -29,6 +37,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ critic.optim.lr=1e-5 \ diff --git a/recipe/dapo/src/dapo_ray_trainer.py b/recipe/dapo/src/dapo_ray_trainer.py index 20294662dda..76a48cc57ff 100644 --- a/recipe/dapo/src/dapo_ray_trainer.py +++ b/recipe/dapo/src/dapo_ray_trainer.py @@ -40,7 +40,7 @@ class RayDAPOTrainer(RayPPOTrainer): Note that this trainer runs on the driver process on a single CPU/GPU node. """ - def fit(self): + async def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC diff --git a/recipe/dapo/src/main_dapo.py b/recipe/dapo/src/main_dapo.py index c7eaebdc80b..152d61c5441 100644 --- a/recipe/dapo/src/main_dapo.py +++ b/recipe/dapo/src/main_dapo.py @@ -76,7 +76,8 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: - def run(self, config): + + async def run(self, config): # print initial config from pprint import pprint @@ -201,7 +202,7 @@ def run(self, config): val_reward_fn=val_reward_fn, ) trainer.init_workers() - trainer.fit() + await trainer.fit() if __name__ == "__main__": diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py index 5f912e374de..f34fe8b6607 100644 --- a/recipe/prime/main_prime.py +++ b/recipe/prime/main_prime.py @@ -29,6 +29,8 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import asyncio + import hydra import ray @@ -53,6 +55,10 @@ def run_prime(config, compute_score=None): @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head def main_task(config, compute_score=None): + asyncio.run(_main_task(config, compute_score)) + + +async def _main_task(config, compute_score=None): # print initial config from pprint import pprint @@ -142,7 +148,7 @@ def main_task(config, compute_score=None): val_reward_fn=val_reward_fn, ) trainer.init_workers() - trainer.fit() + await trainer.fit() if __name__ == "__main__": diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index a21416e199a..873660105c5 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -331,7 +331,7 @@ def _load_checkpoint(self): if isinstance(self.train_dataloader.dataset, RLHFDataset): self.train_dataloader.dataset.resume_dataset_state() - def fit(self): + async def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py new file mode 100644 index 00000000000..79d902ad8cf --- /dev/null +++ b/tests/rollout/test_vllm_multi_turn.py @@ -0,0 +1,149 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from typing import Any, Dict + +import ray +from omegaconf import OmegaConf +from openai.types.chat.chat_completion import ChatCompletion + +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.ray.base import Worker, create_colocated_worker_cls +from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role +from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker, AsyncLLMManager +from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler + + +async def test_vllm_multi_turn(): + config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") + model_path = "Qwen/Qwen2-7B-Instruct" + model_name = "/".join(model_path.split("/")[-2:]) + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + + # =========================== 1. Create hybrid ActorRollout workers =========================== + ray.init( + runtime_env={ + 'env_vars': { + 'TOKENIZERS_PARALLELISM': 'true', + 'NCCL_DEBUG': 'WARN', + 'VLLM_LOGGING_LEVEL': 'WARN', + 'VLLM_USE_V1': '1', + } + }) + role_worker_mapping = { + Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), + } + global_pool_id = 'global_pool' + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + } + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager.create_resource_pool() + resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout], + config=config.actor_rollout_ref, + role='actor_rollout') + resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls + + all_wg = {} + wg_dicts = [] + for resource_pool, class_dict in resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict, worker_cls=Worker) + wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + wg_dicts.append(wg_dict) + actor_rollout_wg = all_wg['actor_rollout'] + actor_rollout_wg.init_model() + + # =========================== 2. Create AsyncLLMManager&ChatScheduler =========================== + async_rollout_manager = AsyncLLMManager( + config=config.actor_rollout_ref, + worker_group=actor_rollout_wg, + ) + + async_chat_scheduler = ChatCompletionScheduler( + config=config.actor_rollout_ref.rollout, + model_path=config.actor_rollout_ref.model.path, + server_addresses=async_rollout_manager.server_addresses, + ) + + # =========================== 3. Multi turn rollout =========================== + async def callback(completions: ChatCompletion, info: Dict[str, Any]): + messages, round = info["messages"], info["round"] + message = completions.choices[0].message + messages.append({"role": message.role, "content": message.content}) + print(f"[round={round}] role: {message.role}, content: {message.content}") + + extra_headers = {"x-request-id": completions.id} + if round == 0: + messages.append({"role": "user", "content": "What is your name?"}) + await async_chat_scheduler.submit_chat_completions( + callback=callback, + callback_additional_info={ + "messages": messages, + "round": 1 + }, + model=model_name, + messages=messages, + extra_headers=extra_headers, + ) + elif round == 1: + messages.append({"role": "user", "content": "What is your favorite color?"}) + await async_chat_scheduler.submit_chat_completions( + callback=callback, + callback_additional_info={ + "messages": messages, + "round": 2 + }, + model=model_name, + messages=messages, + extra_headers=extra_headers, + ) + else: + print("Done!") + + messages = [{ + "role": "user", + "content": "Let's play a role playing game. Your name is Bob, your favorite color is red." + }] + await async_chat_scheduler.submit_chat_completions( + callback=callback, + callback_additional_info={ + "messages": messages, + "round": 0 + }, + model=model_name, + messages=messages, + ) + assert len(messages) == 6 + for round, message in enumerate(messages): + if round % 2 == 0: + assert message["role"] == "user" + else: + assert message["role"] == "assistant" + + +if __name__ == "__main__": + asyncio.run(test_vllm_multi_turn()) diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 9d95c77d485..f5856212c82 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -37,6 +37,9 @@ class Dispatch(Enum): DP_COMPUTE_PROTO_WITH_FUNC = 10 DP_COMPUTE_METRIC = 11 + # This is a special dispatch mode for vllm ExternalRayDistributedExecutor + DIRECT_ROLLOUT_METHOD = 12 + class Execute(Enum): ALL = 0 @@ -65,6 +68,10 @@ def dispatch_one_to_all(worker_group, *args, **kwargs): return args, kwargs +def dummy_direct_rollout_call(worker_group, *args, **kwargs): + raise NotImplementedError("Direct rollout call is forbidden.") + + def dispatch_all_to_all(worker_group, *args, **kwargs): return args, kwargs @@ -356,6 +363,10 @@ def get_predefined_dispatch_fn(dispatch_mode): "collect_fn": collect_dp_compute_data_proto, }, Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, + Dispatch.DIRECT_ROLLOUT_METHOD: { + "dispatch_fn": dummy_direct_rollout_call, + "collect_fn": dummy_direct_rollout_call, + }, } return predefined_dispatch_mode_fn[dispatch_mode] diff --git a/verl/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py index de7f702a894..23a44bf6dc7 100644 --- a/verl/single_controller/base/register_center/ray.py +++ b/verl/single_controller/base/register_center/ray.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Tuple + import ray @@ -19,10 +21,18 @@ class WorkerGroupRegisterCenter: def __init__(self, rank_zero_info): self.rank_zero_info = rank_zero_info + # rank -> node_id + self.workers_info: Dict[int, str] = {} def get_rank_zero_info(self): return self.rank_zero_info + def set_worker_info(self, rank, node_id) -> None: + self.workers_info[rank] = node_id + + def get_worker_info(self) -> Dict[int, str]: + return self.workers_info + def create_worker_group_register_center(name, info): return WorkerGroupRegisterCenter.options(name=name).remote(info) diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 59ff599f061..0655d0e7a22 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -19,6 +19,8 @@ import socket from dataclasses import dataclass +import ray + from .decorator import Dispatch, Execute, register @@ -125,6 +127,11 @@ def _configure_before_init(self, register_center_name: str, rank: int): ) os.environ.update(rank_zero_info) + else: + self.register_center = ray.get_actor(register_center_name) + + # set worker info for node affinity scheduling + ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) def __init__(self, cuda_visible_devices=None) -> None: # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 6bf7ddc1f72..edaf6c7eaa4 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -323,10 +323,16 @@ def worker_names(self): return self._worker_names @classmethod - def from_detached(cls, worker_names=None, ray_cls_with_init=None): - worker_group = cls( - resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=None, worker_names=worker_names - ) + def from_detached( + cls, + name_prefix, + worker_names=None, + ray_cls_with_init=None, + ): + worker_group = cls(resource_pool=None, + ray_cls_with_init=ray_cls_with_init, + name_prefix=name_prefix, + worker_names=worker_names) return worker_group def spawn(self, prefix_set): @@ -350,7 +356,9 @@ def _rebind_actor_methods(worker_group, actor_name): new_worker_group_dict = {} for prefix in prefix_set: new_worker_group = self.from_detached( - worker_names=self._worker_names, ray_cls_with_init=self.ray_cls_with_init + name_prefix=self.name_prefix, + worker_names=self._worker_names, + ray_cls_with_init=self.ray_cls_with_init, ) _rebind_actor_methods(new_worker_group, prefix) @@ -416,7 +424,7 @@ def world_size(self): import os from unittest.mock import patch -from verl.single_controller.base.decorator import MAGIC_ATTR +from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch def _bind_workers_method_to_parent(cls, key, user_defined_cls): @@ -424,6 +432,7 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls): Binds the methods of each worker to the WorkerDict. Note that we only bind public methods that are decorated by register """ + for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) @@ -443,12 +452,19 @@ def func(self, *args, **kwargs): func = generate_function(method_name) # pass MAGIC_ATTR for outer worker group - setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR)) + attrs = getattr(method, MAGIC_ATTR) + setattr(func, MAGIC_ATTR, attrs) try: - method_name_with_prefix = key + "_" + method_name - setattr(cls, method_name_with_prefix, func) - # print(f'Binding {method_name_with_prefix}') - except Exception: + # bind direct rollout method to class without prefix + if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: + assert not hasattr(cls, method_name), \ + f"conflict direct rollout method {method_name} with role {key}" + setattr(cls, method_name, func) + print(f"bind role {key} method {method_name} to class {cls}") + else: + method_name_with_prefix = key + '_' + method_name + setattr(cls, method_name_with_prefix, func) + except Exception as e: raise ValueError(f"Fail to set method_name {method_name}") @@ -458,21 +474,34 @@ def _unwrap_ray_remote(cls): return cls -def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): +def _nearest_common_base(mros: List): + last_common = object + min_len = min([len(mro) for mro in mros]) - 1 # exclude final derived class + + for i in range(min_len): + mro = mros[0][i] + for j in range(1, len(mros)): + if mro != mros[j][i]: + return last_common + last_common = mro + + return last_common + + +def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs], worker_cls: type = None): """ This function should return a class instance that delegates the calls to every cls in cls_dict """ cls_dict = {} init_args_dict = {} - worker_cls = None + if worker_cls is None: + worker_cls = _nearest_common_base( + [list(reversed(cls.cls.__ray_actor_class__.__mro__)) for cls in class_dict.values()]) + assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" + print(f"find nearest common base class {worker_cls}") + for key, cls in class_dict.items(): - if worker_cls is None: - worker_cls = cls.cls.__ray_actor_class__.__base__ - else: - assert worker_cls == cls.cls.__ray_actor_class__.__base__, ( - "the worker class should be the same when share the same process" - ) cls_dict[key] = cls.cls init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 34957b06c45..ee2d181fee5 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -83,6 +83,8 @@ actor_rollout_ref: ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size rollout: name: vllm + mode: "sync" # sync: LLM, async: AsyncLLM + chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 1721c88af52..f6e57530cfd 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -83,7 +83,8 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: - def run(self, config): + + async def run(self, config): # print initial config from pprint import pprint @@ -112,6 +113,9 @@ def run(self, config): ray_worker_group_cls = RayWorkerGroup + if config.actor_rollout_ref.rollout.mode == "async": + from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker as ActorRolloutRefWorker + elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup @@ -176,7 +180,7 @@ def run(self, config): val_reward_fn=val_reward_fn, ) trainer.init_workers() - trainer.fit() + await trainer.fit() if __name__ == "__main__": diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 2c15b8e7fc6..f44961b4291 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -17,6 +17,7 @@ """ import json +import importlib import os import uuid from collections import defaultdict @@ -770,6 +771,25 @@ def init_workers(self): self.actor_rollout_wg = all_wg["actor_rollout"] self.actor_rollout_wg.init_model() + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.get("mode", "sync") == 'async': + from verl.workers.fsdp_async_workers import AsyncLLMManager + + self.async_rollout_mode = True + self.async_rollout_manager = AsyncLLMManager( + config=self.config.actor_rollout_ref, + worker_group=self.actor_rollout_wg, + ) + module_path, class_name = self.config.actor_rollout_ref.rollout.chat_scheduler.rsplit(".", 1) + module = importlib.import_module(module_path) + scheduler_cls = getattr(module, class_name) + self.async_chat_scheduler = scheduler_cls( + config=self.config.actor_rollout_ref.rollout, + model_path=self.config.actor_rollout_ref.model.path, + server_addresses=self.async_rollout_manager.server_addresses, + ) + def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` local_global_step_folder = os.path.join( @@ -899,7 +919,7 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle ) metrics.update(global_balance_stats) - def fit(self): + async def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC @@ -954,7 +974,7 @@ def fit(self): else: gen_batch = batch.pop( batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], + non_tensor_batch_keys=["raw_prompt_ids"] + ["raw_prompt"] if self.async_rollout_mode else [], ) is_last_step = self.global_steps >= self.total_training_steps @@ -962,7 +982,12 @@ def fit(self): with _timer("step", timing_raw): # generate a batch with _timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + await self.async_rollout_manager.wake_up() + gen_batch_output = await self.async_chat_scheduler.generate_sequences(gen_batch) + await self.async_rollout_manager.sleep() if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: with _timer("gen_max", timing_raw): diff --git a/verl/workers/fsdp_async_workers.py b/verl/workers/fsdp_async_workers.py new file mode 100644 index 00000000000..458d3e13314 --- /dev/null +++ b/verl/workers/fsdp_async_workers.py @@ -0,0 +1,326 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +import random +import socket +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cloudpickle +import fastapi +import ray +import uvicorn +from omegaconf import DictConfig +from starlette.requests import Request +from starlette.responses import JSONResponse, StreamingResponse +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.executor.abstract import Executor +from vllm.worker.worker_base import WorkerWrapperBase + +from verl import DataProto +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.single_controller.ray.base import RayWorkerGroup +from verl.utils.fs import copy_to_local +from verl.workers.fsdp_workers import ActorRolloutRefWorker +from verl.workers.sharding_manager import FSDPVLLMShardingManager + +logger = logging.getLogger(__file__) + + +def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +class ExternalRayDistributedExecutor(Executor): + """An executor that engines are launched by external ray actors.""" + uses_ray: bool = False + + def _init_executor(self) -> None: + assert self.vllm_config.instance_id is not None, \ + "instance_id must be set for external ray actors." + + fields = self.vllm_config.instance_id.split(":") + assert len(fields) == 4, \ + f"instance_id: {self.vllm_config.instance_id} must be in " \ + f"the format of :::." + namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3]) + + # Make sure subprocess in same namespace as parent actor. + # actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank} + ray.init(namespace=namespace) + actor_names = [ + actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f"{wg_prefix}WorkerDict") + ] + + vllm_tp_size = self.vllm_config.parallel_config.tensor_parallel_size + assert len(actor_names) == vllm_dp_size * vllm_tp_size, \ + f"instance_id: {self.vllm_config.instance_id} has {len(actor_names)} actors, " \ + f"but vllm_dp_size: {vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = " \ + f"{vllm_dp_size * vllm_tp_size} is expected." + + def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]: + fields = actor_name.split(":") + assert len(fields) == 2, f"invalid actor name: {actor_name}" + pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) + return pg_index, local_rank + + # sort actor names by pg_index and local_rank + actor_names = sorted(actor_names, key=get_pg_index_and_local_rank) + actor_names = actor_names[vllm_dp_rank * vllm_tp_size:(vllm_dp_rank + 1) * vllm_tp_size] + self.workers: List[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names] + print(f"instance_id: {self.vllm_config.instance_id} intializes with external actors: {actor_names}") + + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=None, + rank=None, + distributed_init_method="env://", + is_driver_worker=True, + ) + self.collective_rpc("init_worker", args=([kwargs],)) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + print(f"instance_id: {self.vllm_config.instance_id} intializes finished.") + + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict[str, Any]] = None) -> List[Any]: + # TODO(wuxibin): support ray compiled graph + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method + + outputs = ray.get( + [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers]) + return outputs + + def check_health(self): + return + + +@ray.remote(num_cpus=1) +class AsyncLLMWorker: + """ + AsyncLLMWorker is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines + in hybrid rollout workers, i.e AsyncActorRolloutRefWorker. + + It works as follows: + 1. Initialize AsyncLLM with ExternalRayDistributedExecutor. + 2. AsyncLLM spawn EngineCore in subprocess. + 3. EngineCore initialize ExternalRayDistributedExecutor. + 4. ExternalRayDistributedExecutor lookup its corresponding actors by name. + 5. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model. + 6. AsyncLLM initialize done, start FastAPI server. + + For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826 + """ + + def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str): + """ + Args: + config: DictConfig, actor_rollout_ref config. + vllm_dp_size: int, vllm data parallel size. + vllm_dp_rank: int, vllm data parallel rank. + wg_prefix: str, worker group prefix, used to lookup actors. + """ + model_path = config.model.path + model_name = "/".join(model_path.split("/")[-2:]) + local_path = copy_to_local(model_path) + trust_remote_code = config.model.get('trust_remote_code', False) + config = config.rollout + + tensor_parallel_size = config.get('tensor_model_parallel_size', 1) + max_num_batched_tokens = config.get('max_num_batched_tokens', 8192) + max_model_len = config.max_model_len if config.max_model_len \ + else config.prompt_length + config.response_length + max_model_len = int(max_model_len) + + if max_num_batched_tokens < max_model_len and config.enable_chunked_prefill: + raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ + please increase max_num_batched_tokens or disable chunked prefill') + + engine_args = AsyncEngineArgs( + model=local_path, + enable_sleep_mode=True, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=ExternalRayDistributedExecutor, + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + disable_custom_all_reduce=True, + disable_mm_preprocessor_cache=True, + skip_tokenizer_init=False, + max_model_len=max_model_len, + load_format="auto", + disable_log_stats=config.disable_log_stats, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, + enable_prefix_caching=True, + trust_remote_code=trust_remote_code, + seed=vllm_dp_rank, + ) + + # init async llm engine + vllm_config = engine_args.create_engine_config() + namespace = ray.get_runtime_context().namespace + vllm_config.instance_id = f"{namespace}:{wg_prefix}:{vllm_dp_size}:{vllm_dp_rank}" + self.engine = AsyncLLM.from_vllm_config(vllm_config) + + # build serving chat + model_config = self.engine.model_config + BASE_MODEL_PATHS = [BaseModelPath(name=model_name, model_path=model_path)] + models = OpenAIServingModels(self.engine, model_config, BASE_MODEL_PATHS) + self.openai_serving_chat = OpenAIServingChat( + self.engine, + model_config, + models, + "assistant", + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + ) + + # start FastAPI server + self.address = ray._private.services.get_node_ip_address() + self.port = None + self.server_ready = asyncio.Event() + asyncio.create_task(self._start_fastapi_server()) + + async def chat_completion(self, raw_request: Request): + """OpenAI-compatible HTTP endpoint. + + API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html + """ + request_json = await raw_request.json() + request = ChatCompletionRequest(**request_json) + generator = await self.openai_serving_chat.create_chat_completion(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") + else: + assert isinstance(generator, ChatCompletionResponse) + return JSONResponse(content=generator.model_dump()) + + async def _start_fastapi_server(self): + app = fastapi.FastAPI() + app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) + + # TODO: random sleep to reduce port conflict, retry if port is already in use + asyncio.sleep(random.uniform(0, 3)) + self.port = _get_free_port() + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port) + server = uvicorn.Server(config) + self.server_ready.set() + await server.serve() + + async def get_server_address(self) -> Tuple[str, int]: + await self.server_ready.wait() + return f"{self.address}:{self.port}" + + async def wake_up(self): + await self.engine.wake_up() + + async def sleep(self): + await self.engine.sleep() + + +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + + def _build_rollout(self, trust_remote_code=False): + rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) + + # NOTE: rollout is not actually initialized here, it's deferred + # to be initialized by AsyncLLMWorker. + + self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size + self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size + self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size + + # used for sleep/wake_up + rollout.sharding_manager = rollout_sharding_manager + + return rollout, rollout_sharding_manager + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def generate_sequences(self, prompts: DataProto): + raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences") + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + """Called by ExternalRayDistributedExecutor collective_rpc.""" + if self.vllm_tp_rank == 0 and method != "execute_model": + print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] " + f"execute_method: {method if isinstance(method, str) else 'Callable'}") + return self.rollout.execute_method(method, *args, **kwargs) + + +class AsyncLLMManager: + """AsyncLLMManager manage a group of vllm instances, i.e AsyncLLMWorker.""" + + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): + """Initialize AsyncLLMManager. + + Args: + config: DictConfig, actor_rollout_ref config. + worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. + """ + self.config = config + self.worker_group = worker_group + + self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size + self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size + + register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") + workers_info = ray.get(register_center.get_worker_info.remote()) + assert len(workers_info) == self.worker_group.world_size + + # make sure AsyncLLMWorker colocates with its corresponding workers + self.async_llm_workers = [ + AsyncLLMWorker.options( + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], + soft=False, + ), + name=f"async_llm_worker_{rollout_dp_rank}", + ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + for rollout_dp_rank in range(self.rollout_dp_size) + ] + self.server_addresses = ray.get([worker.get_server_address.remote() for worker in self.async_llm_workers]) + + @property + def server_address(self): + """Ruturn FastAPI server addresses of all vllm instances.""" + return self.server_addresses + + async def wake_up(self): + """Wake up all vllm instances.""" + await asyncio.gather(*[worker.wake_up.remote() for worker in self.async_llm_workers]) + + async def sleep(self): + """Sleep all vllm instances.""" + await asyncio.gather(*[worker.sleep.remote() for worker in self.async_llm_workers]) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index d1153c87172..44941f068bb 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -349,27 +349,24 @@ def _build_rollout(self, trust_remote_code=False): # TODO: a sharding manager that do nothing? elif rollout_name == "vllm": - from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout + from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMAsyncRollout, vLLMRollout from verl.workers.sharding_manager import FSDPVLLMShardingManager log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) local_path = copy_to_local(self.config.model.path) if vllm_mode == "customized": - rollout = vLLMRollout( - actor_module=self.actor_module_fsdp, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - ) + rollout = vLLMRollout(actor_module=self.actor_module_fsdp, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config) elif vllm_mode == "spmd": - rollout = vLLMRollout( - model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - trust_remote_code=trust_remote_code, - ) + vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + rollout = vllm_rollout_cls(model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code) else: raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py new file mode 100644 index 00000000000..2580e4343b6 --- /dev/null +++ b/verl/workers/rollout/chat_scheduler.py @@ -0,0 +1,113 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import heapq +from typing import Any, Callable, Dict, List +from uuid import uuid4 + +import aiohttp +from cachetools import LRUCache +from omegaconf import DictConfig +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion + +from verl.protocol import DataProto + + +class ChatCompletionScheduler: + + def __init__(self, config: DictConfig, model_path: str, server_addresses: List[str], max_cache_size: int = 10000): + """ + Args: + config: DictConfig, rollout config. + model_path: str, model path. + server_addresses: List[str], server addresses. + max_cache_size: int, max cache size of request_id to address mapping. + """ + self.config = config + self.model_name = "/".join(model_path.split("/")[-2:]) + + # Least requests load balancing + self.weighted_addresses = [[0, address] for address in server_addresses] + heapq.heapify(self.weighted_addresses) + + # LRU cache to map request_id to address + self.request_id_to_address = LRUCache(maxsize=max_cache_size) + + async def submit_chat_completions( + self, + callback: Callable[[ChatCompletion, Dict[str, Any]], None], + callback_additional_info: Dict[str, Any], + **chat_complete_request, + ): + """ + Submit a chat completion request to the server with the least number of requests. + + Args: + callback: Callable[[ChatCompletion, Dict[str, Any]], None], async callback function to handle the response. + + **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation, + please move to seperate thread or process pool to avoid blocking the event loop. + + callback_additional_info: Dict[str, Any], additional info to pass to the callback function. + + **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create. + OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create + """ + if "extra_headers" not in chat_complete_request: + chat_complete_request["extra_headers"] = {} + + extra_headers = chat_complete_request["extra_headers"] + request_id = extra_headers.get("x-request-id", None) + if request_id: + if request_id.startswith("chatcmpl-"): + request_id = request_id[len("chatcmpl-"):] + extra_headers["x-request-id"] = request_id + + address = self.request_id_to_address[request_id] + else: + address = self.weighted_addresses[0][1] + self.weighted_addresses[0][0] += 1 + heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) + + request_id = uuid4().hex + self.request_id_to_address[request_id] = address + chat_complete_request["extra_headers"]["x-request-id"] = request_id + + # TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. + completions = await self._chat_completions_openai(address, **chat_complete_request) + + await callback(completions, callback_additional_info) + + async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: + client = AsyncOpenAI( + base_url=f"http://{address}/v1", + api_key="token-abc123", + ) + return await client.chat.completions.create(**chat_complete_request) + + async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: + try: + session = aiohttp.ClientSession() + async with session.post( + url=f"http://{address}/v1/chat/completions", + headers={"Authorization": "Bearer token-abc123"}, + json=chat_complete_request, + ) as resp: + data = await resp.json() + return ChatCompletion(**data) + finally: + await session.close() + + async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: + raise NotImplementedError diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 1c0fe1255a4..dbf5cbd9fc0 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -49,4 +49,4 @@ def get_version(pkg): from .vllm_rollout import vLLMRollout else: vllm_mode = "spmd" - from .vllm_rollout_spmd import vLLMRollout + from .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 236a321a682..865df0abb9b 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -28,21 +28,24 @@ import logging import os from contextlib import contextmanager -from typing import Any, List, Union +from typing import Any, Dict, List, Union import numpy as np import torch import torch.distributed from omegaconf import DictConfig from tensordict import TensorDict +from torch import nn from vllm import LLM, SamplingParams from vllm.distributed import parallel_state as vllm_ps +from vllm.worker.worker_base import WorkerWrapperBase from verl import DataProto from verl.third_party.vllm import vllm_version from verl.utils.debug import GPUMemoryLogger from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length from verl.workers.rollout.base import BaseRollout +from verl.workers.sharding_manager import FSDPVLLMShardingManager logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -338,3 +341,57 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: self.inference_engine.free_cache_engine() return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + +class vLLMAsyncRollout: + """vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase, + which is engine in single worker process. + """ + + def __init__(self, *args, **kwargs): + # Engine is deferred to be initialized in init_worker + self.inference_engine: WorkerWrapperBase = None + self.sharding_manager: FSDPVLLMShardingManager = None + self.is_sleep = False + + def init_worker(self, all_kwargs: List[Dict[str, Any]]): + """Initialize worker engine.""" + all_kwargs[0]["rank"] = int(os.environ["RANK"]) + all_kwargs[0]["local_rank"] = 0 + + self.vllm_config = all_kwargs[0]["vllm_config"] + self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) + self.inference_engine.init_worker(all_kwargs) + + def load_model(self, *args, **kwargs): + self.inference_engine.load_model(*args, **kwargs) + + # inference engine is intialized now, update sharding manager + self.sharding_manager.inference_engine = self.inference_engine + self.sharding_manager.model_runner = self.inference_engine.worker.model_runner + + def sleep(self, *args, **kwargs): + """Offload model weights and discard kv cache.""" + if self.is_sleep: + return + self.sharding_manager.__exit__(None, None, None) + self.is_sleep = True + + def wake_up(self, *args, **kwargs): + """Load model weights and build kv cache.""" + if not self.is_sleep: + return + self.sharding_manager.__enter__() # pylint: disable=C2801 + self.is_sleep = False + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + if method == "init_worker": + return self.init_worker(*args, **kwargs) + elif method == "load_model": + return self.load_model(*args, **kwargs) + elif method == "sleep": + return self.sleep(*args, **kwargs) + elif method == "wake_up": + return self.wake_up(*args, **kwargs) + else: + return self.inference_engine.execute_method(method, *args, **kwargs) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 060ca349582..547ad7d01ab 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -47,7 +47,9 @@ def __init__( device_mesh: DeviceMesh = None, ): self.module = module + # For AsyncLLM, inference_engine and model_runner are defer intialized in vLLMAsyncRollout.load_model self.inference_engine = inference_engine + self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None self.model_config = model_config self.device_mesh = device_mesh @@ -64,8 +66,8 @@ def __init__( state_dict_config=ShardedStateDictConfig(), ) - self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() - self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() + self.tp_size = self.device_mesh["infer_tp"].size() + self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() # Note that torch_random_states may be different on each dp rank self.torch_random_states = torch.cuda.get_rng_state() @@ -182,7 +184,7 @@ def postprocess_data(self, data: DataProto) -> DataProto: return data.chunk(chunks=self.tp_size)[self.tp_rank] def update_params(self, updated_params): - model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + model = self.model_runner.model patch_vllm_moe_model_weight_loader(model) world_size = torch.distributed.get_world_size() loaded_params = model.load_weights( From e3586632b249ec770e2933c753949716dfec7041 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Mon, 21 Apr 2025 18:44:36 +0800 Subject: [PATCH 02/10] move FSDP param load/offload into sharding manager --- tests/rollout/test_vllm_multi_turn.py | 59 +++++++------ .../base/register_center/ray.py | 2 +- verl/single_controller/ray/base.py | 83 ++++++++++--------- verl/trainer/config/generation.yaml | 1 + verl/trainer/config/ppo_megatron_trainer.yaml | 1 + verl/trainer/ppo/ray_trainer.py | 2 +- verl/workers/fsdp_async_workers.py | 61 +++++++------- verl/workers/fsdp_workers.py | 34 ++++---- verl/workers/sharding_manager/fsdp_sglang.py | 13 +-- verl/workers/sharding_manager/fsdp_vllm.py | 17 ++-- 10 files changed, 144 insertions(+), 129 deletions(-) diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py index 79d902ad8cf..f59b9791d5b 100644 --- a/tests/rollout/test_vllm_multi_turn.py +++ b/tests/rollout/test_vllm_multi_turn.py @@ -20,7 +20,7 @@ from openai.types.chat.chat_completion import ChatCompletion from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup -from verl.single_controller.ray.base import Worker, create_colocated_worker_cls +from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker, AsyncLLMManager from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler @@ -35,20 +35,25 @@ async def test_vllm_multi_turn(): config.actor_rollout_ref.rollout.prompt_length = 4096 config.actor_rollout_ref.rollout.response_length = 4096 + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + # =========================== 1. Create hybrid ActorRollout workers =========================== ray.init( runtime_env={ - 'env_vars': { - 'TOKENIZERS_PARALLELISM': 'true', - 'NCCL_DEBUG': 'WARN', - 'VLLM_LOGGING_LEVEL': 'WARN', - 'VLLM_USE_V1': '1', + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + "VLLM_USE_V1": "1", } - }) + } + ) role_worker_mapping = { Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), } - global_pool_id = 'global_pool' + global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } @@ -61,20 +66,20 @@ async def test_vllm_multi_turn(): # create actor and rollout resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) - actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout], - config=config.actor_rollout_ref, - role='actor_rollout') - resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls + actor_rollout_cls = RayClassWithInitArgs( + cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" + ) + resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls all_wg = {} wg_dicts = [] for resource_pool, class_dict in resource_pool_to_cls.items(): - worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict, worker_cls=Worker) + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) wg_dicts.append(wg_dict) - actor_rollout_wg = all_wg['actor_rollout'] + actor_rollout_wg = all_wg["actor_rollout"] actor_rollout_wg.init_model() # =========================== 2. Create AsyncLLMManager&ChatScheduler =========================== @@ -89,6 +94,10 @@ async def test_vllm_multi_turn(): server_addresses=async_rollout_manager.server_addresses, ) + # test sleep and wake_up + async_rollout_manager.sleep() + async_rollout_manager.wake_up() + # =========================== 3. Multi turn rollout =========================== async def callback(completions: ChatCompletion, info: Dict[str, Any]): messages, round = info["messages"], info["round"] @@ -101,10 +110,7 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any]): messages.append({"role": "user", "content": "What is your name?"}) await async_chat_scheduler.submit_chat_completions( callback=callback, - callback_additional_info={ - "messages": messages, - "round": 1 - }, + callback_additional_info={"messages": messages, "round": 1}, model=model_name, messages=messages, extra_headers=extra_headers, @@ -113,10 +119,7 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any]): messages.append({"role": "user", "content": "What is your favorite color?"}) await async_chat_scheduler.submit_chat_completions( callback=callback, - callback_additional_info={ - "messages": messages, - "round": 2 - }, + callback_additional_info={"messages": messages, "round": 2}, model=model_name, messages=messages, extra_headers=extra_headers, @@ -124,16 +127,12 @@ async def callback(completions: ChatCompletion, info: Dict[str, Any]): else: print("Done!") - messages = [{ - "role": "user", - "content": "Let's play a role playing game. Your name is Bob, your favorite color is red." - }] + messages = [ + {"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."} + ] await async_chat_scheduler.submit_chat_completions( callback=callback, - callback_additional_info={ - "messages": messages, - "round": 0 - }, + callback_additional_info={"messages": messages, "round": 0}, model=model_name, messages=messages, ) diff --git a/verl/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py index 23a44bf6dc7..8ff70bd36a8 100644 --- a/verl/single_controller/base/register_center/ray.py +++ b/verl/single_controller/base/register_center/ray.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Tuple +from typing import Dict import ray diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index edaf6c7eaa4..eb2cf0a7771 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -13,8 +13,10 @@ # limitations under the License. import logging +import os import time from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import patch import ray from ray.experimental.state.api import get_actor @@ -23,6 +25,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup +from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch __all__ = ["Worker"] @@ -300,17 +303,23 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d elapsed = int(time.time() - start_time) if elapsed % 30 == 0: logging.warning( - f"Waiting for register center actor {actor_name} to be ready. " - f"Elapsed time: {elapsed} seconds out of {self._ray_wait_register_center_timeout} seconds." + "Waiting for register center actor %s to be ready. " + "Elapsed time: %s seconds out of %s seconds.", + actor_name, + elapsed, + self._ray_wait_register_center_timeout, ) time.sleep(1) if register_center_actor is None: raise TimeoutError( - f"Failed to get register_center_actor {actor_name} in {list_named_actors(all_namespaces=True)} " + f"Failed to get register_center_actor {actor_name} " + f"in {list_named_actors(all_namespaces=True)} " f"for {self._ray_wait_register_center_timeout} seconds. " - "Ensure that any lingering Ray resources from previous runs are cleaned up (e.g., by restarting the Ray cluster), " - "or adjust the waiting time by modifying the config `trainer.ray_wait_register_center_timeout`." + "Ensure that any lingering Ray resources from previous " + "runs are cleaned up (e.g., by restarting the Ray cluster), " + "or adjust the waiting time by modifying the config " + "`trainer.ray_wait_register_center_timeout`." ) rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) @@ -329,10 +338,9 @@ def from_detached( worker_names=None, ray_cls_with_init=None, ): - worker_group = cls(resource_pool=None, - ray_cls_with_init=ray_cls_with_init, - name_prefix=name_prefix, - worker_names=worker_names) + worker_group = cls( + resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names + ) return worker_group def spawn(self, prefix_set): @@ -382,8 +390,9 @@ def execute_all_sync(self, method_name: str, *args, **kwargs): return ray.get(self.execute_all_async(method_name, *args, **kwargs)) def execute_all_async(self, method_name: str, *args, **kwargs): - # Here, we assume that if all arguments in args and kwargs are lists, and their lengths match len(self._workers), - # we'll distribute each element in these lists to the corresponding worker + # Here, we assume that if all arguments in args and kwargs are lists, + # and their lengths match len(self._workers), we'll distribute each + # element in these lists to the corresponding worker # print(f"execute_all_async: method {method_name}({args}, {kwargs})") length = len(self._workers) if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): @@ -421,11 +430,6 @@ def world_size(self): with code written in separate ray.Actors. """ -import os -from unittest.mock import patch - -from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch - def _bind_workers_method_to_parent(cls, key, user_defined_cls): """ @@ -443,12 +447,12 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls): if hasattr(method, MAGIC_ATTR): - def generate_function(name): + def generate_function(name, key=key): def func(self, *args, **kwargs): # dispatch to the actual worker return getattr(self.worker_dict[key], name)(*args, **kwargs) - return func + return func # noqa: B023 func = generate_function(method_name) # pass MAGIC_ATTR for outer worker group @@ -457,15 +461,16 @@ def func(self, *args, **kwargs): try: # bind direct rollout method to class without prefix if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: - assert not hasattr(cls, method_name), \ + assert not hasattr(cls, method_name), ( f"conflict direct rollout method {method_name} with role {key}" + ) setattr(cls, method_name, func) print(f"bind role {key} method {method_name} to class {cls}") else: - method_name_with_prefix = key + '_' + method_name + method_name_with_prefix = key + "_" + method_name setattr(cls, method_name_with_prefix, func) except Exception as e: - raise ValueError(f"Fail to set method_name {method_name}") + raise ValueError(f"Fail to set method_name {method_name}") from e def _unwrap_ray_remote(cls): @@ -474,32 +479,31 @@ def _unwrap_ray_remote(cls): return cls -def _nearest_common_base(mros: List): - last_common = object - min_len = min([len(mro) for mro in mros]) - 1 # exclude final derived class - - for i in range(min_len): - mro = mros[0][i] - for j in range(1, len(mros)): - if mro != mros[j][i]: - return last_common - last_common = mro - - return last_common +def _determine_fsdp_megatron_base_class(mros: List): + """ + - megatron: base class should be MegatronWorker + - fsdp: base class should be Worker + """ + for cls in mros[0]: + if cls.__name__ == "MegatronWorker": + return cls + if cls.__name__ == "Worker": + return cls + raise ValueError(f"Cannot determine base class for {mros}") -def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs], worker_cls: type = None): +def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): """ This function should return a class instance that delegates the calls to every cls in cls_dict """ cls_dict = {} init_args_dict = {} - if worker_cls is None: - worker_cls = _nearest_common_base( - [list(reversed(cls.cls.__ray_actor_class__.__mro__)) for cls in class_dict.values()]) + worker_cls = _determine_fsdp_megatron_base_class( + [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()] + ) assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" - print(f"find nearest common base class {worker_cls}") + print(f"colocated worker base class {worker_cls}") for key, cls in class_dict.items(): cls_dict[key] = cls.cls @@ -515,7 +519,8 @@ def __init__(self): for key, user_defined_cls in cls_dict.items(): user_defined_cls = _unwrap_ray_remote(user_defined_cls) # directly instantiate the class without remote - # in worker class, e.g. when DISABLE_WORKER_INIT == 1 it will return immediately + # in worker class, e.g. + # when DISABLE_WORKER_INIT == 1 it will return immediately with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): self.worker_dict[key] = user_defined_cls( *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index 8b542f542d4..3d74db9db69 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -14,6 +14,7 @@ model: external_lib: null rollout: name: vllm + mode: "sync" # sync: LLM, async: AsyncLLM temperature: 1.0 top_k: 50 # 0 for hf rollout, -1 for vllm rollout top_p: 0.7 diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 4c84f315bef..b87695213b8 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -94,6 +94,7 @@ actor_rollout_ref: log_prob_micro_batch_size_per_gpu: null rollout: name: vllm + mode: "sync" # sync: LLM, async: AsyncLLM temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index f44961b4291..5c03fb601e6 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -773,7 +773,7 @@ def init_workers(self): # create async rollout manager and request scheduler self.async_rollout_mode = False - if self.config.actor_rollout_ref.rollout.get("mode", "sync") == 'async': + if self.config.actor_rollout_ref.rollout.mode == "async": from verl.workers.fsdp_async_workers import AsyncLLMManager self.async_rollout_mode = True diff --git a/verl/workers/fsdp_async_workers.py b/verl/workers/fsdp_async_workers.py index 458d3e13314..fd7bbb36a79 100644 --- a/verl/workers/fsdp_async_workers.py +++ b/verl/workers/fsdp_async_workers.py @@ -34,12 +34,10 @@ from vllm.worker.worker_base import WorkerWrapperBase from verl import DataProto -from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, register from verl.single_controller.ray.base import RayWorkerGroup from verl.utils.fs import copy_to_local from verl.workers.fsdp_workers import ActorRolloutRefWorker -from verl.workers.sharding_manager import FSDPVLLMShardingManager logger = logging.getLogger(__file__) @@ -52,16 +50,17 @@ def _get_free_port(): class ExternalRayDistributedExecutor(Executor): """An executor that engines are launched by external ray actors.""" + uses_ray: bool = False def _init_executor(self) -> None: - assert self.vllm_config.instance_id is not None, \ - "instance_id must be set for external ray actors." + assert self.vllm_config.instance_id is not None, "instance_id must be set for external ray actors." fields = self.vllm_config.instance_id.split(":") - assert len(fields) == 4, \ - f"instance_id: {self.vllm_config.instance_id} must be in " \ + assert len(fields) == 4, ( + f"instance_id: {self.vllm_config.instance_id} must be in " f"the format of :::." + ) namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3]) # Make sure subprocess in same namespace as parent actor. @@ -72,10 +71,11 @@ def _init_executor(self) -> None: ] vllm_tp_size = self.vllm_config.parallel_config.tensor_parallel_size - assert len(actor_names) == vllm_dp_size * vllm_tp_size, \ - f"instance_id: {self.vllm_config.instance_id} has {len(actor_names)} actors, " \ - f"but vllm_dp_size: {vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = " \ + assert len(actor_names) == vllm_dp_size * vllm_tp_size, ( + f"instance_id: {self.vllm_config.instance_id} has {len(actor_names)} actors, " + f"but vllm_dp_size: {vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = " f"{vllm_dp_size * vllm_tp_size} is expected." + ) def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]: fields = actor_name.split(":") @@ -85,7 +85,7 @@ def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]: # sort actor names by pg_index and local_rank actor_names = sorted(actor_names, key=get_pg_index_and_local_rank) - actor_names = actor_names[vllm_dp_rank * vllm_tp_size:(vllm_dp_rank + 1) * vllm_tp_size] + actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size] self.workers: List[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names] print(f"instance_id: {self.vllm_config.instance_id} intializes with external actors: {actor_names}") @@ -101,11 +101,13 @@ def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]: self.collective_rpc("load_model") print(f"instance_id: {self.vllm_config.instance_id} intializes finished.") - def collective_rpc(self, - method: Union[str, Callable], - timeout: Optional[float] = None, - args: Tuple = (), - kwargs: Optional[Dict[str, Any]] = None) -> List[Any]: + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict[str, Any]] = None, + ) -> List[Any]: # TODO(wuxibin): support ray compiled graph if isinstance(method, str): sent_method = method @@ -114,7 +116,8 @@ def collective_rpc(self, del method outputs = ray.get( - [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers]) + [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers] + ) return outputs def check_health(self): @@ -126,7 +129,7 @@ class AsyncLLMWorker: """ AsyncLLMWorker is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines in hybrid rollout workers, i.e AsyncActorRolloutRefWorker. - + It works as follows: 1. Initialize AsyncLLM with ExternalRayDistributedExecutor. 2. AsyncLLM spawn EngineCore in subprocess. @@ -149,18 +152,19 @@ def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_ model_path = config.model.path model_name = "/".join(model_path.split("/")[-2:]) local_path = copy_to_local(model_path) - trust_remote_code = config.model.get('trust_remote_code', False) + trust_remote_code = config.model.get("trust_remote_code", False) config = config.rollout - tensor_parallel_size = config.get('tensor_model_parallel_size', 1) - max_num_batched_tokens = config.get('max_num_batched_tokens', 8192) - max_model_len = config.max_model_len if config.max_model_len \ - else config.prompt_length + config.response_length + tensor_parallel_size = config.get("tensor_model_parallel_size", 1) + max_num_batched_tokens = config.get("max_num_batched_tokens", 8192) + max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length max_model_len = int(max_model_len) if max_num_batched_tokens < max_model_len and config.enable_chunked_prefill: - raise ValueError('Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ - please increase max_num_batched_tokens or disable chunked prefill') + raise ValueError( + "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ + please increase max_num_batched_tokens or disable chunked prefill" + ) engine_args = AsyncEngineArgs( model=local_path, @@ -250,7 +254,6 @@ async def sleep(self): class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): - def _build_rollout(self, trust_remote_code=False): rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) @@ -274,8 +277,10 @@ def generate_sequences(self, prompts: DataProto): def execute_method(self, method: Union[str, bytes], *args, **kwargs): """Called by ExternalRayDistributedExecutor collective_rpc.""" if self.vllm_tp_rank == 0 and method != "execute_model": - print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] " - f"execute_method: {method if isinstance(method, str) else 'Callable'}") + print( + f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] " + f"execute_method: {method if isinstance(method, str) else 'Callable'}" + ) return self.rollout.execute_method(method, *args, **kwargs) @@ -287,7 +292,7 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): Args: config: DictConfig, actor_rollout_ref config. - worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. + worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. """ self.config = config self.worker_group = worker_group diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 44941f068bb..1890e3842d9 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -355,18 +355,22 @@ def _build_rollout(self, trust_remote_code=False): log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) local_path = copy_to_local(self.config.model.path) if vllm_mode == "customized": - rollout = vLLMRollout(actor_module=self.actor_module_fsdp, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config) + rollout = vLLMRollout( + actor_module=self.actor_module_fsdp, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + ) elif vllm_mode == "spmd": vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout - rollout = vllm_rollout_cls(model_path=local_path, - config=self.config.rollout, - tokenizer=self.tokenizer, - model_hf_config=self.actor_model_config, - device_mesh=rollout_device_mesh, - trust_remote_code=trust_remote_code) + rollout = vllm_rollout_cls( + model_path=local_path, + config=self.config.rollout, + tokenizer=self.tokenizer, + model_hf_config=self.actor_model_config, + device_mesh=rollout_device_mesh, + trust_remote_code=trust_remote_code, + ) else: raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'") log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=None) @@ -378,6 +382,7 @@ def _build_rollout(self, trust_remote_code=False): model_config=self.actor_model_config, full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, ) log_gpu_memory_usage("After building sharding manager", logger=None) @@ -409,6 +414,7 @@ def _build_rollout(self, trust_remote_code=False): model_config=self.actor_model_config, full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, ) log_gpu_memory_usage("After building sharding manager", logger=None) @@ -544,8 +550,6 @@ def generate_sequences(self, prompts: DataProto): prompts = prompts.to(torch.cuda.current_device()) assert self._is_rollout - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) meta_info = { "eos_token_id": self.generation_config.eos_token_id @@ -557,11 +561,7 @@ def generate_sequences(self, prompts: DataProto): } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: - # after parameters sync with rollout, offload actor model to CPU - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) output = self.rollout.generate_sequences(prompts=prompts) diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index c9b6222883c..4037b573e28 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -37,6 +37,7 @@ from verl import DataProto from verl.protocol import all_gather_data_proto from verl.utils.debug import log_gpu_memory_usage +from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.torch_functional import broadcast_dict_tensor from .base import BaseShardingManager @@ -55,11 +56,13 @@ def __init__( model_config, full_params: bool = False, device_mesh: DeviceMesh = None, + offload_param: bool = False, ): self.module = module self.inference_engine = inference_engine self.model_config = model_config self.device_mesh = device_mesh + self.offload_param = offload_param # Full params self.full_params = full_params @@ -88,6 +91,8 @@ def __init__( def __enter__(self): torch.cuda.empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_fsdp_model_to_gpu(self.module) params = self.module.state_dict() log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory @@ -98,15 +103,11 @@ def __enter__(self): log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) torch.cuda.empty_cache() log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - # TODO: offload FSDP model weights - # self.module.cpu() - # torch.cuda.empty_cache() - # if torch.distributed.get_rank() == 0: - # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 547ad7d01ab..315d6525bed 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -30,6 +30,7 @@ from verl.third_party.vllm import vllm_version from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader +from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from .base import BaseShardingManager @@ -45,13 +46,17 @@ def __init__( model_config, full_params: bool = False, device_mesh: DeviceMesh = None, + offload_param: bool = False, ): self.module = module # For AsyncLLM, inference_engine and model_runner are defer intialized in vLLMAsyncRollout.load_model self.inference_engine = inference_engine - self.model_runner = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None + self.model_runner = ( + inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None + ) self.model_config = model_config self.device_mesh = device_mesh + self.offload_param = offload_param # Full params self.full_params = full_params @@ -92,6 +97,8 @@ def __enter__(self): torch.cuda.empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_fsdp_model_to_gpu(self.module) params = self.module.state_dict() log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory @@ -114,6 +121,8 @@ def __enter__(self): self.update_params(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) torch.cuda.empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: @@ -121,12 +130,6 @@ def __enter__(self): log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - # TODO: offload FSDP model weights - # self.module.cpu() - # torch.cuda.empty_cache() - # if torch.distributed.get_rank() == 0: - # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() From a61d8cceada8c50d79d89cab03f0735facb58983 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Tue, 22 Apr 2025 23:35:24 +0800 Subject: [PATCH 03/10] develop naive chat scheduler --- examples/ppo_trainer/naive_chat_scheduler.py | 111 ++++++++++++++++--- tests/ray/test_worker_group_basics.py | 4 +- tests/rollout/test_vllm_multi_turn.py | 1 + verl/trainer/ppo/ray_trainer.py | 1 + verl/workers/rollout/chat_scheduler.py | 21 +++- 5 files changed, 116 insertions(+), 22 deletions(-) diff --git a/examples/ppo_trainer/naive_chat_scheduler.py b/examples/ppo_trainer/naive_chat_scheduler.py index ee81b52d22b..6ff14363ed7 100644 --- a/examples/ppo_trainer/naive_chat_scheduler.py +++ b/examples/ppo_trainer/naive_chat_scheduler.py @@ -14,19 +14,33 @@ import asyncio from typing import Any, Dict, List +import torch from omegaconf import DictConfig from openai.types.chat.chat_completion import ChatCompletion +from tensordict import TensorDict +from transformers import PreTrainedTokenizer from verl.protocol import DataProto from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler class NaiveChatCompletionScheduler(ChatCompletionScheduler): + """ + A very naive implementation of ChatCompletionScheduler for demo purpose, + only do single-turn chat completion. + """ - def __init__(self, config: DictConfig, model_path: str, server_addresses: List[str], max_cache_size: int = 10000): - super().__init__(config, model_path, server_addresses, max_cache_size) + def __init__( + self, + config: DictConfig, + model_path: str, + tokenizer: PreTrainedTokenizer, + server_addresses: List[str], + max_cache_size: int = 10000, + ): + super().__init__(config, model_path, tokenizer, server_addresses, max_cache_size) - async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: + async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: kwargs = dict( n=self.config.n, max_completion_tokens=self.config.response_length, @@ -34,8 +48,8 @@ async def generate_sequences(self, prompts: DataProto, **sampling_params) -> Dat top_p=self.config.top_p, ) - do_sample = prompts.meta_info.get('do_sample', True) - is_validate = prompts.meta_info.get('validate', False) + do_sample = batch.meta_info.get("do_sample", True) + is_validate = batch.meta_info.get("validate", False) if not do_sample or is_validate: kwargs["n"] = 1 kwargs["temperature"] = 0 @@ -44,29 +58,96 @@ async def generate_sequences(self, prompts: DataProto, **sampling_params) -> Dat print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}") async def callback(completions: ChatCompletion, info: Dict[str, Any]): - info["all_completions"][info["index"]] = completions + conversation, batch_conversations, batch_index = ( + info["conversation"], + info["batch_conversations"], + info["batch_index"], + ) + + conversations = [] + for choice in completions.choices: + chat = conversation.copy() + chat.append({"role": choice.message.role, "content": choice.message.content}) + conversations.append(chat) + batch_conversations[batch_index] = conversations # NOTE: we can call tools and resubmit chat completions here. # call_tools(completions, info) # await self.submit_chat_completions(callback2, ...) - tasks, all_completions = [], [None] * len(prompts) - for i, prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]): + tasks, batch_conversations = [], [None] * len(batch) + for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]): # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] tasks.append( asyncio.create_task( self.submit_chat_completions( callback=callback, callback_additional_info={ - "all_completions": all_completions, - "index": i + "batch_conversations": batch_conversations, + "batch_index": batch_index, + "conversation": list(conversation), }, model=self.model_name, - messages=prompt, + messages=conversation, **kwargs, - ))) + ) + ) + ) await asyncio.gather(*tasks) - print("[NaiveChatCompletionScheduler] generate_sequences done") - # TODO: completions => DataProto - return all_completions + + return self._postprocess(batch, batch_conversations, kwargs["n"]) + + def _postprocess( + self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int + ) -> DataProto: + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py + # prompts: left pad + # responses: right pad + # input_ids: prompt + response + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + + # prompts: [prompt] from input dataset + prompts = [ + self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False) + for prompt in batch.non_tensor_batch["raw_prompt"] + ] + + # flatten batch_conversations if n > 1 + assert len(batch_conversations) == len(prompts) + batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations] + assert len(batch_conversations) == len(prompts) * n + + # sequences: [prompt + response] + sequences = [ + self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) + for conversation in batch_conversations + ] + + # responses: [response] + # TODO: mask out tools calling tokens? + responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] + + prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") + responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") + if n > 1: + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) + + input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) + attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) + position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask + + batch = TensorDict( + { + "prompts": prompts["input_ids"], + "responses": responses["input_ids"], + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=len(input_ids), + ) + + return batch diff --git a/tests/ray/test_worker_group_basics.py b/tests/ray/test_worker_group_basics.py index 02a5b94ebb9..d0cecb515bf 100644 --- a/tests/ray/test_worker_group_basics.py +++ b/tests/ray/test_worker_group_basics.py @@ -68,7 +68,9 @@ def foo_custom(self, x, y): @ray.remote(num_gpus=0.1) def remote_call_wg(worker_names): class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) - worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args) + worker_group = RayWorkerGroup.from_detached( + worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None + ) print(worker_group.worker_names) output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6]) diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py index f59b9791d5b..19eb0a61aa9 100644 --- a/tests/rollout/test_vllm_multi_turn.py +++ b/tests/rollout/test_vllm_multi_turn.py @@ -91,6 +91,7 @@ async def test_vllm_multi_turn(): async_chat_scheduler = ChatCompletionScheduler( config=config.actor_rollout_ref.rollout, model_path=config.actor_rollout_ref.model.path, + tokenizer=None, server_addresses=async_rollout_manager.server_addresses, ) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 5c03fb601e6..4f8e3231e7a 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -787,6 +787,7 @@ def init_workers(self): self.async_chat_scheduler = scheduler_cls( config=self.config.actor_rollout_ref.rollout, model_path=self.config.actor_rollout_ref.model.path, + tokenizer=self.tokenizer, server_addresses=self.async_rollout_manager.server_addresses, ) diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py index 2580e4343b6..8fdc94adbd0 100644 --- a/verl/workers/rollout/chat_scheduler.py +++ b/verl/workers/rollout/chat_scheduler.py @@ -20,22 +20,31 @@ from omegaconf import DictConfig from openai import AsyncOpenAI from openai.types.chat.chat_completion import ChatCompletion +from transformers import PreTrainedTokenizer from verl.protocol import DataProto class ChatCompletionScheduler: - - def __init__(self, config: DictConfig, model_path: str, server_addresses: List[str], max_cache_size: int = 10000): + def __init__( + self, + config: DictConfig, + model_path: str, + tokenizer: PreTrainedTokenizer, + server_addresses: List[str], + max_cache_size: int = 10000, + ): """ Args: config: DictConfig, rollout config. model_path: str, model path. + tokenizer: PreTrainedTokenizer, tokenizer. server_addresses: List[str], server addresses. max_cache_size: int, max cache size of request_id to address mapping. """ self.config = config self.model_name = "/".join(model_path.split("/")[-2:]) + self.tokenizer = tokenizer # Least requests load balancing self.weighted_addresses = [[0, address] for address in server_addresses] @@ -71,7 +80,7 @@ async def submit_chat_completions( request_id = extra_headers.get("x-request-id", None) if request_id: if request_id.startswith("chatcmpl-"): - request_id = request_id[len("chatcmpl-"):] + request_id = request_id[len("chatcmpl-") :] extra_headers["x-request-id"] = request_id address = self.request_id_to_address[request_id] @@ -100,9 +109,9 @@ async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) try: session = aiohttp.ClientSession() async with session.post( - url=f"http://{address}/v1/chat/completions", - headers={"Authorization": "Bearer token-abc123"}, - json=chat_complete_request, + url=f"http://{address}/v1/chat/completions", + headers={"Authorization": "Bearer token-abc123"}, + json=chat_complete_request, ) as resp: data = await resp.json() return ChatCompletion(**data) From c3d2249a54b368df3ac9b606293db7704cc3acfa Mon Sep 17 00:00:00 2001 From: wuxibin Date: Wed, 23 Apr 2025 19:42:28 +0800 Subject: [PATCH 04/10] restart uvicorn server if address already in use --- .../grpo_trainer/run_qwen2-7b_seq_balance.sh | 10 ++ examples/ppo_trainer/naive_chat_scheduler.py | 4 +- tests/rollout/test_vllm_multi_turn.py | 7 +- verl/trainer/ppo/ray_trainer.py | 13 ++- verl/workers/fsdp_async_workers.py | 97 +++++++++++++------ verl/workers/rollout/chat_scheduler.py | 24 ++++- 6 files changed, 111 insertions(+), 44 deletions(-) diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh index 2c6d6477e28..2976534698a 100644 --- a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh @@ -3,10 +3,18 @@ set -x # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: # export VLLM_ATTENTION_BACKEND=XFORMERS +# For async rollout mode, dataset should return raw chat. +rollout_mode="sync" +if [ "$rollout_mode" = "async" ]; then + return_raw_chat="True" + chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler +fi + python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ + data.return_raw_chat=$return_raw_chat \ data.train_batch_size=1024 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ @@ -27,6 +35,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ diff --git a/examples/ppo_trainer/naive_chat_scheduler.py b/examples/ppo_trainer/naive_chat_scheduler.py index 6ff14363ed7..5e00bca30cc 100644 --- a/examples/ppo_trainer/naive_chat_scheduler.py +++ b/examples/ppo_trainer/naive_chat_scheduler.py @@ -57,7 +57,7 @@ async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataP kwargs.update(sampling_params) print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}") - async def callback(completions: ChatCompletion, info: Dict[str, Any]): + async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): conversation, batch_conversations, batch_index = ( info["conversation"], info["batch_conversations"], @@ -150,4 +150,4 @@ def _postprocess( batch_size=len(input_ids), ) - return batch + return DataProto(batch=batch) diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py index 19eb0a61aa9..43990ed4dc4 100644 --- a/tests/rollout/test_vllm_multi_turn.py +++ b/tests/rollout/test_vllm_multi_turn.py @@ -96,11 +96,12 @@ async def test_vllm_multi_turn(): ) # test sleep and wake_up - async_rollout_manager.sleep() - async_rollout_manager.wake_up() + await async_rollout_manager.sleep() + await async_rollout_manager.wake_up() # =========================== 3. Multi turn rollout =========================== - async def callback(completions: ChatCompletion, info: Dict[str, Any]): + async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): + assert exception is None, f"exception: {exception}" messages, round = info["messages"], info["round"] message = completions.choices[0].message messages.append({"role": message.role, "content": message.content}) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 4f8e3231e7a..b2e7ddb41bd 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -582,7 +582,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): # Log to each configured logger self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - def _validate(self): + async def _validate(self): data_source_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) @@ -617,7 +617,7 @@ def _validate(self): else: test_gen_batch = test_batch.pop( batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], + non_tensor_batch_keys=["raw_prompt_ids"] + ["raw_prompt"] if self.async_rollout_mode else [], ) test_gen_batch.meta_info = { @@ -631,7 +631,12 @@ def _validate(self): # pad to be divisible by dp_size test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + await self.async_rollout_manager.wake_up() + test_output_gen_batch_padded = await self.async_chat_scheduler.generate_sequences(test_gen_batch_padded) + await self.async_rollout_manager.sleep() # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) @@ -1131,7 +1136,7 @@ async def fit(self): and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) ): with _timer("testing", timing_raw): - val_metrics: dict = self._validate() + val_metrics: dict = await self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) diff --git a/verl/workers/fsdp_async_workers.py b/verl/workers/fsdp_async_workers.py index fd7bbb36a79..39f80eaab3b 100644 --- a/verl/workers/fsdp_async_workers.py +++ b/verl/workers/fsdp_async_workers.py @@ -14,8 +14,8 @@ import asyncio import logging import os -import random import socket +from contextlib import asynccontextmanager from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cloudpickle @@ -131,12 +131,12 @@ class AsyncLLMWorker: in hybrid rollout workers, i.e AsyncActorRolloutRefWorker. It works as follows: - 1. Initialize AsyncLLM with ExternalRayDistributedExecutor. - 2. AsyncLLM spawn EngineCore in subprocess. - 3. EngineCore initialize ExternalRayDistributedExecutor. - 4. ExternalRayDistributedExecutor lookup its corresponding actors by name. - 5. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model. - 6. AsyncLLM initialize done, start FastAPI server. + 1. Start FastAPI server first. + 2. Initialize AsyncLLM with ExternalRayDistributedExecutor. + 3. AsyncLLM spawn EngineCore in subprocess. + 4. EngineCore initialize ExternalRayDistributedExecutor. + 5. ExternalRayDistributedExecutor lookup its corresponding actors by name. + 6. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model. For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826 """ @@ -149,6 +149,21 @@ def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_ vllm_dp_rank: int, vllm data parallel rank. wg_prefix: str, worker group prefix, used to lookup actors. """ + self.config = config + self.vllm_dp_size = vllm_dp_size + self.vllm_dp_rank = vllm_dp_rank + self.wg_prefix = wg_prefix + self.engine: AsyncLLM = None + + # start FastAPI server + self.address = ray._private.services.get_node_ip_address() + self.port = None + self.server_ready = asyncio.Event() + asyncio.create_task(self._start_fastapi_server()) + + async def init_async_llm(self): + """Init vLLM AsyncLLM engine.""" + config = self.config model_path = config.model.path model_name = "/".join(model_path.split("/")[-2:]) local_path = copy_to_local(model_path) @@ -184,13 +199,13 @@ def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_ enable_chunked_prefill=config.enable_chunked_prefill, enable_prefix_caching=True, trust_remote_code=trust_remote_code, - seed=vllm_dp_rank, + seed=self.vllm_dp_rank, ) # init async llm engine vllm_config = engine_args.create_engine_config() namespace = ray.get_runtime_context().namespace - vllm_config.instance_id = f"{namespace}:{wg_prefix}:{vllm_dp_size}:{vllm_dp_rank}" + vllm_config.instance_id = f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}" self.engine = AsyncLLM.from_vllm_config(vllm_config) # build serving chat @@ -207,12 +222,6 @@ def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_ chat_template_content_format="auto", ) - # start FastAPI server - self.address = ray._private.services.get_node_ip_address() - self.port = None - self.server_ready = asyncio.Event() - asyncio.create_task(self._start_fastapi_server()) - async def chat_completion(self, raw_request: Request): """OpenAI-compatible HTTP endpoint. @@ -231,11 +240,20 @@ async def chat_completion(self, raw_request: Request): return JSONResponse(content=generator.model_dump()) async def _start_fastapi_server(self): - app = fastapi.FastAPI() + @asynccontextmanager + async def lifespan(app: fastapi.FastAPI): + print("FastAPI startup") + self.server_ready.set() + yield + + # There's no way to gracefully restart uvicorn server if port is already in use, + # so we exit the process directly and let AsyncLLMManager restart it. + print("FastAPI shutdown, maybe address already in use, exit process immediately.") + os._exit(-1) + + app = fastapi.FastAPI(lifespan=lifespan) app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) - # TODO: random sleep to reduce port conflict, retry if port is already in use - asyncio.sleep(random.uniform(0, 3)) self.port = _get_free_port() config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port) server = uvicorn.Server(config) @@ -243,6 +261,7 @@ async def _start_fastapi_server(self): await server.serve() async def get_server_address(self) -> Tuple[str, int]: + """Get FastAPI server address.""" await self.server_ready.wait() return f"{self.address}:{self.port}" @@ -304,18 +323,36 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): workers_info = ray.get(register_center.get_worker_info.remote()) assert len(workers_info) == self.worker_group.world_size - # make sure AsyncLLMWorker colocates with its corresponding workers - self.async_llm_workers = [ - AsyncLLMWorker.options( - scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( - node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], - soft=False, - ), - name=f"async_llm_worker_{rollout_dp_rank}", - ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) - for rollout_dp_rank in range(self.rollout_dp_size) - ] - self.server_addresses = ray.get([worker.get_server_address.remote() for worker in self.async_llm_workers]) + self.async_llm_workers = [None] * self.rollout_dp_size + self.server_addresses = [None] * self.rollout_dp_size + + # Start all server instances, restart if address already in use. + unready_dp_ranks = set(range(self.rollout_dp_size)) + while len(unready_dp_ranks) > 0: + workers = { + rollout_dp_rank: AsyncLLMWorker.options( + # make sure AsyncLLMWorker colocates with its corresponding workers + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], + soft=False, + ), + name=f"async_llm_worker_{rollout_dp_rank}", + ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + for rollout_dp_rank in unready_dp_ranks + } + + for rollout_dp_rank, worker in workers.items(): + try: + address = ray.get(worker.get_server_address.remote()) + self.server_addresses[rollout_dp_rank] = address + self.async_llm_workers[rollout_dp_rank] = worker + unready_dp_ranks.remove(rollout_dp_rank) + except Exception: + ray.kill(worker) + print(f"worker {rollout_dp_rank} failed, maybe address already in use, restarting...") + + # All server instances are ready, init AsyncLLM engine. + ray.get([worker.init_async_llm.remote() for worker in self.async_llm_workers]) @property def server_address(self): diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py index 8fdc94adbd0..f9e5b5d2280 100644 --- a/verl/workers/rollout/chat_scheduler.py +++ b/verl/workers/rollout/chat_scheduler.py @@ -55,7 +55,7 @@ def __init__( async def submit_chat_completions( self, - callback: Callable[[ChatCompletion, Dict[str, Any]], None], + callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], callback_additional_info: Dict[str, Any], **chat_complete_request, ): @@ -63,7 +63,16 @@ async def submit_chat_completions( Submit a chat completion request to the server with the least number of requests. Args: - callback: Callable[[ChatCompletion, Dict[str, Any]], None], async callback function to handle the response. + callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function + to handle the response. The callback function should have the following signature: + + ```python + async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): + ... + ``` + - completions: chat completion response from server. + - info: user provided `callback_additional_info`. + - exception: exception raise from OpenAI client if request failed, otherwise None. **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation, please move to seperate thread or process pool to avoid blocking the event loop. @@ -93,10 +102,15 @@ async def submit_chat_completions( self.request_id_to_address[request_id] = address chat_complete_request["extra_headers"]["x-request-id"] = request_id - # TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. - completions = await self._chat_completions_openai(address, **chat_complete_request) + completions, exception = None, None + try: + # TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. + completions = await self._chat_completions_openai(address, **chat_complete_request) + except Exception as e: + # Let user handle the exception + exception = e - await callback(completions, callback_additional_info) + await callback(completions, callback_additional_info, exception) async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: client = AsyncOpenAI( From 54ca37c702adc77e0ef4375f7e980fc9b383482c Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 24 Apr 2025 11:04:36 +0800 Subject: [PATCH 05/10] align sampling params --- verl/workers/fsdp_async_workers.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/verl/workers/fsdp_async_workers.py b/verl/workers/fsdp_async_workers.py index 39f80eaab3b..02966c4bc1b 100644 --- a/verl/workers/fsdp_async_workers.py +++ b/verl/workers/fsdp_async_workers.py @@ -25,6 +25,7 @@ from omegaconf import DictConfig from starlette.requests import Request from starlette.responses import JSONResponse, StreamingResponse +from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -181,9 +182,22 @@ async def init_async_llm(self): please increase max_num_batched_tokens or disable chunked prefill" ) + # Override default generation config from hugging face model config, + # user can still override them by passing kwargs in each request. + kwargs = dict( + n=1, + logprobs=0, + max_tokens=config.response_length, + ) + for k in config.keys(): + if hasattr(SamplingParams(), str(k)): + kwargs[k] = config.get(k) + print(f"override_generation_config: {kwargs}") + engine_args = AsyncEngineArgs( model=local_path, enable_sleep_mode=True, + override_generation_config=kwargs, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=ExternalRayDistributedExecutor, dtype=config.dtype, From ce0f7aa33964058cac0a8d6d3eb091336f63ae22 Mon Sep 17 00:00:00 2001 From: shengguangming Date: Thu, 24 Apr 2025 15:52:43 +0800 Subject: [PATCH 06/10] [misc] fix validation before train --- verl/trainer/ppo/ray_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index b2e7ddb41bd..bb61845ef15 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -951,7 +951,7 @@ async def fit(self): # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() + val_metrics = await self._validate() pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get("val_only", False): From fe2eb23e57982d2983a4bfce1f7f40c2942a4d30 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 24 Apr 2025 21:16:01 +0800 Subject: [PATCH 07/10] reset prefix cache before sleep --- verl/workers/fsdp_async_workers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/verl/workers/fsdp_async_workers.py b/verl/workers/fsdp_async_workers.py index 02966c4bc1b..9d49b0916a9 100644 --- a/verl/workers/fsdp_async_workers.py +++ b/verl/workers/fsdp_async_workers.py @@ -269,9 +269,8 @@ async def lifespan(app: fastapi.FastAPI): app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port) + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") server = uvicorn.Server(config) - self.server_ready.set() await server.serve() async def get_server_address(self) -> Tuple[str, int]: @@ -283,6 +282,8 @@ async def wake_up(self): await self.engine.wake_up() async def sleep(self): + # TODO: https://github.com/vllm-project/vllm/issues/17103 + await self.engine.reset_prefix_cache() await self.engine.sleep() From c07f06a6a0633132a0ac25cdf463140787eed0c3 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Fri, 25 Apr 2025 11:43:14 +0800 Subject: [PATCH 08/10] abstract AsyncServer base class --- .github/workflows/vllm.yml | 6 +- examples/ppo_trainer/naive_chat_scheduler.py | 6 +- tests/rollout/test_vllm_multi_turn.py | 20 +- verl/trainer/main_ppo.py | 14 +- verl/trainer/ppo/ray_trainer.py | 31 +- verl/utils/vllm_utils.py | 8 +- verl/workers/fsdp_workers.py | 56 +++- verl/workers/rollout/async_server.py | 314 ++++++++++++++++++ verl/workers/rollout/chat_scheduler.py | 136 -------- verl/workers/rollout/vllm_rollout/__init__.py | 17 +- .../vllm_rollout/vllm_async_server.py} | 154 +-------- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 7 +- verl/workers/sharding_manager/fsdp_vllm.py | 22 +- .../workers/sharding_manager/megatron_vllm.py | 16 +- 14 files changed, 439 insertions(+), 368 deletions(-) create mode 100644 verl/workers/rollout/async_server.py delete mode 100644 verl/workers/rollout/chat_scheduler.py rename verl/workers/{fsdp_async_workers.py => rollout/vllm_rollout/vllm_async_server.py} (61%) diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 869b4e7f3ec..e58bfa6a088 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -84,4 +84,8 @@ jobs: cd tests/generation export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet" MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=1 GEN_TP=1 bash ./run_gen_qwen05.sh - rm -rf "${OUTPUT_PATH}" \ No newline at end of file + rm -rf "${OUTPUT_PATH}" + - name: Running multi-turn rollout tests on 8 L20 GPUs + run: | + pip3 install --upgrade vllm==0.8.3 + python3 tests/rollout/test_vllm_multi_turn.py diff --git a/examples/ppo_trainer/naive_chat_scheduler.py b/examples/ppo_trainer/naive_chat_scheduler.py index 5e00bca30cc..15f4323ebbd 100644 --- a/examples/ppo_trainer/naive_chat_scheduler.py +++ b/examples/ppo_trainer/naive_chat_scheduler.py @@ -18,10 +18,9 @@ from omegaconf import DictConfig from openai.types.chat.chat_completion import ChatCompletion from tensordict import TensorDict -from transformers import PreTrainedTokenizer from verl.protocol import DataProto -from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler +from verl.workers.rollout.async_server import ChatCompletionScheduler class NaiveChatCompletionScheduler(ChatCompletionScheduler): @@ -34,11 +33,10 @@ def __init__( self, config: DictConfig, model_path: str, - tokenizer: PreTrainedTokenizer, server_addresses: List[str], max_cache_size: int = 10000, ): - super().__init__(config, model_path, tokenizer, server_addresses, max_cache_size) + super().__init__(config, model_path, server_addresses, max_cache_size) async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: kwargs = dict( diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py index 43990ed4dc4..9869128bb73 100644 --- a/tests/rollout/test_vllm_multi_turn.py +++ b/tests/rollout/test_vllm_multi_turn.py @@ -22,8 +22,8 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role -from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker, AsyncLLMManager -from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler +from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker +from verl.workers.rollout.async_server import AsyncLLMServerManager async def test_vllm_multi_turn(): @@ -32,6 +32,7 @@ async def test_vllm_multi_turn(): model_name = "/".join(model_path.split("/")[-2:]) config.actor_rollout_ref.model.path = model_path config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.chat_scheduler = "verl.workers.rollout.async_server.ChatCompletionScheduler" config.actor_rollout_ref.rollout.prompt_length = 4096 config.actor_rollout_ref.rollout.response_length = 4096 @@ -47,6 +48,9 @@ async def test_vllm_multi_turn(): "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "VLLM_USE_V1": "1", + "no_proxy": "", + "http_proxy": "", + "https_proxy": "", } } ) @@ -82,18 +86,12 @@ async def test_vllm_multi_turn(): actor_rollout_wg = all_wg["actor_rollout"] actor_rollout_wg.init_model() - # =========================== 2. Create AsyncLLMManager&ChatScheduler =========================== - async_rollout_manager = AsyncLLMManager( + # =========================== 2. Create AsyncLLMServerManager&ChatScheduler =========================== + async_rollout_manager = AsyncLLMServerManager( config=config.actor_rollout_ref, worker_group=actor_rollout_wg, ) - - async_chat_scheduler = ChatCompletionScheduler( - config=config.actor_rollout_ref.rollout, - model_path=config.actor_rollout_ref.model.path, - tokenizer=None, - server_addresses=async_rollout_manager.server_addresses, - ) + async_chat_scheduler = async_rollout_manager.chat_scheduler # test sleep and wake_up await async_rollout_manager.sleep() diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index f6e57530cfd..5ab0f688475 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -83,7 +83,6 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: - async def run(self, config): # print initial config from pprint import pprint @@ -109,18 +108,21 @@ async def run(self, config): if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) ray_worker_group_cls = RayWorkerGroup - if config.actor_rollout_ref.rollout.mode == "async": - from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker as ActorRolloutRefWorker - elif config.actor_rollout_ref.actor.strategy == "megatron": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + actor_rollout_cls = ActorRolloutRefWorker ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -129,7 +131,7 @@ async def run(self, config): from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.ActorRollout: ray.remote(actor_rollout_cls), Role.Critic: ray.remote(CriticWorker), } diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index bb61845ef15..35fefdf17a7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -17,7 +17,6 @@ """ import json -import importlib import os import uuid from collections import defaultdict @@ -57,6 +56,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.rollout.async_server import AsyncLLMServerManager WorkerType = Type[Worker] @@ -635,7 +635,9 @@ async def _validate(self): test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) else: await self.async_rollout_manager.wake_up() - test_output_gen_batch_padded = await self.async_chat_scheduler.generate_sequences(test_gen_batch_padded) + test_output_gen_batch_padded = await self.async_rollout_manager.generate_sequences( + test_gen_batch_padded + ) await self.async_rollout_manager.sleep() # unpad @@ -779,22 +781,11 @@ def init_workers(self): # create async rollout manager and request scheduler self.async_rollout_mode = False if self.config.actor_rollout_ref.rollout.mode == "async": - from verl.workers.fsdp_async_workers import AsyncLLMManager - self.async_rollout_mode = True - self.async_rollout_manager = AsyncLLMManager( + self.async_rollout_manager = AsyncLLMServerManager( config=self.config.actor_rollout_ref, worker_group=self.actor_rollout_wg, ) - module_path, class_name = self.config.actor_rollout_ref.rollout.chat_scheduler.rsplit(".", 1) - module = importlib.import_module(module_path) - scheduler_cls = getattr(module, class_name) - self.async_chat_scheduler = scheduler_cls( - config=self.config.actor_rollout_ref.rollout, - model_path=self.config.actor_rollout_ref.model.path, - tokenizer=self.tokenizer, - server_addresses=self.async_rollout_manager.server_addresses, - ) def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` @@ -992,7 +983,7 @@ async def fit(self): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) else: await self.async_rollout_manager.wake_up() - gen_batch_output = await self.async_chat_scheduler.generate_sequences(gen_batch) + gen_batch_output = await self.async_rollout_manager.generate_sequences(gen_batch) await self.async_rollout_manager.sleep() if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: @@ -1148,10 +1139,12 @@ async def fit(self): self._save_checkpoint() # training metrics - metrics.update({ - 'training/global_step': self.global_steps, - 'training/epoch': epoch, - }) + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) diff --git a/verl/utils/vllm_utils.py b/verl/utils/vllm_utils.py index ee518d5e33e..e395db1cf57 100644 --- a/verl/utils/vllm_utils.py +++ b/verl/utils/vllm_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def patch_vllm_moe_model_weight_loader(model): # this is a work around to load the weight of vllm fused moe model # it is from a bug from vllm 0.8.2 @@ -29,10 +30,9 @@ def patch_vllm_moe_model_weight_loader(model): # (False, 'model.layers.0.post_attention_layernorm.weight') use default # (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader # (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader - from vllm.model_executor.models.deepseek_v2 import (DeepseekV2ForCausalLM, - DeepseekV3ForCausalLM) + from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM - + if not isinstance(model, (Qwen2MoeForCausalLM, DeepseekV2ForCausalLM, DeepseekV3ForCausalLM)): return for layer in model.model.layers: @@ -40,4 +40,4 @@ def patch_vllm_moe_model_weight_loader(model): param_dict = dict(mlp.named_parameters()) for name, param in param_dict.items(): if "w13_weight" in name or "w2_weight" in name: - param.weight_loader = mlp.experts.weight_loader \ No newline at end of file + param.weight_loader = mlp.experts.weight_loader diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 1890e3842d9..b565f4b4e62 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -18,6 +18,7 @@ import logging import os import warnings +from typing import Union import psutil import torch @@ -125,7 +126,8 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_mini_batch_size *= self.config.rollout.n self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size assert self.config.actor.ppo_mini_batch_size > 0, ( - f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization" + f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} " + "should be larger than 0 after normalization" ) # micro bsz if self.config.actor.ppo_micro_batch_size is not None: @@ -136,10 +138,12 @@ def __init__(self, config: DictConfig, role: str): if self.config.actor.ppo_micro_batch_size_per_gpu is not None: assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( - f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be " + f"divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( - f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be " + f"larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) # normalize rollout config @@ -389,10 +393,12 @@ def _build_rollout(self, trust_remote_code=False): elif rollout_name == "sglang": from verl.workers.rollout.sglang_rollout import SGLangRollout - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. - # However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to: + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to + # SGLang's model_runner would check CUDA device capability. However, due to veRL's setting, + # the main process of ray can not find any CUDA device, which would potentially lead to: # "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path. + # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and + # we import it here use the abs path. # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager @@ -713,10 +719,12 @@ def __init__(self, config): if self.config.ppo_micro_batch_size_per_gpu is not None: assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( - f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be " + f"divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( - f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be " + f"larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) def _build_critic_model_optimizer(self, config): @@ -1280,3 +1288,35 @@ def compute_rm_score(self, data: DataProto): output = output.to("cpu") return output + + +# ================================= Async related workers ================================= +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + def _build_rollout(self, trust_remote_code=False): + rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) + + # NOTE: rollout is not actually initialized here, it's deferred + # to be initialized by AsyncvLLMServer. + + self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size + self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size + self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size + + # used for sleep/wake_up + rollout.sharding_manager = rollout_sharding_manager + + return rollout, rollout_sharding_manager + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def generate_sequences(self, prompts: DataProto): + raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences") + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + """Called by ExternalRayDistributedExecutor collective_rpc.""" + if self.vllm_tp_rank == 0 and method != "execute_model": + print( + f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] " + f"execute_method: {method if isinstance(method, str) else 'Callable'}" + ) + return self.rollout.execute_method(method, *args, **kwargs) diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py new file mode 100644 index 00000000000..a15f71b5d14 --- /dev/null +++ b/verl/workers/rollout/async_server.py @@ -0,0 +1,314 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import heapq +import importlib +import logging +import os +import socket +from abc import ABC, abstractmethod +from contextlib import asynccontextmanager +from typing import Any, Callable, Dict, List, Tuple, Type +from uuid import uuid4 + +import aiohttp +import fastapi +import ray +import uvicorn +from cachetools import LRUCache +from omegaconf import DictConfig +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion +from starlette.requests import Request + +from verl.protocol import DataProto +from verl.single_controller.ray.base import RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local + +logger = logging.getLogger(__file__) + + +def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +class AsyncServerBase(ABC): + """Base class for AsyncServer.""" + + def __init__(self): + self.address = ray._private.services.get_node_ip_address() + self.port = None + self.server_ready = asyncio.Event() + asyncio.create_task(self._start_fastapi_server()) + + async def _start_fastapi_server(self): + @asynccontextmanager + async def lifespan(app: fastapi.FastAPI): + print("FastAPI startup") + self.server_ready.set() + yield + + # There's no way to gracefully restart uvicorn server if port is already in use, + # so we exit the process directly and let AsyncLLMServerManager restart it. + print("FastAPI shutdown, maybe address already in use, exit process immediately.") + os._exit(-1) + + app = fastapi.FastAPI(lifespan=lifespan) + app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) + + self.port = _get_free_port() + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> Tuple[str, int]: + """Get FastAPI server address.""" + await self.server_ready.wait() + return f"{self.address}:{self.port}" + + @abstractmethod + async def chat_completion(self, raw_request: Request): + """OpenAI chat completion API. + + API reference: https://platform.openai.com/docs/api-reference/chat/create + """ + raise NotImplementedError + + @abstractmethod + async def init_engine(self): + """Init async LLM engine.""" + raise NotImplementedError + + @abstractmethod + async def wake_up(self): + """Wake up engine to load model weights and build kv cache.""" + raise NotImplementedError + + @abstractmethod + async def sleep(self): + """Sleep engine to offload model weights and discard kv cache.""" + raise NotImplementedError + + +class ChatCompletionScheduler: + def __init__( + self, + config: DictConfig, + model_path: str, + server_addresses: List[str], + max_cache_size: int = 10000, + ): + """ + Args: + config: DictConfig, rollout config. + model_path: str, model path. + server_addresses: List[str], server addresses. + max_cache_size: int, max cache size of request_id to address mapping. + """ + self.config = config + self.model_name = "/".join(model_path.split("/")[-2:]) + local_path = copy_to_local(model_path) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + + # Least requests load balancing + self.weighted_addresses = [[0, address] for address in server_addresses] + heapq.heapify(self.weighted_addresses) + + # LRU cache to map request_id to address + self.request_id_to_address = LRUCache(maxsize=max_cache_size) + + async def submit_chat_completions( + self, + callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], + callback_additional_info: Dict[str, Any], + **chat_complete_request, + ): + """ + Submit a chat completion request to the server with the least number of requests. + + Args: + callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function + to handle the response. The callback function should have the following signature: + + ```python + async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): + ... + ``` + - completions: chat completion response from server. + - info: user provided `callback_additional_info`. + - exception: exception raise from OpenAI client if request failed, otherwise None. + + **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation, + please move to seperate thread or process pool to avoid blocking the event loop. + + callback_additional_info: Dict[str, Any], additional info to pass to the callback function. + + **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create. + OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create + """ + if "extra_headers" not in chat_complete_request: + chat_complete_request["extra_headers"] = {} + + extra_headers = chat_complete_request["extra_headers"] + request_id = extra_headers.get("x-request-id", None) + if request_id: + if request_id.startswith("chatcmpl-"): + request_id = request_id[len("chatcmpl-") :] + extra_headers["x-request-id"] = request_id + + address = self.request_id_to_address[request_id] + else: + address = self.weighted_addresses[0][1] + self.weighted_addresses[0][0] += 1 + heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) + + request_id = uuid4().hex + self.request_id_to_address[request_id] = address + chat_complete_request["extra_headers"]["x-request-id"] = request_id + + completions, exception = None, None + try: + # TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. + completions = await self._chat_completions_openai(address, **chat_complete_request) + except Exception as e: + # Let user handle the exception + exception = e + + await callback(completions, callback_additional_info, exception) + + async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: + client = AsyncOpenAI( + base_url=f"http://{address}/v1", + api_key="token-abc123", + ) + return await client.chat.completions.create(**chat_complete_request) + + async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: + try: + session = aiohttp.ClientSession() + async with session.post( + url=f"http://{address}/v1/chat/completions", + headers={"Authorization": "Bearer token-abc123"}, + json=chat_complete_request, + ) as resp: + data = await resp.json() + return ChatCompletion(**data) + finally: + await session.close() + + async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: + raise NotImplementedError + + +class AsyncLLMServerManager: + """AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer.""" + + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): + """Initialize AsyncLLMServerManager. + + Args: + config: DictConfig, actor_rollout_ref config. + worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. + """ + self.config = config + self.worker_group = worker_group + + self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size + self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size + + register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") + workers_info = ray.get(register_center.get_worker_info.remote()) + assert len(workers_info) == self.worker_group.world_size + + self.async_llm_workers = [None] * self.rollout_dp_size + self.server_addresses = [None] * self.rollout_dp_size + + server_class = async_server_class( + rollout_backend=self.config.rollout.name, + ) + + # Start all server instances, restart if address already in use. + unready_dp_ranks = set(range(self.rollout_dp_size)) + while len(unready_dp_ranks) > 0: + workers = { + rollout_dp_rank: server_class.options( + # make sure AsyncvLLMServer colocates with its corresponding workers + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], + soft=False, + ), + name=f"async_llm_worker_{rollout_dp_rank}", + ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + for rollout_dp_rank in unready_dp_ranks + } + + for rollout_dp_rank, worker in workers.items(): + try: + address = ray.get(worker.get_server_address.remote()) + self.server_addresses[rollout_dp_rank] = address + self.async_llm_workers[rollout_dp_rank] = worker + unready_dp_ranks.remove(rollout_dp_rank) + except Exception: + ray.kill(worker) + print(f"worker {rollout_dp_rank} failed, maybe address already in use, restarting...") + + # All server instances are ready, init AsyncLLM engine. + ray.get([worker.init_engine.remote() for worker in self.async_llm_workers]) + + # Init user provided chat scheduler. + self.chat_scheduler = self._init_chat_scheduler() + + def _init_chat_scheduler(self) -> ChatCompletionScheduler: + module_path, class_name = self.config.rollout.chat_scheduler.rsplit(".", 1) + module = importlib.import_module(module_path) + scheduler_cls = getattr(module, class_name) + return scheduler_cls( + config=self.config.rollout, + model_path=self.config.model.path, + server_addresses=self.server_addresses, + ) + + async def wake_up(self): + """Wake up all vllm instances.""" + await asyncio.gather(*[worker.wake_up.remote() for worker in self.async_llm_workers]) + + async def sleep(self): + """Sleep all vllm instances.""" + await asyncio.gather(*[worker.sleep.remote() for worker in self.async_llm_workers]) + + async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: + """Generate sequences via chat scheduler.""" + return await self.chat_scheduler.generate_sequences(prompts, **sampling_params) + + +def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]: + """Get async server class. + + Args: + rollout_backend: str, rollout backend, should be "vllm" or "sglang". + + Returns: + Type[AsyncServerBase]: async server class. + """ + if rollout_backend == "vllm": + from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer + + return AsyncvLLMServer + elif rollout_backend == "sglang": + raise NotImplementedError + else: + raise NotImplementedError diff --git a/verl/workers/rollout/chat_scheduler.py b/verl/workers/rollout/chat_scheduler.py deleted file mode 100644 index f9e5b5d2280..00000000000 --- a/verl/workers/rollout/chat_scheduler.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import heapq -from typing import Any, Callable, Dict, List -from uuid import uuid4 - -import aiohttp -from cachetools import LRUCache -from omegaconf import DictConfig -from openai import AsyncOpenAI -from openai.types.chat.chat_completion import ChatCompletion -from transformers import PreTrainedTokenizer - -from verl.protocol import DataProto - - -class ChatCompletionScheduler: - def __init__( - self, - config: DictConfig, - model_path: str, - tokenizer: PreTrainedTokenizer, - server_addresses: List[str], - max_cache_size: int = 10000, - ): - """ - Args: - config: DictConfig, rollout config. - model_path: str, model path. - tokenizer: PreTrainedTokenizer, tokenizer. - server_addresses: List[str], server addresses. - max_cache_size: int, max cache size of request_id to address mapping. - """ - self.config = config - self.model_name = "/".join(model_path.split("/")[-2:]) - self.tokenizer = tokenizer - - # Least requests load balancing - self.weighted_addresses = [[0, address] for address in server_addresses] - heapq.heapify(self.weighted_addresses) - - # LRU cache to map request_id to address - self.request_id_to_address = LRUCache(maxsize=max_cache_size) - - async def submit_chat_completions( - self, - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], - callback_additional_info: Dict[str, Any], - **chat_complete_request, - ): - """ - Submit a chat completion request to the server with the least number of requests. - - Args: - callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function - to handle the response. The callback function should have the following signature: - - ```python - async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): - ... - ``` - - completions: chat completion response from server. - - info: user provided `callback_additional_info`. - - exception: exception raise from OpenAI client if request failed, otherwise None. - - **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation, - please move to seperate thread or process pool to avoid blocking the event loop. - - callback_additional_info: Dict[str, Any], additional info to pass to the callback function. - - **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create. - OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create - """ - if "extra_headers" not in chat_complete_request: - chat_complete_request["extra_headers"] = {} - - extra_headers = chat_complete_request["extra_headers"] - request_id = extra_headers.get("x-request-id", None) - if request_id: - if request_id.startswith("chatcmpl-"): - request_id = request_id[len("chatcmpl-") :] - extra_headers["x-request-id"] = request_id - - address = self.request_id_to_address[request_id] - else: - address = self.weighted_addresses[0][1] - self.weighted_addresses[0][0] += 1 - heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) - - request_id = uuid4().hex - self.request_id_to_address[request_id] = address - chat_complete_request["extra_headers"]["x-request-id"] = request_id - - completions, exception = None, None - try: - # TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. - completions = await self._chat_completions_openai(address, **chat_complete_request) - except Exception as e: - # Let user handle the exception - exception = e - - await callback(completions, callback_additional_info, exception) - - async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: - client = AsyncOpenAI( - base_url=f"http://{address}/v1", - api_key="token-abc123", - ) - return await client.chat.completions.create(**chat_complete_request) - - async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: - try: - session = aiohttp.ClientSession() - async with session.post( - url=f"http://{address}/v1/chat/completions", - headers={"Authorization": "Bearer token-abc123"}, - json=chat_complete_request, - ) as resp: - data = await resp.json() - return ChatCompletion(**data) - finally: - await session.close() - - async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: - raise NotImplementedError diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index dbf5cbd9fc0..66da822ea86 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -11,15 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from importlib.metadata import PackageNotFoundError, version -### -# [SUPPORT AMD:] -import torch - -### - def get_version(pkg): try: @@ -34,7 +28,8 @@ def get_version(pkg): ### # package_version = get_version(package_name) # [SUPPORT AMD:] -if "AMD" in torch.cuda.get_device_name(): +# Do not call any torch.cuda* API here, or ray actor creation import class will fail. +if "ROCM_PATH" in os.environ: import re package_version = version(package_name) @@ -45,8 +40,8 @@ def get_version(pkg): if package_version <= "0.6.3": vllm_mode = "customized" - from .fire_vllm_rollout import FIREvLLMRollout - from .vllm_rollout import vLLMRollout + from .fire_vllm_rollout import FIREvLLMRollout # noqa: F401 + from .vllm_rollout import vLLMRollout # noqa: F401 else: vllm_mode = "spmd" - from .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout + from .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout # noqa: F401 diff --git a/verl/workers/fsdp_async_workers.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py similarity index 61% rename from verl/workers/fsdp_async_workers.py rename to verl/workers/rollout/vllm_rollout/vllm_async_server.py index 9d49b0916a9..4f551c07116 100644 --- a/verl/workers/fsdp_async_workers.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -11,17 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging -import os -import socket -from contextlib import asynccontextmanager from typing import Any, Callable, Dict, List, Optional, Tuple, Union import cloudpickle -import fastapi import ray -import uvicorn from omegaconf import DictConfig from starlette.requests import Request from starlette.responses import JSONResponse, StreamingResponse @@ -34,21 +28,12 @@ from vllm.v1.executor.abstract import Executor from vllm.worker.worker_base import WorkerWrapperBase -from verl import DataProto -from verl.single_controller.base.decorator import Dispatch, register -from verl.single_controller.ray.base import RayWorkerGroup from verl.utils.fs import copy_to_local -from verl.workers.fsdp_workers import ActorRolloutRefWorker +from verl.workers.rollout.async_server import AsyncServerBase logger = logging.getLogger(__file__) -def _get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - class ExternalRayDistributedExecutor(Executor): """An executor that engines are launched by external ray actors.""" @@ -126,12 +111,12 @@ def check_health(self): @ray.remote(num_cpus=1) -class AsyncLLMWorker: +class AsyncvLLMServer(AsyncServerBase): """ - AsyncLLMWorker is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines + AsyncvLLMServer is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines in hybrid rollout workers, i.e AsyncActorRolloutRefWorker. - It works as follows: + AsyncvLLMServer works as follows: 1. Start FastAPI server first. 2. Initialize AsyncLLM with ExternalRayDistributedExecutor. 3. AsyncLLM spawn EngineCore in subprocess. @@ -150,19 +135,15 @@ def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_ vllm_dp_rank: int, vllm data parallel rank. wg_prefix: str, worker group prefix, used to lookup actors. """ + super().__init__() + self.config = config self.vllm_dp_size = vllm_dp_size self.vllm_dp_rank = vllm_dp_rank self.wg_prefix = wg_prefix self.engine: AsyncLLM = None - # start FastAPI server - self.address = ray._private.services.get_node_ip_address() - self.port = None - self.server_ready = asyncio.Event() - asyncio.create_task(self._start_fastapi_server()) - - async def init_async_llm(self): + async def init_engine(self): """Init vLLM AsyncLLM engine.""" config = self.config model_path = config.model.path @@ -253,31 +234,6 @@ async def chat_completion(self, raw_request: Request): assert isinstance(generator, ChatCompletionResponse) return JSONResponse(content=generator.model_dump()) - async def _start_fastapi_server(self): - @asynccontextmanager - async def lifespan(app: fastapi.FastAPI): - print("FastAPI startup") - self.server_ready.set() - yield - - # There's no way to gracefully restart uvicorn server if port is already in use, - # so we exit the process directly and let AsyncLLMManager restart it. - print("FastAPI shutdown, maybe address already in use, exit process immediately.") - os._exit(-1) - - app = fastapi.FastAPI(lifespan=lifespan) - app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) - - self.port = _get_free_port() - config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") - server = uvicorn.Server(config) - await server.serve() - - async def get_server_address(self) -> Tuple[str, int]: - """Get FastAPI server address.""" - await self.server_ready.wait() - return f"{self.address}:{self.port}" - async def wake_up(self): await self.engine.wake_up() @@ -285,99 +241,3 @@ async def sleep(self): # TODO: https://github.com/vllm-project/vllm/issues/17103 await self.engine.reset_prefix_cache() await self.engine.sleep() - - -class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): - def _build_rollout(self, trust_remote_code=False): - rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) - - # NOTE: rollout is not actually initialized here, it's deferred - # to be initialized by AsyncLLMWorker. - - self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size - self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size - self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size - - # used for sleep/wake_up - rollout.sharding_manager = rollout_sharding_manager - - return rollout, rollout_sharding_manager - - @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) - def generate_sequences(self, prompts: DataProto): - raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences") - - @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) - def execute_method(self, method: Union[str, bytes], *args, **kwargs): - """Called by ExternalRayDistributedExecutor collective_rpc.""" - if self.vllm_tp_rank == 0 and method != "execute_model": - print( - f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] " - f"execute_method: {method if isinstance(method, str) else 'Callable'}" - ) - return self.rollout.execute_method(method, *args, **kwargs) - - -class AsyncLLMManager: - """AsyncLLMManager manage a group of vllm instances, i.e AsyncLLMWorker.""" - - def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): - """Initialize AsyncLLMManager. - - Args: - config: DictConfig, actor_rollout_ref config. - worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. - """ - self.config = config - self.worker_group = worker_group - - self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size - self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size - - register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") - workers_info = ray.get(register_center.get_worker_info.remote()) - assert len(workers_info) == self.worker_group.world_size - - self.async_llm_workers = [None] * self.rollout_dp_size - self.server_addresses = [None] * self.rollout_dp_size - - # Start all server instances, restart if address already in use. - unready_dp_ranks = set(range(self.rollout_dp_size)) - while len(unready_dp_ranks) > 0: - workers = { - rollout_dp_rank: AsyncLLMWorker.options( - # make sure AsyncLLMWorker colocates with its corresponding workers - scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( - node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], - soft=False, - ), - name=f"async_llm_worker_{rollout_dp_rank}", - ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) - for rollout_dp_rank in unready_dp_ranks - } - - for rollout_dp_rank, worker in workers.items(): - try: - address = ray.get(worker.get_server_address.remote()) - self.server_addresses[rollout_dp_rank] = address - self.async_llm_workers[rollout_dp_rank] = worker - unready_dp_ranks.remove(rollout_dp_rank) - except Exception: - ray.kill(worker) - print(f"worker {rollout_dp_rank} failed, maybe address already in use, restarting...") - - # All server instances are ready, init AsyncLLM engine. - ray.get([worker.init_async_llm.remote() for worker in self.async_llm_workers]) - - @property - def server_address(self): - """Ruturn FastAPI server addresses of all vllm instances.""" - return self.server_addresses - - async def wake_up(self): - """Wake up all vllm instances.""" - await asyncio.gather(*[worker.wake_up.remote() for worker in self.async_llm_workers]) - - async def sleep(self): - """Sleep all vllm instances.""" - await asyncio.gather(*[worker.sleep.remote() for worker in self.async_llm_workers]) diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 865df0abb9b..9b9567cca37 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -19,7 +19,8 @@ When working with Megatron: - Use Megatron weight loader - During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters) +- Before inference, broadcast the parameters of the current pp rank + to all other pp ranks (all pp ranks holds all the parameters) - Bind the parameters to the inference engine - Do inference in tp. pp is treated as additional dp - After inference, all the parameters that doesn't belong to this pp rank is freed. @@ -35,7 +36,6 @@ import torch.distributed from omegaconf import DictConfig from tensordict import TensorDict -from torch import nn from vllm import LLM, SamplingParams from vllm.distributed import parallel_state as vllm_ps from vllm.worker.worker_base import WorkerWrapperBase @@ -59,7 +59,8 @@ # NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]: # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id + # is not None else self.llm_engine.tokenizer.eos_token_id non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 315d6525bed..1a9b587f6d0 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -18,19 +18,16 @@ import torch from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.api import (FullStateDictConfig, - ShardedStateDictConfig, StateDictType) -from torch.distributed.fsdp.fully_sharded_data_parallel import \ - FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from verl import DataProto from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM +from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps -from verl.third_party.vllm import vllm_version from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage -from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu +from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from .base import BaseShardingManager @@ -146,10 +143,6 @@ def __exit__(self, exc_type, exc_value, traceback): else: self.inference_engine.sleep(level=1) - # self.module.to('cuda') - # if torch.distributed.get_rank() == 0: - # print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - self.module.train() # add empty cache after each compute @@ -191,6 +184,9 @@ def update_params(self, updated_params): patch_vllm_moe_model_weight_loader(model) world_size = torch.distributed.get_world_size() loaded_params = model.load_weights( - ((name, param.full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items()) + ( + (name, param.full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) + for name, param in updated_params.items() + ) ) - logger.info(f"vLLM load weights, loaded_params: {len(loaded_params)}") + logger.info("vLLM load weights, loaded_params: %d", len(loaded_params)) diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index e919ba634fb..74f827b0c2e 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -36,11 +36,17 @@ from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.debug import GPUMemoryLogger from verl.utils.megatron_utils import ( - broadcast_from_megatron_pp, broadcast_str_from_megatron_pp, - convert_megatron_model_to_transformers_model, get_model, unwrap_model) -from verl.utils.memory_buffer import (build_memory_buffer, - build_memory_reference_from_module, - get_weight_buffer_meta_from_module) + broadcast_from_megatron_pp, + broadcast_str_from_megatron_pp, + convert_megatron_model_to_transformers_model, + get_model, + unwrap_model, +) +from verl.utils.memory_buffer import ( + build_memory_buffer, + build_memory_reference_from_module, + get_weight_buffer_meta_from_module, +) from verl.utils.model import normalize_model_name from verl.utils.torch_functional import allgather_dict_tensors from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader From 9d49758bee06e0f184ea5a3292f459f3d4b75dc4 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Fri, 25 Apr 2025 14:39:55 +0800 Subject: [PATCH 09/10] fix unit test --- .github/workflows/vllm.yml | 2 +- verl/trainer/config/generation.yaml | 2 +- verl/trainer/config/ppo_megatron_trainer.yaml | 2 +- verl/trainer/config/ppo_trainer.yaml | 2 +- verl/workers/rollout/async_server.py | 10 +++++----- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index e58bfa6a088..ee355b2fe1d 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -87,5 +87,5 @@ jobs: rm -rf "${OUTPUT_PATH}" - name: Running multi-turn rollout tests on 8 L20 GPUs run: | - pip3 install --upgrade vllm==0.8.3 + pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2 python3 tests/rollout/test_vllm_multi_turn.py diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index 3d74db9db69..d3068e886c9 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -14,7 +14,7 @@ model: external_lib: null rollout: name: vllm - mode: "sync" # sync: LLM, async: AsyncLLM + mode: sync # sync: LLM, async: AsyncLLM temperature: 1.0 top_k: 50 # 0 for hf rollout, -1 for vllm rollout top_p: 0.7 diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index b87695213b8..838663886bc 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -94,7 +94,7 @@ actor_rollout_ref: log_prob_micro_batch_size_per_gpu: null rollout: name: vllm - mode: "sync" # sync: LLM, async: AsyncLLM + mode: sync # sync: LLM, async: AsyncLLM temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index ee2d181fee5..ca14b378bf9 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -83,7 +83,7 @@ actor_rollout_ref: ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size rollout: name: vllm - mode: "sync" # sync: LLM, async: AsyncLLM + mode: sync # sync: LLM, async: AsyncLLM chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py index a15f71b5d14..866ca366dfd 100644 --- a/verl/workers/rollout/async_server.py +++ b/verl/workers/rollout/async_server.py @@ -234,7 +234,7 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): workers_info = ray.get(register_center.get_worker_info.remote()) assert len(workers_info) == self.worker_group.world_size - self.async_llm_workers = [None] * self.rollout_dp_size + self.async_llm_servers = [None] * self.rollout_dp_size self.server_addresses = [None] * self.rollout_dp_size server_class = async_server_class( @@ -260,14 +260,14 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): try: address = ray.get(worker.get_server_address.remote()) self.server_addresses[rollout_dp_rank] = address - self.async_llm_workers[rollout_dp_rank] = worker + self.async_llm_servers[rollout_dp_rank] = worker unready_dp_ranks.remove(rollout_dp_rank) except Exception: ray.kill(worker) print(f"worker {rollout_dp_rank} failed, maybe address already in use, restarting...") # All server instances are ready, init AsyncLLM engine. - ray.get([worker.init_engine.remote() for worker in self.async_llm_workers]) + ray.get([worker.init_engine.remote() for worker in self.async_llm_servers]) # Init user provided chat scheduler. self.chat_scheduler = self._init_chat_scheduler() @@ -284,11 +284,11 @@ def _init_chat_scheduler(self) -> ChatCompletionScheduler: async def wake_up(self): """Wake up all vllm instances.""" - await asyncio.gather(*[worker.wake_up.remote() for worker in self.async_llm_workers]) + await asyncio.gather(*[worker.wake_up.remote() for worker in self.async_llm_servers]) async def sleep(self): """Sleep all vllm instances.""" - await asyncio.gather(*[worker.sleep.remote() for worker in self.async_llm_workers]) + await asyncio.gather(*[worker.sleep.remote() for worker in self.async_llm_servers]) async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: """Generate sequences via chat scheduler.""" From edbeb429c9759ece5b14211f4bd29d3ed6671fcf Mon Sep 17 00:00:00 2001 From: wuxibin Date: Fri, 25 Apr 2025 15:33:21 +0800 Subject: [PATCH 10/10] fix unit test --- tests/rollout/test_vllm_multi_turn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py index 9869128bb73..bc6a154ed01 100644 --- a/tests/rollout/test_vllm_multi_turn.py +++ b/tests/rollout/test_vllm_multi_turn.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import os from typing import Any, Dict import ray @@ -41,6 +42,11 @@ async def test_vllm_multi_turn(): config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True # =========================== 1. Create hybrid ActorRollout workers =========================== + # make openai client happy + os.environ["no_proxy"] = "" + os.environ["http_proxy"] = "" + os.environ["https_proxy"] = "" + ray.init( runtime_env={ "env_vars": { @@ -48,9 +54,6 @@ async def test_vllm_multi_turn(): "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN", "VLLM_USE_V1": "1", - "no_proxy": "", - "http_proxy": "", - "https_proxy": "", } } )