diff --git a/.github/workflows/reward_model_sglang.yml b/.github/workflows/reward_model_sglang.yml index 646bad42465..cce6083bbf8 100644 --- a/.github/workflows/reward_model_sglang.yml +++ b/.github/workflows/reward_model_sglang.yml @@ -98,7 +98,7 @@ jobs: - name: Install the current repository run: | pip3 install -e .[test] - pip3 install sglang-router==0.2.2 + pip3 install sglang-router==0.1.8 - name: Prepare gsm8k dataset run: | ray stop --force @@ -111,10 +111,6 @@ jobs: run: | unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY ROLLOUT_NAME=sglang pytest -s -x tests/experimental/reward/test_agent_loop_reward_manager.py - - name: Running sglang agent loop with reward model colocate tests on 8 L20 GPUs - run: | - unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY - ROLLOUT_NAME=sglang pytest -s -x tests/experimental/reward/test_agent_reward_loop_colocate.py cleanup: runs-on: ubuntu-latest diff --git a/.github/workflows/reward_model_vllm.yml b/.github/workflows/reward_model_vllm.yml index ef9c7da113f..0a15a5f63b5 100644 --- a/.github/workflows/reward_model_vllm.yml +++ b/.github/workflows/reward_model_vllm.yml @@ -110,10 +110,6 @@ jobs: run: | unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY ROLLOUT_NAME=vllm pytest -s -x tests/experimental/reward/test_agent_loop_reward_manager.py - - name: Running vllm agent loop with reward model colocate tests on 8 L20 GPUs - run: | - unset http_proxy https_proxy HTTP_PROXY HTTPS_PROXY - ROLLOUT_NAME=vllm pytest -s -x tests/experimental/reward/test_agent_reward_loop_colocate.py cleanup: runs-on: ubuntu-latest diff --git a/tests/experimental/agent_loop/agent_utils.py b/tests/experimental/agent_loop/agent_utils.py index f85e9aa341a..fa4504c6af0 100644 --- a/tests/experimental/agent_loop/agent_utils.py +++ b/tests/experimental/agent_loop/agent_utils.py @@ -79,14 +79,15 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG return actor_rollout_wg if config.reward_model.enable_resource_pool and config.reward_model.enable: - rm_resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_wg = all_wg["rm"] + rm_wg.init_model() else: - rm_resource_pool = None + rm_wg = None # =========================== 2. Create AgentLoopManager =========================== agent_loop_manager = AgentLoopManager( config=config, worker_group=actor_rollout_wg, - rm_resource_pool=rm_resource_pool, + rm_wg=rm_wg, ) return agent_loop_manager diff --git a/tests/experimental/reward/test_agent_reward_loop_colocate.py b/tests/experimental/reward/test_agent_reward_loop_colocate.py deleted file mode 100644 index ad5110b6e60..00000000000 --- a/tests/experimental/reward/test_agent_reward_loop_colocate.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import ray -from hydra import compose, initialize_config_dir -from torchdata.stateful_dataloader import StatefulDataLoader -from transformers import AutoTokenizer - -from verl.experimental.agent_loop import AgentLoopManager -from verl.experimental.reward.reward_model import RewardModelManager -from verl.protocol import DataProto -from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup -from verl.trainer.main_ppo import create_rl_sampler -from verl.trainer.ppo.ray_trainer import ResourcePoolManager -from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn -from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker - - -def test_agent_loop_reward_manager(): - ray.init( - runtime_env={ - "env_vars": { - "TOKENIZERS_PARALLELISM": "true", - "NCCL_DEBUG": "WARN", - "VLLM_LOGGING_LEVEL": "INFO", - "VLLM_USE_V1": "1", - } - } - ) - with initialize_config_dir(config_dir=os.path.abspath("recipe/fapo/config")): - config = compose("rm_config") - - rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-0.5B-Instruct") - reward_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct") - - # actor_rollout_ref config - config.data.return_raw_chat = True - config.data.max_prompt_length = 1024 - config.data.max_response_length = 4096 - config.actor_rollout_ref.model.path = rollout_model_path - config.actor_rollout_ref.actor.use_dynamic_bsz = True - config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") - config.actor_rollout_ref.rollout.mode = "async" - config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 - config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.8 - config.actor_rollout_ref.rollout.enforce_eager = True - config.actor_rollout_ref.rollout.prompt_length = 1024 - config.actor_rollout_ref.rollout.response_length = 4096 - config.actor_rollout_ref.rollout.skip_tokenizer_init = True - config.trainer.n_gpus_per_node = 8 - config.trainer.nnodes = 1 - - config.reward_model.reward_manager = "dapo" - config.reward_model.enable = True - config.reward_model.enable_resource_pool = False - config.reward_model.n_gpus_per_node = 8 - config.reward_model.model.path = reward_model_path - config.reward_model.rollout.name = os.getenv("ROLLOUT_NAME", "vllm") - config.reward_model.rollout.gpu_memory_utilization = 0.8 - config.reward_model.rollout.tensor_model_parallel_size = 2 - config.reward_model.rollout.skip_tokenizer_init = False - config.reward_model.rollout.prompt_length = 5120 - config.reward_model.rollout.response_length = 4096 - config.custom_reward_function.path = "tests/experimental/reward/reward_fn.py" - config.custom_reward_function.name = "compute_score_gsm8k" - - # 1. init reward model manager - actor_rollout_cls = ( - AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker - ) - global_pool_id = "global_pool" - resource_pool_spec = { - global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, - } - resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=None) - resource_pool_manager.create_resource_pool() - resource_pool = resource_pool_manager.resource_pool_dict[global_pool_id] - actor_rollout_cls = RayClassWithInitArgs( - cls=ray.remote(actor_rollout_cls), config=config.actor_rollout_ref, role="actor_rollout" - ) - actor_rollout_wg = RayWorkerGroup( - resource_pool=resource_pool, - ray_cls_with_init=actor_rollout_cls, - ) - actor_rollout_wg.init_model() - - agent_loop_manager = AgentLoopManager(config, worker_group=actor_rollout_wg) - reward_model_manager = RewardModelManager(config.reward_model, resource_pool=resource_pool) - - # 2. init test data - local_folder = os.path.expanduser("~/data/gsm8k/") - data_files = [os.path.join(local_folder, "train.parquet")] - tokenizer = AutoTokenizer.from_pretrained(rollout_model_path) - - dataset = RLHFDataset( - data_files=data_files, - tokenizer=tokenizer, - config=config.data, - processor=None, - ) - - batch_size = 64 - sampler = create_rl_sampler(config.data, dataset) - dataloader = StatefulDataLoader( - dataset=dataset, - batch_size=batch_size, - num_workers=config.data.dataloader_num_workers, - drop_last=True, - collate_fn=collate_fn, - sampler=sampler, - ) - - # 3. generate responses - batch_dict = next(iter(dataloader)) - batch = DataProto.from_single_dict(batch_dict) - gen_batch = agent_loop_manager.generate_sequences(prompts=batch) - sampling_params = {"temperature": 0.0, "top_p": 1.0, "max_tokens": 1024} - genrm_outputs = reward_model_manager.generate_sequences(gen_batch, sampling_params=sampling_params) - - print(genrm_outputs[0]) - - print("done") - - ray.shutdown() diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index b758c95ff96..8650f6c4660 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -34,7 +34,7 @@ from verl.experimental.agent_loop.utils import resolve_config_path from verl.experimental.reward import RewardManagerWorker from verl.protocol import DataProto -from verl.single_controller.ray.base import RayResourcePool, RayWorkerGroup +from verl.single_controller.ray.base import RayWorkerGroup from verl.utils import hf_processor, hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.model import compute_position_id_with_mask @@ -686,15 +686,12 @@ async def get_trajectory_info(step, index, validate): class AgentLoopManager: """Agent loop manager that manages a group of agent loop workers.""" - def __init__( - self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_resource_pool: RayResourcePool = None - ): + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None): """Initialize agent loop manager. Args: config (DictConfig): trainer config. worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode. - rm_resource_pool (RayResourcePool): Resource pool for reward model (Standalone mode). """ self.config = config self.worker_group = worker_group @@ -703,9 +700,7 @@ def __init__( if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool: from verl.experimental.reward import RewardModelManager - # TODO (dyy): current rm is colocated with the legacy fsdp/megatron rm - # future pr will depericate fsdp/megatron rm and init RewardModelManager in standalone mode - self.reward_model_manager = RewardModelManager(config.reward_model, rm_resource_pool) + self.reward_model_manager = RewardModelManager(config.reward_model, rm_wg) self.reward_router_address = self.reward_model_manager.get_router_address() # for recipe to change diff --git a/verl/experimental/reward/reward_model.py b/verl/experimental/reward/reward_model.py index 7711f361ea3..25b5b37b5ea 100644 --- a/verl/experimental/reward/reward_model.py +++ b/verl/experimental/reward/reward_model.py @@ -21,7 +21,7 @@ from openai.types.chat import ChatCompletion from verl import DataProto -from verl.single_controller.ray.base import RayResourcePool +from verl.single_controller.ray.base import RayWorkerGroup from verl.workers.config import HFModelConfig, RewardModelConfig from verl.workers.rollout.replica import get_rollout_replica_class @@ -32,16 +32,16 @@ class RewardModelManager: """Reward model manager.""" - def __init__(self, config: RewardModelConfig, resource_pool: RayResourcePool = None): + def __init__(self, config: RewardModelConfig, worker_group: RayWorkerGroup = None): """ Initialize the reward model manager. Args: config (RewardModelConfig): Reward model configuration. - resource_pool (RayResourcePool, optional): Resource pool. Defaults to None. + worker_group (RayWorkerGroup, optional): Worker group. Defaults to None. """ self.config = config - self.resource_pool = resource_pool + self.worker_group = worker_group self._initialize_llm_servers() self._initialize_router() if self.config.rollout.free_cache_engine: @@ -50,8 +50,8 @@ def __init__(self, config: RewardModelConfig, resource_pool: RayResourcePool = N def _initialize_llm_servers(self): rollout_world_size = self.config.rollout.tensor_model_parallel_size world_size = ( - self.resource_pool.world_size - if self.resource_pool # colocate mode + self.worker_group.world_size + if self.worker_group # colocate mode else self.config.n_gpus_per_node * self.config.nnodes # standalone mode ) num_replicas = world_size // rollout_world_size @@ -74,11 +74,10 @@ def _initialize_llm_servers(self): ) for replica_rank in range(num_replicas) ] - if self.resource_pool: - self._run_all([server.init_colocated(self.resource_pool) for server in self.rollout_replicas]) + if self.worker_group: + self._run_all([server.init_colocated(self.worker_group) for server in self.rollout_replicas]) else: self._run_all([server.init_standalone() for server in self.rollout_replicas]) - self.server_handles = [server._server_handle for server in self.rollout_replicas] self.server_addresses = [server._server_address for server in self.rollout_replicas] @@ -124,8 +123,6 @@ async def chat_complete(self, chat_complete_request: dict): await session.close() def generate_sequences(self, prompts: DataProto, sampling_params: dict): - if self.config.rollout.free_cache_engine: - self.wake_up() chat_complete_requests = [ { "model": self.config.model.path, @@ -136,6 +133,4 @@ def generate_sequences(self, prompts: DataProto, sampling_params: dict): ] tasks = [self.chat_complete(chat_complete_request) for chat_complete_request in chat_complete_requests] results = self._run_all(tasks) - if self.config.rollout.free_cache_engine: - self.sleep() return results diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 00488c947be..34176d87ef2 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -84,9 +84,8 @@ def create_resource_pool(self): # For FSDP backend, we recommend using max_colocate_count=1 that merge all WorkerGroups into one. # For Megatron backend, we recommend using max_colocate_count>1 # that can utilize different WorkerGroup for differnt models - # max_colocate_count = 3: actor_critic_ref, rollout, reward model (optional) resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=3, name_prefix=resource_pool_name + process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name ) self.resource_pool_dict[resource_pool_name] = resource_pool @@ -779,15 +778,8 @@ def init_workers(self): from verl.experimental.agent_loop import AgentLoopManager self.async_rollout_mode = True - if self.config.reward_model.enable and self.config.reward_model.enable_resource_pool: - rm_resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - else: - rm_resource_pool = None - self.async_rollout_manager = AgentLoopManager( - config=self.config, - worker_group=self.actor_rollout_wg, - rm_resource_pool=rm_resource_pool, + config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg ) def _save_checkpoint(self): diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index 223b8c0ccc5..a78ef42fcce 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -120,8 +120,9 @@ async def init_hybrid(self, worker_group: RayWorkerGroup): ] await self.launch_servers() + # TODO(@dyy): init with resource_pool? # TODO(sgm): this should be the default solution, but need to make the RolloutMode more clear. - async def init_colocated(self, resource_pool: RayResourcePool): + async def init_colocated(self, worker_group: RayWorkerGroup): """Init colocated rollout server, rollout engine and hybrid engine colocated in same ray placement group but in separate processes. @@ -129,19 +130,9 @@ async def init_colocated(self, resource_pool: RayResourcePool): resource_pool: RayResourcePool, ray placement group where hybrid engine processes have been launched. """ self.rollout_mode = RolloutMode.COLOCATED - self.resource_pool = resource_pool - - worker_group = RayWorkerGroup( - resource_pool=self.resource_pool, - ray_cls_with_init=self.get_ray_class_with_init_args(), - bin_pack=False, - name_prefix=f"rollout_colocate_{self.replica_rank}" - if not self.is_reward_model - else f"rollout_reward_colocate_{self.replica_rank}", - replica_rank=self.replica_rank, - replica_world_size=self.world_size, - ) - self.workers = worker_group.workers + self.workers = worker_group.workers[ + self.world_size * self.replica_rank : self.world_size * (self.replica_rank + 1) + ] await self.launch_servers() async def init_standalone(self): diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 44645c36b8e..1d3d657d925 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -205,6 +205,7 @@ async def wake_up(self): await asyncio.gather(*[worker.wake_up.remote() for worker in self.workers]) elif self.rollout_mode == RolloutMode.COLOCATED: # Directly call engine to wake up without sync weights. + # FIXME(@wuxibin): sglang seems resume with random weights. obj = ResumeMemoryOccupationReqInput(tags=["kv_cache", "weights"]) await self.tokenizer_manager.resume_memory_occupation(obj, None) await self.tokenizer_manager.flush_cache() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 6a8cd55b9c6..a14314188af 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -623,6 +623,8 @@ async def _execute_method(self, method: str | bytes, *args, **kwargs): return self._init_worker(*args, **kwargs) elif method == "load_model": return self._load_model(*args, **kwargs) + elif method == "sleep" or method == "wake_up": + raise ValueError("wake_up and sleep should not be called through ZeroMQ") else: return self.inference_engine.execute_method(method, *args, **kwargs)