Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion tests/experimental/agent_loop/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, RewardModelWorker


def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerGroup:
Expand All @@ -30,13 +30,25 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG
role_worker_mapping = {
Role.ActorRollout: ray.remote(actor_rollout_cls),
}
if config.reward_model.enable:
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)

global_pool_id = "global_pool"
resource_pool_spec = {
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
}
mapping = {
Role.ActorRollout: global_pool_id,
}
if config.reward_model.enable_resource_pool:
mapping[Role.RewardModel] = "reward_pool"
if config.reward_model.n_gpus_per_node <= 0:
raise ValueError("config.reward_model.n_gpus_per_node must be greater than 0")
if config.reward_model.nnodes <= 0:
raise ValueError("config.reward_model.nnodes must be greater than 0")

reward_pool = [config.reward_model.n_gpus_per_node] * config.reward_model.nnodes
resource_pool_spec["reward_pool"] = reward_pool
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
resource_pool_manager.create_resource_pool()
resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}
Expand All @@ -48,6 +60,12 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG
)
resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls

if config.reward_model.enable:
# we create a RM here
resource_pool = resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(role_worker_mapping[Role.RewardModel], config=config.reward_model)
resource_pool_to_cls[resource_pool]["rm"] = rm_cls

all_wg = {}
for resource_pool, class_dict in resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
Expand All @@ -60,10 +78,16 @@ def init_agent_loop_manager(config: DictConfig) -> AgentLoopManager | RayWorkerG
if config.actor_rollout_ref.rollout.mode == "sync":
return actor_rollout_wg

if config.reward_model.enable_resource_pool and config.reward_model.enable:
rm_wg = all_wg["rm"]
rm_wg.init_model()
else:
rm_wg = None
# =========================== 2. Create AgentLoopManager ===========================
agent_loop_manager = AgentLoopManager(
config=config,
worker_group=actor_rollout_wg,
rm_wg=rm_wg,
)

return agent_loop_manager
96 changes: 96 additions & 0 deletions tests/experimental/agent_loop/test_agent_loop_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import ray
from hydra import compose, initialize_config_dir
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer

from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
from verl.protocol import DataProto
from verl.trainer.main_ppo import create_rl_sampler
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn


def test_agent_loop_compute_score_with_model():
ray.init(
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"VLLM_USE_V1": "1",
}
}
)

with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose("ppo_trainer")

model_path = "Qwen/Qwen2.5-1.5B-Instruct"
config.data.return_raw_chat = True
config.actor_rollout_ref.model.path = model_path
config.actor_rollout_ref.actor.use_dynamic_bsz = True
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
config.actor_rollout_ref.rollout.mode = "async"
config.actor_rollout_ref.rollout.prompt_length = 1024
config.actor_rollout_ref.rollout.response_length = 4096
config.reward_model.enable = True
config.reward_model.model.path = model_path
config.reward_model.use_dynamic_bsz = True
config.reward_model.forward_max_token_len_per_gpu = 6000
config.reward_model.micro_batch_size_per_gpu = 40
config.reward_model.enable_resource_pool = True
config.reward_model.n_gpus_per_node = 1
config.reward_model.nnodes = 1
config.reward_model.model.trust_remote_code = True
config.reward_model.model.input_tokenizer = None
config.trainer.n_gpus_per_node = 4
config.trainer.nnodes = 1
# 1. init agent loop manager
agent_loop_manager = init_agent_loop_manager(config)

# 2. init dataset and dataloader
local_folder = os.path.expanduser("~/verl-data/gsm8k/")
data_files = [os.path.join(local_folder, "train.parquet")]
tokenizer = AutoTokenizer.from_pretrained(model_path)

dataset = RLHFDataset(
data_files=data_files,
tokenizer=tokenizer,
config=config.data,
processor=None,
)

batch_size = 128
sampler = create_rl_sampler(config.data, dataset)
dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=config.data.dataloader_num_workers,
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
)

# 3. generate_sequences with agent loop
batch_dict = next(iter(dataloader))
batch = DataProto.from_single_dict(batch_dict)
gen_batch = agent_loop_manager.generate_sequences(prompts=batch)

rm_scores = gen_batch.batch["rm_scores"]
sample_scores = rm_scores.sum(dim=1)
print(sample_scores)
ray.shutdown()
Loading