diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 869b4e7f3ec..ee355b2fe1d 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -84,4 +84,8 @@ jobs: cd tests/generation export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet" MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=1 GEN_TP=1 bash ./run_gen_qwen05.sh - rm -rf "${OUTPUT_PATH}" \ No newline at end of file + rm -rf "${OUTPUT_PATH}" + - name: Running multi-turn rollout tests on 8 L20 GPUs + run: | + pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2 + python3 tests/rollout/test_vllm_multi_turn.py diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh index 2c6d6477e28..2976534698a 100644 --- a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh @@ -3,10 +3,18 @@ set -x # If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs: # export VLLM_ATTENTION_BACKEND=XFORMERS +# For async rollout mode, dataset should return raw chat. +rollout_mode="sync" +if [ "$rollout_mode" = "async" ]; then + return_raw_chat="True" + chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler +fi + python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ + data.return_raw_chat=$return_raw_chat \ data.train_batch_size=1024 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ @@ -27,6 +35,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ diff --git a/examples/ppo_trainer/naive_chat_scheduler.py b/examples/ppo_trainer/naive_chat_scheduler.py new file mode 100644 index 00000000000..15f4323ebbd --- /dev/null +++ b/examples/ppo_trainer/naive_chat_scheduler.py @@ -0,0 +1,151 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from typing import Any, Dict, List + +import torch +from omegaconf import DictConfig +from openai.types.chat.chat_completion import ChatCompletion +from tensordict import TensorDict + +from verl.protocol import DataProto +from verl.workers.rollout.async_server import ChatCompletionScheduler + + +class NaiveChatCompletionScheduler(ChatCompletionScheduler): + """ + A very naive implementation of ChatCompletionScheduler for demo purpose, + only do single-turn chat completion. + """ + + def __init__( + self, + config: DictConfig, + model_path: str, + server_addresses: List[str], + max_cache_size: int = 10000, + ): + super().__init__(config, model_path, server_addresses, max_cache_size) + + async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto: + kwargs = dict( + n=self.config.n, + max_completion_tokens=self.config.response_length, + temperature=self.config.temperature, + top_p=self.config.top_p, + ) + + do_sample = batch.meta_info.get("do_sample", True) + is_validate = batch.meta_info.get("validate", False) + if not do_sample or is_validate: + kwargs["n"] = 1 + kwargs["temperature"] = 0 + + kwargs.update(sampling_params) + print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}") + + async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): + conversation, batch_conversations, batch_index = ( + info["conversation"], + info["batch_conversations"], + info["batch_index"], + ) + + conversations = [] + for choice in completions.choices: + chat = conversation.copy() + chat.append({"role": choice.message.role, "content": choice.message.content}) + conversations.append(chat) + batch_conversations[batch_index] = conversations + + # NOTE: we can call tools and resubmit chat completions here. + # call_tools(completions, info) + # await self.submit_chat_completions(callback2, ...) + + tasks, batch_conversations = [], [None] * len(batch) + for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]): + # raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...] + tasks.append( + asyncio.create_task( + self.submit_chat_completions( + callback=callback, + callback_additional_info={ + "batch_conversations": batch_conversations, + "batch_index": batch_index, + "conversation": list(conversation), + }, + model=self.model_name, + messages=conversation, + **kwargs, + ) + ) + ) + await asyncio.gather(*tasks) + print("[NaiveChatCompletionScheduler] generate_sequences done") + + return self._postprocess(batch, batch_conversations, kwargs["n"]) + + def _postprocess( + self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int + ) -> DataProto: + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py + # prompts: left pad + # responses: right pad + # input_ids: prompt + response + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + + # prompts: [prompt] from input dataset + prompts = [ + self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False) + for prompt in batch.non_tensor_batch["raw_prompt"] + ] + + # flatten batch_conversations if n > 1 + assert len(batch_conversations) == len(prompts) + batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations] + assert len(batch_conversations) == len(prompts) * n + + # sequences: [prompt + response] + sequences = [ + self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False) + for conversation in batch_conversations + ] + + # responses: [response] + # TODO: mask out tools calling tokens? + responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)] + + prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left") + responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right") + if n > 1: + prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0) + prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0) + + input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1) + attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1) + position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask + + batch = TensorDict( + { + "prompts": prompts["input_ids"], + "responses": responses["input_ids"], + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + }, + batch_size=len(input_ids), + ) + + return DataProto(batch=batch) diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh index a5824a542d5..abc490acdc7 100644 --- a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh @@ -8,10 +8,18 @@ math_test_path=$HOME/data/math/test.parquet train_files="['$gsm8k_train_path', '$math_train_path']" test_files="['$gsm8k_test_path', '$math_test_path']" +# For async rollout mode, dataset should return raw chat. +rollout_mode="sync" +if [ "$rollout_mode" = "async" ]; then + return_raw_chat="True" + chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler +fi + python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ + data.return_raw_chat=$return_raw_chat \ data.train_batch_size=4096 \ data.max_prompt_length=4096 \ data.max_response_length=4096 \ @@ -29,6 +37,8 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode=$rollout_mode \ + actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \ critic.optim.lr=1e-5 \ diff --git a/recipe/dapo/src/dapo_ray_trainer.py b/recipe/dapo/src/dapo_ray_trainer.py index 20294662dda..76a48cc57ff 100644 --- a/recipe/dapo/src/dapo_ray_trainer.py +++ b/recipe/dapo/src/dapo_ray_trainer.py @@ -40,7 +40,7 @@ class RayDAPOTrainer(RayPPOTrainer): Note that this trainer runs on the driver process on a single CPU/GPU node. """ - def fit(self): + async def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC diff --git a/recipe/dapo/src/main_dapo.py b/recipe/dapo/src/main_dapo.py index c7eaebdc80b..152d61c5441 100644 --- a/recipe/dapo/src/main_dapo.py +++ b/recipe/dapo/src/main_dapo.py @@ -76,7 +76,8 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: - def run(self, config): + + async def run(self, config): # print initial config from pprint import pprint @@ -201,7 +202,7 @@ def run(self, config): val_reward_fn=val_reward_fn, ) trainer.init_workers() - trainer.fit() + await trainer.fit() if __name__ == "__main__": diff --git a/recipe/prime/main_prime.py b/recipe/prime/main_prime.py index 5f912e374de..f34fe8b6607 100644 --- a/recipe/prime/main_prime.py +++ b/recipe/prime/main_prime.py @@ -29,6 +29,8 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import asyncio + import hydra import ray @@ -53,6 +55,10 @@ def run_prime(config, compute_score=None): @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head def main_task(config, compute_score=None): + asyncio.run(_main_task(config, compute_score)) + + +async def _main_task(config, compute_score=None): # print initial config from pprint import pprint @@ -142,7 +148,7 @@ def main_task(config, compute_score=None): val_reward_fn=val_reward_fn, ) trainer.init_workers() - trainer.fit() + await trainer.fit() if __name__ == "__main__": diff --git a/recipe/prime/prime_ray_trainer.py b/recipe/prime/prime_ray_trainer.py index a21416e199a..873660105c5 100644 --- a/recipe/prime/prime_ray_trainer.py +++ b/recipe/prime/prime_ray_trainer.py @@ -331,7 +331,7 @@ def _load_checkpoint(self): if isinstance(self.train_dataloader.dataset, RLHFDataset): self.train_dataloader.dataset.resume_dataset_state() - def fit(self): + async def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. diff --git a/tests/ray/test_worker_group_basics.py b/tests/ray/test_worker_group_basics.py index 02a5b94ebb9..d0cecb515bf 100644 --- a/tests/ray/test_worker_group_basics.py +++ b/tests/ray/test_worker_group_basics.py @@ -68,7 +68,9 @@ def foo_custom(self, x, y): @ray.remote(num_gpus=0.1) def remote_call_wg(worker_names): class_with_args = RayClassWithInitArgs(cls=TestActor, x=2) - worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args) + worker_group = RayWorkerGroup.from_detached( + worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None + ) print(worker_group.worker_names) output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6]) diff --git a/tests/rollout/test_vllm_multi_turn.py b/tests/rollout/test_vllm_multi_turn.py new file mode 100644 index 00000000000..bc6a154ed01 --- /dev/null +++ b/tests/rollout/test_vllm_multi_turn.py @@ -0,0 +1,151 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from typing import Any, Dict + +import ray +from omegaconf import OmegaConf +from openai.types.chat.chat_completion import ChatCompletion + +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role +from verl.workers.fsdp_workers import AsyncActorRolloutRefWorker +from verl.workers.rollout.async_server import AsyncLLMServerManager + + +async def test_vllm_multi_turn(): + config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml") + model_path = "Qwen/Qwen2-7B-Instruct" + model_name = "/".join(model_path.split("/")[-2:]) + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.chat_scheduler = "verl.workers.rollout.async_server.ChatCompletionScheduler" + config.actor_rollout_ref.rollout.prompt_length = 4096 + config.actor_rollout_ref.rollout.response_length = 4096 + + # test sleep/wake_up with fsdp offload + config.actor_rollout_ref.actor.fsdp_config.param_offload = True + config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True + + # =========================== 1. Create hybrid ActorRollout workers =========================== + # make openai client happy + os.environ["no_proxy"] = "" + os.environ["http_proxy"] = "" + os.environ["https_proxy"] = "" + + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "WARN", + "VLLM_USE_V1": "1", + } + } + ) + role_worker_mapping = { + Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker), + } + global_pool_id = "global_pool" + resource_pool_spec = { + global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, + } + mapping = { + Role.ActorRollout: global_pool_id, + } + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + resource_pool_manager.create_resource_pool() + resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} + + # create actor and rollout + resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout) + actor_rollout_cls = RayClassWithInitArgs( + cls=role_worker_mapping[Role.ActorRollout], config=config.actor_rollout_ref, role="actor_rollout" + ) + resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + + all_wg = {} + wg_dicts = [] + for resource_pool, class_dict in resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + wg_dicts.append(wg_dict) + actor_rollout_wg = all_wg["actor_rollout"] + actor_rollout_wg.init_model() + + # =========================== 2. Create AsyncLLMServerManager&ChatScheduler =========================== + async_rollout_manager = AsyncLLMServerManager( + config=config.actor_rollout_ref, + worker_group=actor_rollout_wg, + ) + async_chat_scheduler = async_rollout_manager.chat_scheduler + + # test sleep and wake_up + await async_rollout_manager.sleep() + await async_rollout_manager.wake_up() + + # =========================== 3. Multi turn rollout =========================== + async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): + assert exception is None, f"exception: {exception}" + messages, round = info["messages"], info["round"] + message = completions.choices[0].message + messages.append({"role": message.role, "content": message.content}) + print(f"[round={round}] role: {message.role}, content: {message.content}") + + extra_headers = {"x-request-id": completions.id} + if round == 0: + messages.append({"role": "user", "content": "What is your name?"}) + await async_chat_scheduler.submit_chat_completions( + callback=callback, + callback_additional_info={"messages": messages, "round": 1}, + model=model_name, + messages=messages, + extra_headers=extra_headers, + ) + elif round == 1: + messages.append({"role": "user", "content": "What is your favorite color?"}) + await async_chat_scheduler.submit_chat_completions( + callback=callback, + callback_additional_info={"messages": messages, "round": 2}, + model=model_name, + messages=messages, + extra_headers=extra_headers, + ) + else: + print("Done!") + + messages = [ + {"role": "user", "content": "Let's play a role playing game. Your name is Bob, your favorite color is red."} + ] + await async_chat_scheduler.submit_chat_completions( + callback=callback, + callback_additional_info={"messages": messages, "round": 0}, + model=model_name, + messages=messages, + ) + assert len(messages) == 6 + for round, message in enumerate(messages): + if round % 2 == 0: + assert message["role"] == "user" + else: + assert message["role"] == "assistant" + + +if __name__ == "__main__": + asyncio.run(test_vllm_multi_turn()) diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 9d95c77d485..f5856212c82 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -37,6 +37,9 @@ class Dispatch(Enum): DP_COMPUTE_PROTO_WITH_FUNC = 10 DP_COMPUTE_METRIC = 11 + # This is a special dispatch mode for vllm ExternalRayDistributedExecutor + DIRECT_ROLLOUT_METHOD = 12 + class Execute(Enum): ALL = 0 @@ -65,6 +68,10 @@ def dispatch_one_to_all(worker_group, *args, **kwargs): return args, kwargs +def dummy_direct_rollout_call(worker_group, *args, **kwargs): + raise NotImplementedError("Direct rollout call is forbidden.") + + def dispatch_all_to_all(worker_group, *args, **kwargs): return args, kwargs @@ -356,6 +363,10 @@ def get_predefined_dispatch_fn(dispatch_mode): "collect_fn": collect_dp_compute_data_proto, }, Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute}, + Dispatch.DIRECT_ROLLOUT_METHOD: { + "dispatch_fn": dummy_direct_rollout_call, + "collect_fn": dummy_direct_rollout_call, + }, } return predefined_dispatch_mode_fn[dispatch_mode] diff --git a/verl/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py index de7f702a894..8ff70bd36a8 100644 --- a/verl/single_controller/base/register_center/ray.py +++ b/verl/single_controller/base/register_center/ray.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict + import ray @@ -19,10 +21,18 @@ class WorkerGroupRegisterCenter: def __init__(self, rank_zero_info): self.rank_zero_info = rank_zero_info + # rank -> node_id + self.workers_info: Dict[int, str] = {} def get_rank_zero_info(self): return self.rank_zero_info + def set_worker_info(self, rank, node_id) -> None: + self.workers_info[rank] = node_id + + def get_worker_info(self) -> Dict[int, str]: + return self.workers_info + def create_worker_group_register_center(name, info): return WorkerGroupRegisterCenter.options(name=name).remote(info) diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 59ff599f061..0655d0e7a22 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -19,6 +19,8 @@ import socket from dataclasses import dataclass +import ray + from .decorator import Dispatch, Execute, register @@ -125,6 +127,11 @@ def _configure_before_init(self, register_center_name: str, rank: int): ) os.environ.update(rank_zero_info) + else: + self.register_center = ray.get_actor(register_center_name) + + # set worker info for node affinity scheduling + ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id())) def __init__(self, cuda_visible_devices=None) -> None: # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index 6bf7ddc1f72..eb2cf0a7771 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -13,8 +13,10 @@ # limitations under the License. import logging +import os import time from typing import Any, Dict, List, Optional, Tuple +from unittest.mock import patch import ray from ray.experimental.state.api import get_actor @@ -23,6 +25,7 @@ from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy, PlacementGroupSchedulingStrategy from verl.single_controller.base import ClassWithInitArgs, ResourcePool, Worker, WorkerGroup +from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch __all__ = ["Worker"] @@ -300,17 +303,23 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d elapsed = int(time.time() - start_time) if elapsed % 30 == 0: logging.warning( - f"Waiting for register center actor {actor_name} to be ready. " - f"Elapsed time: {elapsed} seconds out of {self._ray_wait_register_center_timeout} seconds." + "Waiting for register center actor %s to be ready. " + "Elapsed time: %s seconds out of %s seconds.", + actor_name, + elapsed, + self._ray_wait_register_center_timeout, ) time.sleep(1) if register_center_actor is None: raise TimeoutError( - f"Failed to get register_center_actor {actor_name} in {list_named_actors(all_namespaces=True)} " + f"Failed to get register_center_actor {actor_name} " + f"in {list_named_actors(all_namespaces=True)} " f"for {self._ray_wait_register_center_timeout} seconds. " - "Ensure that any lingering Ray resources from previous runs are cleaned up (e.g., by restarting the Ray cluster), " - "or adjust the waiting time by modifying the config `trainer.ray_wait_register_center_timeout`." + "Ensure that any lingering Ray resources from previous " + "runs are cleaned up (e.g., by restarting the Ray cluster), " + "or adjust the waiting time by modifying the config " + "`trainer.ray_wait_register_center_timeout`." ) rank_zero_info = ray.get(register_center_actor.get_rank_zero_info.remote()) @@ -323,9 +332,14 @@ def worker_names(self): return self._worker_names @classmethod - def from_detached(cls, worker_names=None, ray_cls_with_init=None): + def from_detached( + cls, + name_prefix, + worker_names=None, + ray_cls_with_init=None, + ): worker_group = cls( - resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=None, worker_names=worker_names + resource_pool=None, ray_cls_with_init=ray_cls_with_init, name_prefix=name_prefix, worker_names=worker_names ) return worker_group @@ -350,7 +364,9 @@ def _rebind_actor_methods(worker_group, actor_name): new_worker_group_dict = {} for prefix in prefix_set: new_worker_group = self.from_detached( - worker_names=self._worker_names, ray_cls_with_init=self.ray_cls_with_init + name_prefix=self.name_prefix, + worker_names=self._worker_names, + ray_cls_with_init=self.ray_cls_with_init, ) _rebind_actor_methods(new_worker_group, prefix) @@ -374,8 +390,9 @@ def execute_all_sync(self, method_name: str, *args, **kwargs): return ray.get(self.execute_all_async(method_name, *args, **kwargs)) def execute_all_async(self, method_name: str, *args, **kwargs): - # Here, we assume that if all arguments in args and kwargs are lists, and their lengths match len(self._workers), - # we'll distribute each element in these lists to the corresponding worker + # Here, we assume that if all arguments in args and kwargs are lists, + # and their lengths match len(self._workers), we'll distribute each + # element in these lists to the corresponding worker # print(f"execute_all_async: method {method_name}({args}, {kwargs})") length = len(self._workers) if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): @@ -413,17 +430,13 @@ def world_size(self): with code written in separate ray.Actors. """ -import os -from unittest.mock import patch - -from verl.single_controller.base.decorator import MAGIC_ATTR - def _bind_workers_method_to_parent(cls, key, user_defined_cls): """ Binds the methods of each worker to the WorkerDict. Note that we only bind public methods that are decorated by register """ + for method_name in dir(user_defined_cls): try: method = getattr(user_defined_cls, method_name) @@ -434,22 +447,30 @@ def _bind_workers_method_to_parent(cls, key, user_defined_cls): if hasattr(method, MAGIC_ATTR): - def generate_function(name): + def generate_function(name, key=key): def func(self, *args, **kwargs): # dispatch to the actual worker return getattr(self.worker_dict[key], name)(*args, **kwargs) - return func + return func # noqa: B023 func = generate_function(method_name) # pass MAGIC_ATTR for outer worker group - setattr(func, MAGIC_ATTR, getattr(method, MAGIC_ATTR)) + attrs = getattr(method, MAGIC_ATTR) + setattr(func, MAGIC_ATTR, attrs) try: - method_name_with_prefix = key + "_" + method_name - setattr(cls, method_name_with_prefix, func) - # print(f'Binding {method_name_with_prefix}') - except Exception: - raise ValueError(f"Fail to set method_name {method_name}") + # bind direct rollout method to class without prefix + if attrs["dispatch_mode"] == Dispatch.DIRECT_ROLLOUT_METHOD and "rollout" in key: + assert not hasattr(cls, method_name), ( + f"conflict direct rollout method {method_name} with role {key}" + ) + setattr(cls, method_name, func) + print(f"bind role {key} method {method_name} to class {cls}") + else: + method_name_with_prefix = key + "_" + method_name + setattr(cls, method_name_with_prefix, func) + except Exception as e: + raise ValueError(f"Fail to set method_name {method_name}") from e def _unwrap_ray_remote(cls): @@ -458,6 +479,19 @@ def _unwrap_ray_remote(cls): return cls +def _determine_fsdp_megatron_base_class(mros: List): + """ + - megatron: base class should be MegatronWorker + - fsdp: base class should be Worker + """ + for cls in mros[0]: + if cls.__name__ == "MegatronWorker": + return cls + if cls.__name__ == "Worker": + return cls + raise ValueError(f"Cannot determine base class for {mros}") + + def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): """ This function should return a class instance that delegates the calls to every @@ -465,14 +499,13 @@ def create_colocated_worker_cls(class_dict: dict[str, RayClassWithInitArgs]): """ cls_dict = {} init_args_dict = {} - worker_cls = None + worker_cls = _determine_fsdp_megatron_base_class( + [cls.cls.__ray_actor_class__.__mro__ for cls in class_dict.values()] + ) + assert issubclass(worker_cls, Worker), f"worker_cls {worker_cls} should be a subclass of Worker" + print(f"colocated worker base class {worker_cls}") + for key, cls in class_dict.items(): - if worker_cls is None: - worker_cls = cls.cls.__ray_actor_class__.__base__ - else: - assert worker_cls == cls.cls.__ray_actor_class__.__base__, ( - "the worker class should be the same when share the same process" - ) cls_dict[key] = cls.cls init_args_dict[key] = {"args": cls.args, "kwargs": cls.kwargs} @@ -486,7 +519,8 @@ def __init__(self): for key, user_defined_cls in cls_dict.items(): user_defined_cls = _unwrap_ray_remote(user_defined_cls) # directly instantiate the class without remote - # in worker class, e.g. when DISABLE_WORKER_INIT == 1 it will return immediately + # in worker class, e.g. + # when DISABLE_WORKER_INIT == 1 it will return immediately with patch.dict(os.environ, {"DISABLE_WORKER_INIT": "1"}): self.worker_dict[key] = user_defined_cls( *init_args_dict[key].get("args", ()), **init_args_dict[key].get("kwargs", {}) diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index 8b542f542d4..d3068e886c9 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -14,6 +14,7 @@ model: external_lib: null rollout: name: vllm + mode: sync # sync: LLM, async: AsyncLLM temperature: 1.0 top_k: 50 # 0 for hf rollout, -1 for vllm rollout top_p: 0.7 diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 4c84f315bef..838663886bc 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -94,6 +94,7 @@ actor_rollout_ref: log_prob_micro_batch_size_per_gpu: null rollout: name: vllm + mode: sync # sync: LLM, async: AsyncLLM temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 34957b06c45..ca14b378bf9 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -83,6 +83,8 @@ actor_rollout_ref: ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size rollout: name: vllm + mode: sync # sync: LLM, async: AsyncLLM + chat_scheduler: null # async chat scheduler, e.g examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 1721c88af52..5ab0f688475 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -83,7 +83,7 @@ def run_ppo(config) -> None: @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head class TaskRunner: - def run(self, config): + async def run(self, config): # print initial config from pprint import pprint @@ -108,8 +108,13 @@ def run(self, config): if config.actor_rollout_ref.actor.strategy == "fsdp": assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.single_controller.ray import RayWorkerGroup - from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + actor_rollout_cls = ( + AsyncActorRolloutRefWorker + if config.actor_rollout_ref.rollout.mode == "async" + else ActorRolloutRefWorker + ) ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == "megatron": @@ -117,6 +122,7 @@ def run(self, config): from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + actor_rollout_cls = ActorRolloutRefWorker ray_worker_group_cls = NVMegatronRayWorkerGroup else: @@ -125,7 +131,7 @@ def run(self, config): from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role role_worker_mapping = { - Role.ActorRollout: ray.remote(ActorRolloutRefWorker), + Role.ActorRollout: ray.remote(actor_rollout_cls), Role.Critic: ray.remote(CriticWorker), } @@ -176,7 +182,7 @@ def run(self, config): val_reward_fn=val_reward_fn, ) trainer.init_workers() - trainer.fit() + await trainer.fit() if __name__ == "__main__": diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 2c15b8e7fc6..35fefdf17a7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -56,6 +56,7 @@ from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger +from verl.workers.rollout.async_server import AsyncLLMServerManager WorkerType = Type[Worker] @@ -581,7 +582,7 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): # Log to each configured logger self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) - def _validate(self): + async def _validate(self): data_source_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) @@ -616,7 +617,7 @@ def _validate(self): else: test_gen_batch = test_batch.pop( batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], + non_tensor_batch_keys=["raw_prompt_ids"] + ["raw_prompt"] if self.async_rollout_mode else [], ) test_gen_batch.meta_info = { @@ -630,7 +631,14 @@ def _validate(self): # pad to be divisible by dp_size test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_wg.world_size) - test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + if not self.async_rollout_mode: + test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded) + else: + await self.async_rollout_manager.wake_up() + test_output_gen_batch_padded = await self.async_rollout_manager.generate_sequences( + test_gen_batch_padded + ) + await self.async_rollout_manager.sleep() # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) @@ -770,6 +778,15 @@ def init_workers(self): self.actor_rollout_wg = all_wg["actor_rollout"] self.actor_rollout_wg.init_model() + # create async rollout manager and request scheduler + self.async_rollout_mode = False + if self.config.actor_rollout_ref.rollout.mode == "async": + self.async_rollout_mode = True + self.async_rollout_manager = AsyncLLMServerManager( + config=self.config.actor_rollout_ref, + worker_group=self.actor_rollout_wg, + ) + def _save_checkpoint(self): # path: given_path + `/global_step_{global_steps}` + `/actor` local_global_step_folder = os.path.join( @@ -899,7 +916,7 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle ) metrics.update(global_balance_stats) - def fit(self): + async def fit(self): """ The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC @@ -925,7 +942,7 @@ def fit(self): # perform validation before training # currently, we only support validation using the reward_function. if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() + val_metrics = await self._validate() pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get("val_only", False): @@ -954,7 +971,7 @@ def fit(self): else: gen_batch = batch.pop( batch_keys=["input_ids", "attention_mask", "position_ids"], - non_tensor_batch_keys=["raw_prompt_ids"], + non_tensor_batch_keys=["raw_prompt_ids"] + ["raw_prompt"] if self.async_rollout_mode else [], ) is_last_step = self.global_steps >= self.total_training_steps @@ -962,7 +979,12 @@ def fit(self): with _timer("step", timing_raw): # generate a batch with _timer("gen", timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + await self.async_rollout_manager.wake_up() + gen_batch_output = await self.async_rollout_manager.generate_sequences(gen_batch) + await self.async_rollout_manager.sleep() if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: with _timer("gen_max", timing_raw): @@ -1105,7 +1127,7 @@ def fit(self): and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) ): with _timer("testing", timing_raw): - val_metrics: dict = self._validate() + val_metrics: dict = await self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) @@ -1117,10 +1139,12 @@ def fit(self): self._save_checkpoint() # training metrics - metrics.update({ - 'training/global_step': self.global_steps, - 'training/epoch': epoch, - }) + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) diff --git a/verl/utils/vllm_utils.py b/verl/utils/vllm_utils.py index ee518d5e33e..e395db1cf57 100644 --- a/verl/utils/vllm_utils.py +++ b/verl/utils/vllm_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + def patch_vllm_moe_model_weight_loader(model): # this is a work around to load the weight of vllm fused moe model # it is from a bug from vllm 0.8.2 @@ -29,10 +30,9 @@ def patch_vllm_moe_model_weight_loader(model): # (False, 'model.layers.0.post_attention_layernorm.weight') use default # (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader # (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader - from vllm.model_executor.models.deepseek_v2 import (DeepseekV2ForCausalLM, - DeepseekV3ForCausalLM) + from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM - + if not isinstance(model, (Qwen2MoeForCausalLM, DeepseekV2ForCausalLM, DeepseekV3ForCausalLM)): return for layer in model.model.layers: @@ -40,4 +40,4 @@ def patch_vllm_moe_model_weight_loader(model): param_dict = dict(mlp.named_parameters()) for name, param in param_dict.items(): if "w13_weight" in name or "w2_weight" in name: - param.weight_loader = mlp.experts.weight_loader \ No newline at end of file + param.weight_loader = mlp.experts.weight_loader diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index d1153c87172..b565f4b4e62 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -18,6 +18,7 @@ import logging import os import warnings +from typing import Union import psutil import torch @@ -125,7 +126,8 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_mini_batch_size *= self.config.rollout.n self.config.actor.ppo_mini_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size assert self.config.actor.ppo_mini_batch_size > 0, ( - f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization" + f"ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} " + "should be larger than 0 after normalization" ) # micro bsz if self.config.actor.ppo_micro_batch_size is not None: @@ -136,10 +138,12 @@ def __init__(self, config: DictConfig, role: str): if self.config.actor.ppo_micro_batch_size_per_gpu is not None: assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, ( - f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be " + f"divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, ( - f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be " + f"larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}" ) # normalize rollout config @@ -349,7 +353,7 @@ def _build_rollout(self, trust_remote_code=False): # TODO: a sharding manager that do nothing? elif rollout_name == "vllm": - from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMRollout + from verl.workers.rollout.vllm_rollout import vllm_mode, vLLMAsyncRollout, vLLMRollout from verl.workers.sharding_manager import FSDPVLLMShardingManager log_gpu_memory_usage(f"Before building {rollout_name} rollout", logger=None) @@ -362,7 +366,8 @@ def _build_rollout(self, trust_remote_code=False): model_hf_config=self.actor_model_config, ) elif vllm_mode == "spmd": - rollout = vLLMRollout( + vllm_rollout_cls = vLLMRollout if self.config.rollout.mode == "sync" else vLLMAsyncRollout + rollout = vllm_rollout_cls( model_path=local_path, config=self.config.rollout, tokenizer=self.tokenizer, @@ -381,16 +386,19 @@ def _build_rollout(self, trust_remote_code=False): model_config=self.actor_model_config, full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, ) log_gpu_memory_usage("After building sharding manager", logger=None) elif rollout_name == "sglang": from verl.workers.rollout.sglang_rollout import SGLangRollout - # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability. - # However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to: + # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to + # SGLang's model_runner would check CUDA device capability. However, due to veRL's setting, + # the main process of ray can not find any CUDA device, which would potentially lead to: # "RuntimeError: No CUDA GPUs are available". - # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path. + # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and + # we import it here use the abs path. # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76 from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager @@ -412,6 +420,7 @@ def _build_rollout(self, trust_remote_code=False): model_config=self.actor_model_config, full_params="hf" in self.config.rollout.load_format, device_mesh=rollout_device_mesh, + offload_param=self._is_offload_param, ) log_gpu_memory_usage("After building sharding manager", logger=None) @@ -547,8 +556,6 @@ def generate_sequences(self, prompts: DataProto): prompts = prompts.to(torch.cuda.current_device()) assert self._is_rollout - if self._is_offload_param: - load_fsdp_model_to_gpu(self.actor_module_fsdp) meta_info = { "eos_token_id": self.generation_config.eos_token_id @@ -560,11 +567,7 @@ def generate_sequences(self, prompts: DataProto): } prompts.meta_info.update(meta_info) with self.rollout_sharding_manager: - # after parameters sync with rollout, offload actor model to CPU - if self._is_offload_param: - offload_fsdp_model_to_cpu(self.actor_module_fsdp) - if self._is_offload_optimizer: - offload_fsdp_optimizer(optimizer=self.actor_optimizer) + log_gpu_memory_usage("After entering rollout sharding manager", logger=logger) prompts = self.rollout_sharding_manager.preprocess_data(prompts) output = self.rollout.generate_sequences(prompts=prompts) @@ -716,10 +719,12 @@ def __init__(self, config): if self.config.ppo_micro_batch_size_per_gpu is not None: assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0, ( - f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be " + f"divisible by ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) assert self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu > 0, ( - f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" + f"normalized ppo_mini_batch_size {self.config.ppo_mini_batch_size} should be " + f"larger than ppo_micro_batch_size_per_gpu {self.config.ppo_micro_batch_size_per_gpu}" ) def _build_critic_model_optimizer(self, config): @@ -1283,3 +1288,35 @@ def compute_rm_score(self, data: DataProto): output = output.to("cpu") return output + + +# ================================= Async related workers ================================= +class AsyncActorRolloutRefWorker(ActorRolloutRefWorker): + def _build_rollout(self, trust_remote_code=False): + rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code) + + # NOTE: rollout is not actually initialized here, it's deferred + # to be initialized by AsyncvLLMServer. + + self.vllm_tp_size = self.config.rollout.tensor_model_parallel_size + self.vllm_dp_rank = int(os.environ["RANK"]) // self.vllm_tp_size + self.vllm_tp_rank = int(os.environ["RANK"]) % self.vllm_tp_size + + # used for sleep/wake_up + rollout.sharding_manager = rollout_sharding_manager + + return rollout, rollout_sharding_manager + + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def generate_sequences(self, prompts: DataProto): + raise NotImplementedError("AsyncActorRolloutRefWorker does not support generate_sequences") + + @register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD) + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + """Called by ExternalRayDistributedExecutor collective_rpc.""" + if self.vllm_tp_rank == 0 and method != "execute_model": + print( + f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] " + f"execute_method: {method if isinstance(method, str) else 'Callable'}" + ) + return self.rollout.execute_method(method, *args, **kwargs) diff --git a/verl/workers/rollout/async_server.py b/verl/workers/rollout/async_server.py new file mode 100644 index 00000000000..866ca366dfd --- /dev/null +++ b/verl/workers/rollout/async_server.py @@ -0,0 +1,314 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import heapq +import importlib +import logging +import os +import socket +from abc import ABC, abstractmethod +from contextlib import asynccontextmanager +from typing import Any, Callable, Dict, List, Tuple, Type +from uuid import uuid4 + +import aiohttp +import fastapi +import ray +import uvicorn +from cachetools import LRUCache +from omegaconf import DictConfig +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion +from starlette.requests import Request + +from verl.protocol import DataProto +from verl.single_controller.ray.base import RayWorkerGroup +from verl.utils import hf_tokenizer +from verl.utils.fs import copy_to_local + +logger = logging.getLogger(__file__) + + +def _get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +class AsyncServerBase(ABC): + """Base class for AsyncServer.""" + + def __init__(self): + self.address = ray._private.services.get_node_ip_address() + self.port = None + self.server_ready = asyncio.Event() + asyncio.create_task(self._start_fastapi_server()) + + async def _start_fastapi_server(self): + @asynccontextmanager + async def lifespan(app: fastapi.FastAPI): + print("FastAPI startup") + self.server_ready.set() + yield + + # There's no way to gracefully restart uvicorn server if port is already in use, + # so we exit the process directly and let AsyncLLMServerManager restart it. + print("FastAPI shutdown, maybe address already in use, exit process immediately.") + os._exit(-1) + + app = fastapi.FastAPI(lifespan=lifespan) + app.router.add_api_route("/v1/chat/completions", self.chat_completion, methods=["POST"]) + + self.port = _get_free_port() + config = uvicorn.Config(app, host=["::", "0.0.0.0"], port=self.port, log_level="warning") + server = uvicorn.Server(config) + await server.serve() + + async def get_server_address(self) -> Tuple[str, int]: + """Get FastAPI server address.""" + await self.server_ready.wait() + return f"{self.address}:{self.port}" + + @abstractmethod + async def chat_completion(self, raw_request: Request): + """OpenAI chat completion API. + + API reference: https://platform.openai.com/docs/api-reference/chat/create + """ + raise NotImplementedError + + @abstractmethod + async def init_engine(self): + """Init async LLM engine.""" + raise NotImplementedError + + @abstractmethod + async def wake_up(self): + """Wake up engine to load model weights and build kv cache.""" + raise NotImplementedError + + @abstractmethod + async def sleep(self): + """Sleep engine to offload model weights and discard kv cache.""" + raise NotImplementedError + + +class ChatCompletionScheduler: + def __init__( + self, + config: DictConfig, + model_path: str, + server_addresses: List[str], + max_cache_size: int = 10000, + ): + """ + Args: + config: DictConfig, rollout config. + model_path: str, model path. + server_addresses: List[str], server addresses. + max_cache_size: int, max cache size of request_id to address mapping. + """ + self.config = config + self.model_name = "/".join(model_path.split("/")[-2:]) + local_path = copy_to_local(model_path) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + + # Least requests load balancing + self.weighted_addresses = [[0, address] for address in server_addresses] + heapq.heapify(self.weighted_addresses) + + # LRU cache to map request_id to address + self.request_id_to_address = LRUCache(maxsize=max_cache_size) + + async def submit_chat_completions( + self, + callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], + callback_additional_info: Dict[str, Any], + **chat_complete_request, + ): + """ + Submit a chat completion request to the server with the least number of requests. + + Args: + callback: Callable[[ChatCompletion, Dict[str, Any], Exception], None], async callback function + to handle the response. The callback function should have the following signature: + + ```python + async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception): + ... + ``` + - completions: chat completion response from server. + - info: user provided `callback_additional_info`. + - exception: exception raise from OpenAI client if request failed, otherwise None. + + **CAUTION**: the callback function must be async and non-blocking, if you have any blocking operation, + please move to seperate thread or process pool to avoid blocking the event loop. + + callback_additional_info: Dict[str, Any], additional info to pass to the callback function. + + **chat_complete_request: dict, request parameters same as OpenAI AsyncCompletions.create. + OpenAI API reference: https://platform.openai.com/docs/api-reference/chat/create + """ + if "extra_headers" not in chat_complete_request: + chat_complete_request["extra_headers"] = {} + + extra_headers = chat_complete_request["extra_headers"] + request_id = extra_headers.get("x-request-id", None) + if request_id: + if request_id.startswith("chatcmpl-"): + request_id = request_id[len("chatcmpl-") :] + extra_headers["x-request-id"] = request_id + + address = self.request_id_to_address[request_id] + else: + address = self.weighted_addresses[0][1] + self.weighted_addresses[0][0] += 1 + heapq.heapreplace(self.weighted_addresses, self.weighted_addresses[0]) + + request_id = uuid4().hex + self.request_id_to_address[request_id] = address + chat_complete_request["extra_headers"]["x-request-id"] = request_id + + completions, exception = None, None + try: + # TODO: OpenAI client uses httpx, seems to have performance issue in high concurrency requests. + completions = await self._chat_completions_openai(address, **chat_complete_request) + except Exception as e: + # Let user handle the exception + exception = e + + await callback(completions, callback_additional_info, exception) + + async def _chat_completions_openai(self, address: str, **chat_complete_request) -> ChatCompletion: + client = AsyncOpenAI( + base_url=f"http://{address}/v1", + api_key="token-abc123", + ) + return await client.chat.completions.create(**chat_complete_request) + + async def _chat_completions_aiohttp(self, address: str, **chat_complete_request) -> ChatCompletion: + try: + session = aiohttp.ClientSession() + async with session.post( + url=f"http://{address}/v1/chat/completions", + headers={"Authorization": "Bearer token-abc123"}, + json=chat_complete_request, + ) as resp: + data = await resp.json() + return ChatCompletion(**data) + finally: + await session.close() + + async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: + raise NotImplementedError + + +class AsyncLLMServerManager: + """AsyncLLMServerManager manage a group of vllm instances, i.e AsyncvLLMServer.""" + + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): + """Initialize AsyncLLMServerManager. + + Args: + config: DictConfig, actor_rollout_ref config. + worker_group: RayWorkerGroup, worker group of AsyncActorRolloutRefWorker. + """ + self.config = config + self.worker_group = worker_group + + self.rollout_tp_size = self.config.rollout.tensor_model_parallel_size + self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size + + register_center = ray.get_actor(f"{self.worker_group.name_prefix}_register_center") + workers_info = ray.get(register_center.get_worker_info.remote()) + assert len(workers_info) == self.worker_group.world_size + + self.async_llm_servers = [None] * self.rollout_dp_size + self.server_addresses = [None] * self.rollout_dp_size + + server_class = async_server_class( + rollout_backend=self.config.rollout.name, + ) + + # Start all server instances, restart if address already in use. + unready_dp_ranks = set(range(self.rollout_dp_size)) + while len(unready_dp_ranks) > 0: + workers = { + rollout_dp_rank: server_class.options( + # make sure AsyncvLLMServer colocates with its corresponding workers + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=workers_info[rollout_dp_rank * self.rollout_tp_size], + soft=False, + ), + name=f"async_llm_worker_{rollout_dp_rank}", + ).remote(config, self.rollout_dp_size, rollout_dp_rank, self.worker_group.name_prefix) + for rollout_dp_rank in unready_dp_ranks + } + + for rollout_dp_rank, worker in workers.items(): + try: + address = ray.get(worker.get_server_address.remote()) + self.server_addresses[rollout_dp_rank] = address + self.async_llm_servers[rollout_dp_rank] = worker + unready_dp_ranks.remove(rollout_dp_rank) + except Exception: + ray.kill(worker) + print(f"worker {rollout_dp_rank} failed, maybe address already in use, restarting...") + + # All server instances are ready, init AsyncLLM engine. + ray.get([worker.init_engine.remote() for worker in self.async_llm_servers]) + + # Init user provided chat scheduler. + self.chat_scheduler = self._init_chat_scheduler() + + def _init_chat_scheduler(self) -> ChatCompletionScheduler: + module_path, class_name = self.config.rollout.chat_scheduler.rsplit(".", 1) + module = importlib.import_module(module_path) + scheduler_cls = getattr(module, class_name) + return scheduler_cls( + config=self.config.rollout, + model_path=self.config.model.path, + server_addresses=self.server_addresses, + ) + + async def wake_up(self): + """Wake up all vllm instances.""" + await asyncio.gather(*[worker.wake_up.remote() for worker in self.async_llm_servers]) + + async def sleep(self): + """Sleep all vllm instances.""" + await asyncio.gather(*[worker.sleep.remote() for worker in self.async_llm_servers]) + + async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto: + """Generate sequences via chat scheduler.""" + return await self.chat_scheduler.generate_sequences(prompts, **sampling_params) + + +def async_server_class(rollout_backend: str) -> Type[AsyncServerBase]: + """Get async server class. + + Args: + rollout_backend: str, rollout backend, should be "vllm" or "sglang". + + Returns: + Type[AsyncServerBase]: async server class. + """ + if rollout_backend == "vllm": + from verl.workers.rollout.vllm_rollout.vllm_async_server import AsyncvLLMServer + + return AsyncvLLMServer + elif rollout_backend == "sglang": + raise NotImplementedError + else: + raise NotImplementedError diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 1c0fe1255a4..66da822ea86 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -11,15 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from importlib.metadata import PackageNotFoundError, version -### -# [SUPPORT AMD:] -import torch - -### - def get_version(pkg): try: @@ -34,7 +28,8 @@ def get_version(pkg): ### # package_version = get_version(package_name) # [SUPPORT AMD:] -if "AMD" in torch.cuda.get_device_name(): +# Do not call any torch.cuda* API here, or ray actor creation import class will fail. +if "ROCM_PATH" in os.environ: import re package_version = version(package_name) @@ -45,8 +40,8 @@ def get_version(pkg): if package_version <= "0.6.3": vllm_mode = "customized" - from .fire_vllm_rollout import FIREvLLMRollout - from .vllm_rollout import vLLMRollout + from .fire_vllm_rollout import FIREvLLMRollout # noqa: F401 + from .vllm_rollout import vLLMRollout # noqa: F401 else: vllm_mode = "spmd" - from .vllm_rollout_spmd import vLLMRollout + from .vllm_rollout_spmd import vLLMAsyncRollout, vLLMRollout # noqa: F401 diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py new file mode 100644 index 00000000000..4f551c07116 --- /dev/null +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -0,0 +1,243 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import cloudpickle +import ray +from omegaconf import DictConfig +from starlette.requests import Request +from starlette.responses import JSONResponse, StreamingResponse +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ChatCompletionResponse, ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels +from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.executor.abstract import Executor +from vllm.worker.worker_base import WorkerWrapperBase + +from verl.utils.fs import copy_to_local +from verl.workers.rollout.async_server import AsyncServerBase + +logger = logging.getLogger(__file__) + + +class ExternalRayDistributedExecutor(Executor): + """An executor that engines are launched by external ray actors.""" + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert self.vllm_config.instance_id is not None, "instance_id must be set for external ray actors." + + fields = self.vllm_config.instance_id.split(":") + assert len(fields) == 4, ( + f"instance_id: {self.vllm_config.instance_id} must be in " + f"the format of :::." + ) + namespace, wg_prefix, vllm_dp_size, vllm_dp_rank = fields[0], fields[1], int(fields[2]), int(fields[3]) + + # Make sure subprocess in same namespace as parent actor. + # actor name format: {name_prefix}WorkerDict_{pg_idx}:{local_rank} + ray.init(namespace=namespace) + actor_names = [ + actor_name for actor_name in ray.util.list_named_actors() if actor_name.startswith(f"{wg_prefix}WorkerDict") + ] + + vllm_tp_size = self.vllm_config.parallel_config.tensor_parallel_size + assert len(actor_names) == vllm_dp_size * vllm_tp_size, ( + f"instance_id: {self.vllm_config.instance_id} has {len(actor_names)} actors, " + f"but vllm_dp_size: {vllm_dp_size} * vllm_tp_size: {vllm_tp_size} = " + f"{vllm_dp_size * vllm_tp_size} is expected." + ) + + def get_pg_index_and_local_rank(actor_name) -> Tuple[int, int]: + fields = actor_name.split(":") + assert len(fields) == 2, f"invalid actor name: {actor_name}" + pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1]) + return pg_index, local_rank + + # sort actor names by pg_index and local_rank + actor_names = sorted(actor_names, key=get_pg_index_and_local_rank) + actor_names = actor_names[vllm_dp_rank * vllm_tp_size : (vllm_dp_rank + 1) * vllm_tp_size] + self.workers: List[WorkerWrapperBase] = [ray.get_actor(actor_name) for actor_name in actor_names] + print(f"instance_id: {self.vllm_config.instance_id} intializes with external actors: {actor_names}") + + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=None, + rank=None, + distributed_init_method="env://", + is_driver_worker=True, + ) + self.collective_rpc("init_worker", args=([kwargs],)) + self.collective_rpc("init_device") + self.collective_rpc("load_model") + print(f"instance_id: {self.vllm_config.instance_id} intializes finished.") + + def collective_rpc( + self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict[str, Any]] = None, + ) -> List[Any]: + # TODO(wuxibin): support ray compiled graph + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method + + outputs = ray.get( + [worker.execute_method.remote(sent_method, *args, **(kwargs or {})) for worker in self.workers] + ) + return outputs + + def check_health(self): + return + + +@ray.remote(num_cpus=1) +class AsyncvLLMServer(AsyncServerBase): + """ + AsyncvLLMServer is a wrapper for AsyncLLM, it uses ExternalRayDistributedExecutor to launch engines + in hybrid rollout workers, i.e AsyncActorRolloutRefWorker. + + AsyncvLLMServer works as follows: + 1. Start FastAPI server first. + 2. Initialize AsyncLLM with ExternalRayDistributedExecutor. + 3. AsyncLLM spawn EngineCore in subprocess. + 4. EngineCore initialize ExternalRayDistributedExecutor. + 5. ExternalRayDistributedExecutor lookup its corresponding actors by name. + 6. ExternalRayDistributedExecutor init executor: init_worker, init_device, load_model. + + For vLLM AsyncLLM design, see: https://github.com/vllm-project/vllm/pull/9826 + """ + + def __init__(self, config: DictConfig, vllm_dp_size: int, vllm_dp_rank: int, wg_prefix: str): + """ + Args: + config: DictConfig, actor_rollout_ref config. + vllm_dp_size: int, vllm data parallel size. + vllm_dp_rank: int, vllm data parallel rank. + wg_prefix: str, worker group prefix, used to lookup actors. + """ + super().__init__() + + self.config = config + self.vllm_dp_size = vllm_dp_size + self.vllm_dp_rank = vllm_dp_rank + self.wg_prefix = wg_prefix + self.engine: AsyncLLM = None + + async def init_engine(self): + """Init vLLM AsyncLLM engine.""" + config = self.config + model_path = config.model.path + model_name = "/".join(model_path.split("/")[-2:]) + local_path = copy_to_local(model_path) + trust_remote_code = config.model.get("trust_remote_code", False) + config = config.rollout + + tensor_parallel_size = config.get("tensor_model_parallel_size", 1) + max_num_batched_tokens = config.get("max_num_batched_tokens", 8192) + max_model_len = config.max_model_len if config.max_model_len else config.prompt_length + config.response_length + max_model_len = int(max_model_len) + + if max_num_batched_tokens < max_model_len and config.enable_chunked_prefill: + raise ValueError( + "Enable chunked prefill, max_num_batched_tokens is smaller than max_model_len, \ + please increase max_num_batched_tokens or disable chunked prefill" + ) + + # Override default generation config from hugging face model config, + # user can still override them by passing kwargs in each request. + kwargs = dict( + n=1, + logprobs=0, + max_tokens=config.response_length, + ) + for k in config.keys(): + if hasattr(SamplingParams(), str(k)): + kwargs[k] = config.get(k) + print(f"override_generation_config: {kwargs}") + + engine_args = AsyncEngineArgs( + model=local_path, + enable_sleep_mode=True, + override_generation_config=kwargs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=ExternalRayDistributedExecutor, + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + disable_custom_all_reduce=True, + disable_mm_preprocessor_cache=True, + skip_tokenizer_init=False, + max_model_len=max_model_len, + load_format="auto", + disable_log_stats=config.disable_log_stats, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, + enable_prefix_caching=True, + trust_remote_code=trust_remote_code, + seed=self.vllm_dp_rank, + ) + + # init async llm engine + vllm_config = engine_args.create_engine_config() + namespace = ray.get_runtime_context().namespace + vllm_config.instance_id = f"{namespace}:{self.wg_prefix}:{self.vllm_dp_size}:{self.vllm_dp_rank}" + self.engine = AsyncLLM.from_vllm_config(vllm_config) + + # build serving chat + model_config = self.engine.model_config + BASE_MODEL_PATHS = [BaseModelPath(name=model_name, model_path=model_path)] + models = OpenAIServingModels(self.engine, model_config, BASE_MODEL_PATHS) + self.openai_serving_chat = OpenAIServingChat( + self.engine, + model_config, + models, + "assistant", + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + ) + + async def chat_completion(self, raw_request: Request): + """OpenAI-compatible HTTP endpoint. + + API reference: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html + """ + request_json = await raw_request.json() + request = ChatCompletionRequest(**request_json) + generator = await self.openai_serving_chat.create_chat_completion(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") + else: + assert isinstance(generator, ChatCompletionResponse) + return JSONResponse(content=generator.model_dump()) + + async def wake_up(self): + await self.engine.wake_up() + + async def sleep(self): + # TODO: https://github.com/vllm-project/vllm/issues/17103 + await self.engine.reset_prefix_cache() + await self.engine.sleep() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 236a321a682..9b9567cca37 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -19,7 +19,8 @@ When working with Megatron: - Use Megatron weight loader - During training, only the current pp stage holds the parameters -- Before inference, broadcast the parameters of the current pp rank to all other pp ranks (all pp ranks holds all the parameters) +- Before inference, broadcast the parameters of the current pp rank + to all other pp ranks (all pp ranks holds all the parameters) - Bind the parameters to the inference engine - Do inference in tp. pp is treated as additional dp - After inference, all the parameters that doesn't belong to this pp rank is freed. @@ -28,7 +29,7 @@ import logging import os from contextlib import contextmanager -from typing import Any, List, Union +from typing import Any, Dict, List, Union import numpy as np import torch @@ -37,12 +38,14 @@ from tensordict import TensorDict from vllm import LLM, SamplingParams from vllm.distributed import parallel_state as vllm_ps +from vllm.worker.worker_base import WorkerWrapperBase from verl import DataProto from verl.third_party.vllm import vllm_version from verl.utils.debug import GPUMemoryLogger from verl.utils.torch_functional import get_response_mask, pad_2d_list_to_length from verl.workers.rollout.base import BaseRollout +from verl.workers.sharding_manager import FSDPVLLMShardingManager logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -56,7 +59,8 @@ # NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding. def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]: # remove the left padding in the prompt token_id - # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id + # is not None else self.llm_engine.tokenizer.eos_token_id non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] token_ids = prompt_token_ids[non_pad_index:].tolist() return token_ids @@ -338,3 +342,57 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: self.inference_engine.free_cache_engine() return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + +class vLLMAsyncRollout: + """vLLMAsyncRollout is a thin wrapper of WorkerWrapperBase, + which is engine in single worker process. + """ + + def __init__(self, *args, **kwargs): + # Engine is deferred to be initialized in init_worker + self.inference_engine: WorkerWrapperBase = None + self.sharding_manager: FSDPVLLMShardingManager = None + self.is_sleep = False + + def init_worker(self, all_kwargs: List[Dict[str, Any]]): + """Initialize worker engine.""" + all_kwargs[0]["rank"] = int(os.environ["RANK"]) + all_kwargs[0]["local_rank"] = 0 + + self.vllm_config = all_kwargs[0]["vllm_config"] + self.inference_engine = WorkerWrapperBase(vllm_config=self.vllm_config) + self.inference_engine.init_worker(all_kwargs) + + def load_model(self, *args, **kwargs): + self.inference_engine.load_model(*args, **kwargs) + + # inference engine is intialized now, update sharding manager + self.sharding_manager.inference_engine = self.inference_engine + self.sharding_manager.model_runner = self.inference_engine.worker.model_runner + + def sleep(self, *args, **kwargs): + """Offload model weights and discard kv cache.""" + if self.is_sleep: + return + self.sharding_manager.__exit__(None, None, None) + self.is_sleep = True + + def wake_up(self, *args, **kwargs): + """Load model weights and build kv cache.""" + if not self.is_sleep: + return + self.sharding_manager.__enter__() # pylint: disable=C2801 + self.is_sleep = False + + def execute_method(self, method: Union[str, bytes], *args, **kwargs): + if method == "init_worker": + return self.init_worker(*args, **kwargs) + elif method == "load_model": + return self.load_model(*args, **kwargs) + elif method == "sleep": + return self.sleep(*args, **kwargs) + elif method == "wake_up": + return self.wake_up(*args, **kwargs) + else: + return self.inference_engine.execute_method(method, *args, **kwargs) diff --git a/verl/workers/sharding_manager/fsdp_sglang.py b/verl/workers/sharding_manager/fsdp_sglang.py index c9b6222883c..4037b573e28 100644 --- a/verl/workers/sharding_manager/fsdp_sglang.py +++ b/verl/workers/sharding_manager/fsdp_sglang.py @@ -37,6 +37,7 @@ from verl import DataProto from verl.protocol import all_gather_data_proto from verl.utils.debug import log_gpu_memory_usage +from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.torch_functional import broadcast_dict_tensor from .base import BaseShardingManager @@ -55,11 +56,13 @@ def __init__( model_config, full_params: bool = False, device_mesh: DeviceMesh = None, + offload_param: bool = False, ): self.module = module self.inference_engine = inference_engine self.model_config = model_config self.device_mesh = device_mesh + self.offload_param = offload_param # Full params self.full_params = full_params @@ -88,6 +91,8 @@ def __init__( def __enter__(self): torch.cuda.empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_fsdp_model_to_gpu(self.module) params = self.module.state_dict() log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory @@ -98,15 +103,11 @@ def __enter__(self): log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) torch.cuda.empty_cache() log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - # TODO: offload FSDP model weights - # self.module.cpu() - # torch.cuda.empty_cache() - # if torch.distributed.get_rank() == 0: - # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 060ca349582..1a9b587f6d0 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -18,17 +18,15 @@ import torch from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.fsdp.api import (FullStateDictConfig, - ShardedStateDictConfig, StateDictType) -from torch.distributed.fsdp.fully_sharded_data_parallel import \ - FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP from verl import DataProto from verl.protocol import all_gather_data_proto -from verl.third_party.vllm import LLM +from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps -from verl.third_party.vllm import vllm_version from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage +from verl.utils.fsdp_utils import load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from .base import BaseShardingManager @@ -45,11 +43,17 @@ def __init__( model_config, full_params: bool = False, device_mesh: DeviceMesh = None, + offload_param: bool = False, ): self.module = module + # For AsyncLLM, inference_engine and model_runner are defer intialized in vLLMAsyncRollout.load_model self.inference_engine = inference_engine + self.model_runner = ( + inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner if inference_engine else None + ) self.model_config = model_config self.device_mesh = device_mesh + self.offload_param = offload_param # Full params self.full_params = full_params @@ -64,8 +68,8 @@ def __init__( state_dict_config=ShardedStateDictConfig(), ) - self.tp_size = vllm_ps.get_tensor_model_parallel_world_size() - self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() + self.tp_size = self.device_mesh["infer_tp"].size() + self.tp_rank = self.device_mesh["infer_tp"].get_local_rank() # Note that torch_random_states may be different on each dp rank self.torch_random_states = torch.cuda.get_rng_state() @@ -90,6 +94,8 @@ def __enter__(self): torch.cuda.empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) + if self.offload_param: + load_fsdp_model_to_gpu(self.module) params = self.module.state_dict() log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) # Copy, not share memory @@ -112,6 +118,8 @@ def __enter__(self): self.update_params(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params + if self.offload_param: + offload_fsdp_model_to_cpu(self.module) torch.cuda.empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: @@ -119,12 +127,6 @@ def __enter__(self): log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) - # TODO: offload FSDP model weights - # self.module.cpu() - # torch.cuda.empty_cache() - # if torch.distributed.get_rank() == 0: - # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state() @@ -141,10 +143,6 @@ def __exit__(self, exc_type, exc_value, traceback): else: self.inference_engine.sleep(level=1) - # self.module.to('cuda') - # if torch.distributed.get_rank() == 0: - # print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - self.module.train() # add empty cache after each compute @@ -182,10 +180,13 @@ def postprocess_data(self, data: DataProto) -> DataProto: return data.chunk(chunks=self.tp_size)[self.tp_rank] def update_params(self, updated_params): - model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + model = self.model_runner.model patch_vllm_moe_model_weight_loader(model) world_size = torch.distributed.get_world_size() loaded_params = model.load_weights( - ((name, param.full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) for name, param in updated_params.items()) + ( + (name, param.full_tensor() if world_size != 1 and hasattr(param, "full_tensor") else param) + for name, param in updated_params.items() + ) ) - logger.info(f"vLLM load weights, loaded_params: {len(loaded_params)}") + logger.info("vLLM load weights, loaded_params: %d", len(loaded_params)) diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index e919ba634fb..74f827b0c2e 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -36,11 +36,17 @@ from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.debug import GPUMemoryLogger from verl.utils.megatron_utils import ( - broadcast_from_megatron_pp, broadcast_str_from_megatron_pp, - convert_megatron_model_to_transformers_model, get_model, unwrap_model) -from verl.utils.memory_buffer import (build_memory_buffer, - build_memory_reference_from_module, - get_weight_buffer_meta_from_module) + broadcast_from_megatron_pp, + broadcast_str_from_megatron_pp, + convert_megatron_model_to_transformers_model, + get_model, + unwrap_model, +) +from verl.utils.memory_buffer import ( + build_memory_buffer, + build_memory_reference_from_module, + get_weight_buffer_meta_from_module, +) from verl.utils.model import normalize_model_name from verl.utils.torch_functional import allgather_dict_tensors from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader