diff --git a/tests/experimental/agent_loop/agent_utils.py b/tests/experimental/agent_loop/agent_utils.py index 3c708c42cfb..fa4504c6af0 100644 --- a/tests/experimental/agent_loop/agent_utils.py +++ b/tests/experimental/agent_loop/agent_utils.py @@ -19,7 +19,7 @@ from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup from verl.single_controller.ray.base import create_colocated_worker_cls from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role -from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker +from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, RewardModelWorker def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup: @@ -30,6 +30,9 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG role_worker_mapping = { Role.ActorRollout: ray.remote(actor_rollout_cls), } + if config.reward_model.enable: + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, @@ -37,6 +40,15 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG mapping = { Role.ActorRollout: global_pool_id, } + if config.reward_model.enable_resource_pool: + mapping[Role.RewardModel] = "reward_pool" + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) resource_pool_manager.create_resource_pool() resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()} @@ -48,6 +60,12 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG ) resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + if config.reward_model.enable: + # we create a RM here + resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=config.reward_model) + resource_pool_to_cls[resource_pool]["rm"] = rm_cls + all_wg = {} for resource_pool, class_dict in resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) @@ -60,10 +78,16 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG if config.actor_rollout_ref.rollout.mode == "sync": return actor_rollout_wg + if config.reward_model.enable_resource_pool and config.reward_model.enable: + rm_wg = all_wg["rm"] + rm_wg.init_model() + else: + rm_wg = None # =========================== 2. Create AgentLoopManager =========================== agent_loop_manager = AgentLoopManager( config=config, worker_group=actor_rollout_wg, + rm_wg=rm_wg, ) return agent_loop_manager diff --git a/tests/experimental/agent_loop/test_agent_loop_reward_model.py b/tests/experimental/agent_loop/test_agent_loop_reward_model.py new file mode 100644 index 00000000000..cbc8c1a786d --- /dev/null +++ b/tests/experimental/agent_loop/test_agent_loop_reward_model.py @@ -0,0 +1,96 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import ray +from hydra import compose, initialize_config_dir +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + +from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.protocol import DataProto +from verl.trainer.main_ppo import create_rl_sampler +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + +def test_agent_loop_compute_score_with_model(): + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose("ppo_trainer") + + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.data.return_raw_chat = True + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.actor.use_dynamic_bsz = True + config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"] + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.prompt_length = 1024 + config.actor_rollout_ref.rollout.response_length = 4096 + config.reward_model.enable = True + config.reward_model.model.path = model_path + config.reward_model.use_dynamic_bsz = True + config.reward_model.forward_max_token_len_per_gpu = 6000 + config.reward_model.micro_batch_size_per_gpu = 40 + config.reward_model.enable_resource_pool = True + config.reward_model.n_gpus_per_node = 1 + config.reward_model.nnodes = 1 + config.reward_model.model.trust_remote_code = True + config.reward_model.model.input_tokenizer = None + config.trainer.n_gpus_per_node = 4 + config.trainer.nnodes = 1 + # 1. init agent loop manager + agent_loop_manager = init_agent_loop_manager(config) + + # 2. init dataset and dataloader + local_folder = os.path.expanduser("~/verl-data/gsm8k/") + data_files = [os.path.join(local_folder, "train.parquet")] + tokenizer = AutoTokenizer.from_pretrained(model_path) + + dataset = RLHFDataset( + data_files=data_files, + tokenizer=tokenizer, + config=config.data, + processor=None, + ) + + batch_size = 128 + sampler = create_rl_sampler(config.data, dataset) + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=config.data.dataloader_num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) + + # 3. generate_sequences with agent loop + batch_dict = next(iter(dataloader)) + batch = DataProto.from_single_dict(batch_dict) + gen_batch = agent_loop_manager.generate_sequences(prompts=batch) + + rm_scores = gen_batch.batch["rm_scores"] + sample_scores = rm_scores.sum(dim=1) + print(sample_scores) + ray.shutdown() diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 13526046a0d..1fe43ccddfd 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -15,8 +15,11 @@ import heapq import logging import os +import queue import random +import threading from abc import ABC, abstractmethod +from concurrent.futures import Future from typing import Any, Optional import hydra @@ -243,52 +246,106 @@ def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]: return decorator +@ray.remote(num_cpus=1) +class BatchExecutor: + """Batch executor is used to collect requests into a batch execution""" + + def __init__(self, batch_func, micro_batch_size=1, max_batch_size=None): + """ + + Args: + batch_func: batch processing function. + micro_batch_size (int, optional): micro batch size. Defaults to 1. + max_batch_size: batch size for batching. + """ + self._q = queue.Queue() + self._batch_func = batch_func + self._max_batch = max_batch_size + self._micro_batch_size = micro_batch_size + + self._worker = threading.Thread(target=self._worker_loop, daemon=True) + self._worker.start() + + async def submit_task(self, item): + """ + Blocking submission, returning Future + Args: + item: function input + + Returns: + fut: function output + """ + fut = Future() + self._q.put((item, fut)) + async_fut = asyncio.wrap_future(fut) + res = await async_fut + return res + + def _worker_loop(self): + while True: + # 1. Fetch a full batch (block until at least one) + first, first_fut = self._q.get() + items = [first] + futs = [first_fut] + + # Take the remaining tasks at once + while True: + try: + next_item, next_fut = self._q.get_nowait() + items.append(next_item) + futs.append(next_fut) + if self._max_batch and len(items) >= self._max_batch: + break + except queue.Empty: + while len(items) % self._micro_batch_size != 0: + next_item, next_fut = self._q.get() + items.append(next_item) + futs.append(next_fut) + if self._max_batch and len(items) >= self._max_batch: + break + break + + try: + results = self._batch_func(items) + except Exception as e: + for f in futs: + f.set_exception(e) + else: + for f, r in zip(futs, results, strict=False): + f.set_result(r) + + @ray.remote(num_cpus=1) class RewardManagerWorker: """Reward manager worker to compute reward score asynchronously to overlap with agent loop.""" - def __init__(self, config: DictConfig, local_path: str) -> None: + def __init__(self, config: DictConfig, local_path: str, rm_executor: BatchExecutor = None) -> None: tokenizer = hf_tokenizer(local_path, trust_remote_code=True) self.reward_manager = load_reward_manager( config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) ) + self.rm_executor = rm_executor self.loop = asyncio.get_event_loop() - async def compute_score(self, output: AgentLoopOutput, kwargs: dict) -> dict: + async def compute_score( + self, + data: DataProto, + ) -> dict: """Compute reward score for agent loop output. NOTE: Since `reward_manager.__call__` is blocking function, we run it in thread pool to compute multiple samples in parallel. Args: - output (AgentLoopOutput): Agent loop output. - kwargs (dict): Dataset fields from `verl.utils.dataset.RLHFDataset`. + data: reward function input + Returns: dict: Reward score and reward extra info. """ - prompts = torch.tensor(output.prompt_ids, dtype=torch.long).unsqueeze(0) - responses = torch.tensor(output.response_ids, dtype=torch.long).unsqueeze(0) - attention_mask = torch.ones((1, prompts.shape[1] + responses.shape[1]), dtype=torch.long) - batch = TensorDict( - { - "prompts": prompts, # [1, prompt_length] - "responses": responses, # [1, response_length] - "attention_mask": attention_mask, # [1, prompt_length + response_length] - }, - batch_size=1, - ) - non_tensor_batch = { - **{k: np.array([v]) for k, v in kwargs.items()}, - "__num_turns__": np.array([output.num_turns]), - } - data = DataProto( - batch=batch, - non_tensor_batch=non_tensor_batch, - ) result = await self.loop.run_in_executor( None, - self.reward_manager, + self.reward_wrapper, data, True, # return_dict ) @@ -297,12 +354,31 @@ async def compute_score(self, output: AgentLoopOutput, kwargs: dict) -> dict: reward_extra_info = {k: v[0] for k, v in result.get("reward_extra_info", {}).items()} return {"reward_score": reward_score, "reward_extra_info": reward_extra_info} + def reward_wrapper(self, data: DataProto, return_dict=False) -> torch.Tensor: + """Assemble reward functions and reward model into one function and expose it to the event loop + + + Args: + return_dict: whether return as dict + data: DataProto from compute reward score + + Returns: + torch.Tensor: Reward score tensor. + """ + if self.rm_executor is not None: + res = ray.get(self.rm_executor.submit_task.remote(data)) + data = data.union(res) + + return self.reward_manager(data, return_dict) + @ray.remote class AgentLoopWorker: """Agent loop worker takes a batch of messages and run each message in an agent loop.""" - def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle]): + def __init__( + self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None + ): """Initialize agent loop manager. Args: @@ -311,6 +387,7 @@ def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandl """ self.config = config self.server_manager = AsyncLLMServerManager(config, server_handles) + self.rm_executor = rm_executor model_path = config.actor_rollout_ref.model.path self.model_name = "/".join(model_path.split("/")[-2:]) @@ -333,7 +410,7 @@ def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandl node_id=ray.get_runtime_context().get_node_id(), soft=False, ), - ).remote(self.config, local_path) + ).remote(self.config, local_path, self.rm_executor) trace_config = self.config.actor_rollout_ref.rollout.get("trace", {}) RolloutTraceConfig.init( @@ -429,10 +506,6 @@ async def _run_agent_loop( output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) # Some AgentLoop may have already computed the reward score, e.g SWE-agent. - if output.reward_score is None and not self.config.reward_model.enable: - result = await self.reward_manager_worker.compute_score.remote(output, kwargs) - output.reward_score = result["reward_score"] - output.extra_fields["reward_extra_info"] = result["reward_extra_info"] # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py # prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4]) @@ -528,6 +601,31 @@ async def _run_agent_loop( ).unsqueeze(0) # (1, 3, seq_len) else: position_ids = compute_position_id_with_mask(attention_mask) # (1, seq_len) + enable_async_reward = ( + self.rm_executor is not None and self.config.reward_model.enable_resource_pool + ) or not self.config.reward_model.enable + if output.reward_score is None and enable_async_reward: + batch = TensorDict( + { + "prompts": prompt_output["input_ids"], # [1, prompt_length] + "responses": response_output["input_ids"], # [1, response_length] + "attention_mask": attention_mask, # [1, prompt_length + response_length] + "input_ids": input_ids, # [1, prompt_length + response_length] + "position_ids": position_ids, + }, + batch_size=1, + ) + non_tensor_batch = { + **{k: np.array([v]) for k, v in kwargs.items()}, + "__num_turns__": np.array([output.num_turns]), + } + data = DataProto( + batch=batch, + non_tensor_batch=non_tensor_batch, + ) + result = await self.reward_manager_worker.compute_score.remote(data) + output.reward_score = result["reward_score"] + output.extra_fields["reward_extra_info"] = result["reward_extra_info"] return _InternalAgentLoopOutput( prompt_ids=prompt_output["input_ids"], @@ -628,7 +726,7 @@ async def get_trajectory_info(step, index, validate): class AgentLoopManager: """Agent loop manager that manages a group of agent loop workers.""" - def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup, rm_wg: RayWorkerGroup = None): """Initialize agent loop manager. Args: @@ -637,6 +735,29 @@ def __init__(self, config: DictConfig, worker_group: RayWorkerGroup): """ self.config = config self.worker_group = worker_group + self.rm_executor = None + self.rm_micro_batch_size = None + if rm_wg: + + def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]: + new_data_list = [] + for data in data_list: + temp_non_tensor_batch = {"__num_turns__": data.non_tensor_batch["__num_turns__"]} + temp_data = DataProto(batch=data.batch, non_tensor_batch=temp_non_tensor_batch) + new_data_list.append(temp_data) + + new_batch = DataProto.concat(new_data_list) + out_data = rm_wg.compute_rm_score(new_batch) + return out_data.split(1) + + self.rm_executor = BatchExecutor.options( + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ), + ).remote(batch_fn, rm_wg.world_size) + + self.rm_micro_batch_size = rm_wg.world_size self._initialize_llm_servers() self._init_agent_loop_workers() @@ -710,7 +831,7 @@ def _init_agent_loop_workers(self): scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( node_id=node_id, soft=True ), - ).remote(self.config, self.async_llm_servers) + ).remote(self.config, self.async_llm_servers, self.rm_executor) ) def generate_sequences(self, prompts: DataProto) -> DataProto: @@ -722,6 +843,11 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: Returns: DataProto: Output batch. """ + + if self.rm_micro_batch_size and len(prompts) % self.rm_micro_batch_size != 0: + raise ValueError( + f"The length of prompts {len(prompts)} cannot divide the world size of rm_wg {self.rm_micro_batch_size}" + ) if self.config.actor_rollout_ref.rollout.free_cache_engine: self.wake_up() chunkes = prompts.chunk(len(self.agent_loop_workers)) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 03d4d5cacf1..6c1ab9fb215 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -386,6 +386,9 @@ critic: data_loader_seed: ${oc.select:actor_rollout_ref.actor.data_loader_seed,null} reward_model: enable: false + enable_resource_pool: false + n_gpus_per_node: 0 + nnodes: 0 strategy: megatron model: input_tokenizer: ${actor_rollout_ref.model.path} diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 3c7a73f7e37..a2c168fce52 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -354,6 +354,9 @@ critic: grad_clip: 1.0 reward_model: enable: false + enable_resource_pool: false + n_gpus_per_node: 0 + nnodes: 0 strategy: fsdp model: input_tokenizer: ${actor_rollout_ref.model.path} diff --git a/verl/trainer/config/reward_model/reward_model.yaml b/verl/trainer/config/reward_model/reward_model.yaml index 08ae37ac9db..e9ffc60fbc6 100644 --- a/verl/trainer/config/reward_model/reward_model.yaml +++ b/verl/trainer/config/reward_model/reward_model.yaml @@ -6,6 +6,12 @@ # If False, the following parameters are not effective enable: False +# Whether to deploy the model to a separate resource pool. +# If true, n_gpus_per_node & nnodes will be used to determine the resource node. +enable_resource_pool: False +n_gpus_per_node: 0 +nnodes: 0 + # FSDP strategy: "fsdp" or "fsdp2" strategy: ??? diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 7ab01b456f7..fcf748ba99e 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -171,6 +171,16 @@ def init_resource_pool_mgr(self, config): resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } + # TODO Here you can use the new registration method to support dynamic registration of roles + if config.reward_model.enable_resource_pool: + if config.reward_model.n_gpus_per_node <= 0: + raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0") + if config.reward_model.nnodes <= 0: + raise ValueError("config.reward_model.nnodes must be greater than 0") + + reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes + resource_pool_spec["reward_pool"] = reward_pool + self.mapping[Role.ActorRollout] = global_pool_id self.mapping[Role.Critic] = global_pool_id from verl.trainer.ppo.ray_trainer import ResourcePoolManager @@ -190,7 +200,10 @@ def add_reward_model_worker(self, config): else: raise NotImplementedError self.role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) - self.mapping[Role.RewardModel] = "global_pool" + if config.reward_model.enable_resource_pool: + self.mapping[Role.RewardModel] = "reward_pool" + else: + self.mapping[Role.RewardModel] = "global_pool" def add_ref_policy_worker(self, config, ref_policy_cls): """Add reference policy worker if KL loss or KL reward is used.""" diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d2508a1259c..c953e5c1ed5 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -720,6 +720,7 @@ def init_workers(self): self.ref_policy_wg = all_wg["ref"] self.ref_policy_wg.init_model() + self.rm_wg = None if self.use_rm: self.rm_wg = all_wg["rm"] self.rm_wg.init_model() @@ -735,8 +736,7 @@ def init_workers(self): self.async_rollout_mode = True self.async_rollout_manager = AgentLoopManager( - config=self.config, - worker_group=self.actor_rollout_wg, + config=self.config, worker_group=self.actor_rollout_wg, rm_wg=self.rm_wg ) def _save_checkpoint(self): @@ -1023,7 +1023,7 @@ def fit(self): with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score - if self.use_rm: + if self.use_rm and "rm_scores" not in batch.batch.keys(): reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor)