From 7e38a8cd8ba60cc2be33e56fb227dfd08ecf16af Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 14 Aug 2025 18:56:22 +0800 Subject: [PATCH 1/7] [rollout] feat: compute reward score in agent loop --- .../agent_loop/test_agent_loop_reward.py | 80 ++++++++++++++++++ .../agent_loop/test_basic_agent_loop.py | 6 ++ .../agent_loop/test_multi_modal.py | 2 + verl/experimental/agent_loop/agent_loop.py | 83 +++++++++++++++++-- 4 files changed, 165 insertions(+), 6 deletions(-) create mode 100644 tests/experimental/agent_loop/test_agent_loop_reward.py diff --git a/tests/experimental/agent_loop/test_agent_loop_reward.py b/tests/experimental/agent_loop/test_agent_loop_reward.py new file mode 100644 index 00000000000..063835798cb --- /dev/null +++ b/tests/experimental/agent_loop/test_agent_loop_reward.py @@ -0,0 +1,80 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import ray +from hydra import compose, initialize_config_dir +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + +from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager +from verl.protocol import DataProto +from verl.trainer.main_ppo import create_rl_sampler +from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + + +def test_agent_loop_compute_score(): + ray.init() + + with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): + config = compose("ppo_trainer") + + model_path = "Qwen/Qwen2.5-1.5B-Instruct" + config.data.return_raw_chat = True + config.actor_rollout_ref.model.path = model_path + config.actor_rollout_ref.actor.use_dynamic_bsz = True + config.actor_rollout_ref.rollout.name = "sglang" + config.actor_rollout_ref.rollout.mode = "async" + config.actor_rollout_ref.rollout.prompt_length = 1024 + config.actor_rollout_ref.rollout.response_length = 4096 + + # 1. init agent loop manager + agent_loop_manager = init_agent_loop_manager(config) + + # 2. init dataset and dataloader + local_folder = os.path.expanduser("~/verl-data/gsm8k/") + data_files = [os.path.join(local_folder, "train.parquet")] + tokenizer = AutoTokenizer.from_pretrained(model_path) + + dataset = RLHFDataset( + data_files=data_files, + tokenizer=tokenizer, + config=config.data, + processor=None, + ) + + batch_size = 128 + sampler = create_rl_sampler(config.data, dataset) + dataloader = StatefulDataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=config.data.dataloader_num_workers, + drop_last=True, + collate_fn=collate_fn, + sampler=sampler, + ) + + # 3. generate_sequences with agent loop + batch_dict = next(iter(dataloader)) + batch = DataProto.from_single_dict(batch_dict) + gen_batch = agent_loop_manager.generate_sequences(prompts=batch) + + rm_scores = gen_batch.batch["rm_scores"] + sample_scores = rm_scores.sum(dim=1) + assert sample_scores.min() == 0.0, f"min score: {sample_scores.min()}" + assert sample_scores.max() == 1.0, f"max score: {sample_scores.max()}" + print(f"gsm8k acc: {sample_scores.mean()}") + + print("Test passed!") + ray.shutdown() diff --git a/tests/experimental/agent_loop/test_basic_agent_loop.py b/tests/experimental/agent_loop/test_basic_agent_loop.py index ac0cd3246cc..0de7804a5d4 100644 --- a/tests/experimental/agent_loop/test_basic_agent_loop.py +++ b/tests/experimental/agent_loop/test_basic_agent_loop.py @@ -83,6 +83,8 @@ def test_single_turn(init_config): non_tensor_batch={ "raw_prompt": np.array(raw_prompts), "agent_name": np.array(["single_turn_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), }, ) n = init_config.actor_rollout_ref.rollout.n @@ -95,6 +97,7 @@ def test_single_turn(init_config): assert result.batch["input_ids"].size(1) == seq_len assert result.batch["attention_mask"].size(1) == seq_len assert result.batch["position_ids"].size(1) == seq_len + assert result.batch["rm_scores"].size(1) == result.batch["responses"].size(1) # check turns num_turns = result.non_tensor_batch["__num_turns__"] @@ -226,6 +229,8 @@ def test_tool_agent(init_config): non_tensor_batch={ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), "agent_name": np.array(["tool_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), }, ) batch = batch.repeat(n) @@ -248,6 +253,7 @@ def test_tool_agent(init_config): responses = result.batch["responses"] response_mask = result.batch["response_mask"] attention_mask = result.batch["attention_mask"] + assert result.batch["rm_scores"].size(1) == responses.size(1) assert responses.size() == response_mask.size(), f"{responses.size()} != {response_mask.size()}" response_length = response_mask.size(1) diff --git a/tests/experimental/agent_loop/test_multi_modal.py b/tests/experimental/agent_loop/test_multi_modal.py index be56fc6e9cc..35dac568cd2 100644 --- a/tests/experimental/agent_loop/test_multi_modal.py +++ b/tests/experimental/agent_loop/test_multi_modal.py @@ -175,6 +175,8 @@ def test_multimodal_tool_agent(init_config): non_tensor_batch={ "raw_prompt": np.array([np.array(prompt) for prompt in raw_prompts], dtype=object), "agent_name": np.array(["tool_agent"] * len(raw_prompts)), + "data_source": np.array(["openai/gsm8k"] * len(raw_prompts)), + "reward_model": np.array([{"style": "rule", "ground_truth": "1.0"}] * len(raw_prompts)), }, ) batch = batch.repeat(n) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 29f8dc67e0e..759c58e69b4 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -31,6 +31,7 @@ from verl.protocol import DataProto from verl.single_controller.ray.base import RayWorkerGroup +from verl.trainer.ppo.reward import load_reward_manager from verl.utils import hf_processor, hf_tokenizer from verl.utils.fs import copy_to_local from verl.utils.model import compute_position_id_with_mask @@ -125,6 +126,8 @@ class AgentLoopOutput(BaseModel): """Response mask, 1 for LLM generated token, 0 for tool response token.""" multi_modal_data: Optional[dict[str, Any]] = None """Multi-modal data for multi-modal tools.""" + reward_score: float = None + """Reward score for the trajectory.""" num_turns: int = 0 """Number of chat turns, including user, assistant, tool.""" metrics: AgentLoopMetrics @@ -234,6 +237,57 @@ def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]: return decorator +@ray.remote(num_cpus=1) +class RewardManagerWorker: + """Reward manager worker to compute reward score asynchronously to overlap with agent loop.""" + + def __init__(self, config: DictConfig, local_path: str) -> None: + tokenizer = hf_tokenizer(local_path, trust_remote_code=True) + self.reward_manager = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + self.loop = asyncio.get_event_loop() + + async def compute_score(self, output: AgentLoopOutput, kwargs: dict) -> float: + """Compute reward score for agent loop output. + + NOTE: Since `reward_manager.__call__` is blocking function, we run it in thread pool to + compute multiple samples in parallel. + + Args: + output (AgentLoopOutput): Agent loop output. + kwargs (dict): Dataset fields from `verl.utils.dataset.RLHFDataset`. + + Returns: + float: Reward score. + """ + prompts = torch.tensor(output.prompt_ids, dtype=torch.long).unsqueeze(0) + responses = torch.tensor(output.response_ids, dtype=torch.long).unsqueeze(0) + attention_mask = torch.ones((1, prompts.shape[1] + responses.shape[1]), dtype=torch.long) + batch = TensorDict( + { + "prompts": prompts, # [1, prompt_length] + "responses": responses, # [1, response_length] + "attention_mask": attention_mask, # [1, prompt_length + response_length] + }, + batch_size=1, + ) + non_tensor_batch = { + **{k: np.array([v]) for k, v in kwargs.items()}, + "__num_turns__": np.array([output.num_turns]), + } + data = DataProto( + batch=batch, + non_tensor_batch=non_tensor_batch, + ) + reward_tensor = await self.loop.run_in_executor( + None, + self.reward_manager, + data, + ) + return reward_tensor.sum(dim=-1).item() + + @ray.remote class AgentLoopWorker: """Agent loop worker takes a batch of messages and run each message in an agent loop.""" @@ -264,6 +318,13 @@ def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandl self.processor.chat_template = self.config.actor_rollout_ref.model.custom_chat_template self.tokenizer.chat_template = self.config.actor_rollout_ref.model.custom_chat_template + self.reward_manager_worker = RewardManagerWorker.options( + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ), + ).remote(self.config, local_path) + trace_config = self.config.actor_rollout_ref.rollout.get("trace", {}) RolloutTraceConfig.init( self.config.trainer.project_name, @@ -356,6 +417,10 @@ async def _run_agent_loop( ) output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) + # Some AgentLoop may have already computed the reward score, e.g SWE-agent. + if output.reward_score is None: + output.reward_score = await self.reward_manager_worker.compute_score.remote(output, kwargs) + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py # prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4]) # response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0]) @@ -455,6 +520,7 @@ async def _run_agent_loop( attention_mask=attention_mask, multi_modal_inputs=multi_modal_inputs, multi_modal_data=output.multi_modal_data, + reward_score=output.reward_score, num_turns=output.num_turns, metrics=output.metrics, ) @@ -482,6 +548,14 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto: batch_size=len(inputs), ) + scores = [input.reward_score for input in inputs] + if all(score is not None for score in scores): + prompt_length = prompt_ids.size(1) + response_length = attention_mask[:, prompt_length:].sum(dim=1) - 1 + rm_scores = torch.zeros_like(response_mask, dtype=torch.float32) + rm_scores[torch.arange(response_mask.size(0)), response_length] = torch.tensor(scores, dtype=torch.float32) + batch["rm_scores"] = rm_scores + non_tensor_batch = { "__num_turns__": np.array([input.num_turns for input in inputs], dtype=np.int32), } @@ -546,9 +620,6 @@ def _initialize_llm_servers(self): for worker in self.worker_group.workers ] ) - self.rollout_dp_size = self.worker_group.world_size // self.rollout_tp_size - # Store the node IDs for the servers - self.server_node_ids = [workers_info[i * self.rollout_tp_size] for i in range(self.rollout_dp_size)] assert len(workers_info) == self.worker_group.world_size self.async_llm_servers = [None] * self.rollout_dp_size @@ -594,11 +665,11 @@ def _initialize_llm_servers(self): def _init_agent_loop_workers(self): self.agent_loop_workers = [] num_workers = self.config.actor_rollout_ref.rollout.agent.num_workers - num_server_nodes = len(self.server_node_ids) + node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"]["CPU"] > 0] for i in range(num_workers): - # Round-robin scheduling over the server nodes - node_id = self.server_node_ids[i % num_server_nodes] + # Round-robin scheduling over the all nodes + node_id = node_ids[i % len(node_ids)] self.agent_loop_workers.append( AgentLoopWorker.options( name=f"agent_loop_worker_{i}", From bb309270b16ce8bd7e25756476da461d5bca8f78 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 14 Aug 2025 19:11:46 +0800 Subject: [PATCH 2/7] fix unit test --- tests/experimental/agent_loop/test_agent_loop_reward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/experimental/agent_loop/test_agent_loop_reward.py b/tests/experimental/agent_loop/test_agent_loop_reward.py index 063835798cb..ef3fc962cda 100644 --- a/tests/experimental/agent_loop/test_agent_loop_reward.py +++ b/tests/experimental/agent_loop/test_agent_loop_reward.py @@ -34,7 +34,7 @@ def test_agent_loop_compute_score(): config.data.return_raw_chat = True config.actor_rollout_ref.model.path = model_path config.actor_rollout_ref.actor.use_dynamic_bsz = True - config.actor_rollout_ref.rollout.name = "sglang" + config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"] config.actor_rollout_ref.rollout.mode = "async" config.actor_rollout_ref.rollout.prompt_length = 1024 config.actor_rollout_ref.rollout.response_length = 4096 From d58321e0d56f767c34968ea61decf7e9027318a2 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 14 Aug 2025 19:13:52 +0800 Subject: [PATCH 3/7] fix unit test --- .github/workflows/sgl.yml | 1 + .github/workflows/vllm.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/sgl.yml b/.github/workflows/sgl.yml index 8d87de17baf..5795f9c3e7c 100644 --- a/.github/workflows/sgl.yml +++ b/.github/workflows/sgl.yml @@ -136,6 +136,7 @@ jobs: pytest -s test_sglang_async_rollout_mcp_tools.py - name: Test the latest SGLang Rollout async with agent loop run: | + huggingface-cli download verl-team/gsm8k-v0.4.1 --repo-type dataset --local-dir ~/verl-data/gsm8k ROLLOUT_NAME=sglang pytest -svvv tests/experimental/agent_loop # Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests - name: Test the latest SGLang Rollout async with multimodal delta diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 7a0319a44ff..2998c08f09f 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -127,5 +127,6 @@ jobs: rm -rf "${OUTPUT_PATH}" - name: Test the latest vLLM Rollout async with agent loop run: | + huggingface-cli download verl-team/gsm8k-v0.4.1 --repo-type dataset --local-dir ~/verl-data/gsm8k ROLLOUT_NAME=vllm pytest -svvv tests/experimental/agent_loop # Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests From 16105e43dacb2b949b88fba5ab838cf507385d29 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 14 Aug 2025 20:06:46 +0800 Subject: [PATCH 4/7] fix trainer --- verl/trainer/ppo/ray_trainer.py | 56 +++++++++++---------------------- 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index e8ba7fb237f..acf44b0a267 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -625,6 +625,23 @@ def _maybe_log_val_generations(self, inputs, outputs, scores): # Log to each configured logger self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) + def _get_gen_batch(self, batch: DataProto) -> DataProto: + reward_model_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() + + # pop those keys for generation + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_model_keys + gen_batch = batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), + ) + + # For agent loop, we need reward model keys to compute score. + if self.async_rollout_mode: + gen_batch.non_tensor_batch.update(batch.non_tensor_batch) + + return gen_batch + def _validate(self): data_source_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) @@ -659,23 +676,7 @@ def _validate(self): ] sample_gts.extend(ground_truths) - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - if "interaction_kwargs" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("interaction_kwargs") - if "agent_name" in test_batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("agent_name") - test_gen_batch = test_batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) - + test_gen_batch = self._get_gen_batch(test_batch) test_gen_batch.meta_info = { "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, @@ -1107,26 +1108,7 @@ def fit(self): [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object ) - # pop those keys for generation - batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] - non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] - if "multi_modal_data" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("multi_modal_data") - if "raw_prompt" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("raw_prompt") - if "tools_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("tools_kwargs") - if "interaction_kwargs" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("interaction_kwargs") - if "index" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("index") - if "agent_name" in batch.non_tensor_batch: - non_tensor_batch_keys_to_pop.append("agent_name") - - gen_batch = batch.pop( - batch_keys=batch_keys_to_pop, - non_tensor_batch_keys=non_tensor_batch_keys_to_pop, - ) + gen_batch = self._get_gen_batch(batch) # pass global_steps to trace gen_batch.meta_info["global_steps"] = self.global_steps From 3944d5acdc046db0af1fe440337bc497f0ecacf0 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Thu, 14 Aug 2025 22:13:21 +0800 Subject: [PATCH 5/7] fix unit test --- tests/workers/rollout/utils_sglang.py | 1 + verl/experimental/agent_loop/agent_loop.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/workers/rollout/utils_sglang.py b/tests/workers/rollout/utils_sglang.py index 38b4709fdcf..e4ca383fec5 100644 --- a/tests/workers/rollout/utils_sglang.py +++ b/tests/workers/rollout/utils_sglang.py @@ -171,6 +171,7 @@ def get_rollout_config( }, "calculate_log_probs": False, "max_model_len": None, + "over_sample_rate": 0, **sampling_params, } ) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 759c58e69b4..4d9dd95dcc3 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -418,7 +418,7 @@ async def _run_agent_loop( output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs) # Some AgentLoop may have already computed the reward score, e.g SWE-agent. - if output.reward_score is None: + if output.reward_score is None and not self.config.reward_model.enable: output.reward_score = await self.reward_manager_worker.compute_score.remote(output, kwargs) # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py From 2c952ce354d1d69f9ac1cc184cde39532fea1433 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Fri, 15 Aug 2025 10:05:51 +0800 Subject: [PATCH 6/7] fix unit test --- .../experimental/agent_loop/test_agent_loop_reward.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/experimental/agent_loop/test_agent_loop_reward.py b/tests/experimental/agent_loop/test_agent_loop_reward.py index ef3fc962cda..f9b13e8b963 100644 --- a/tests/experimental/agent_loop/test_agent_loop_reward.py +++ b/tests/experimental/agent_loop/test_agent_loop_reward.py @@ -25,7 +25,16 @@ def test_agent_loop_compute_score(): - ray.init() + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + "VLLM_USE_V1": "1", + } + } + ) with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): config = compose("ppo_trainer") From 1bf24099ba2db97f4046001759236067cda3e9d6 Mon Sep 17 00:00:00 2001 From: wuxibin Date: Fri, 15 Aug 2025 14:32:12 +0800 Subject: [PATCH 7/7] fix unnit test --- verl/experimental/agent_loop/agent_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 4d9dd95dcc3..09d5ea76777 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -126,7 +126,7 @@ class AgentLoopOutput(BaseModel): """Response mask, 1 for LLM generated token, 0 for tool response token.""" multi_modal_data: Optional[dict[str, Any]] = None """Multi-modal data for multi-modal tools.""" - reward_score: float = None + reward_score: Optional[float] = None """Reward score for the trajectory.""" num_turns: int = 0 """Number of chat turns, including user, assistant, tool."""