diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_legacy.sh b/examples/ppo_trainer/run_qwen2-7b_rm_legacy.sh new file mode 100644 index 00000000000..99574a33c96 --- /dev/null +++ b/examples/ppo_trainer/run_qwen2-7b_rm_legacy.sh @@ -0,0 +1,63 @@ +# download datasets and models +# python3 examples/data_preprocess/gsm8k.py +# python3 examples/data_preprocess/math_dataset.py +# huggingface-cli download Skywork/Skywork-Reward-V2-Llama-3.2-3B --local-dir $HOME/models/Skywork-Reward-V2-Llama-3.2-3B +# huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path="$HOME/models/Qwen2.5-3B-Instruct" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding=True \ + critic.optim.lr_warmup_steps_ratio=0.05 \ + critic.model.path="$HOME/models/Qwen2.5-3B-Instruct" \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size_per_gpu=32 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + reward_model.enable=True \ + reward_model.model.path="$HOME/models/Skywork-Reward-V2-Llama-3.2-3B" \ + reward_model.use_reward_loop=False \ + reward_model.model.use_remove_padding=True \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_test_qwen25_rm' \ + trainer.val_before_train=True \ + trainer.experiment_name='legacy_fsdp_reward_model' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_reward_loop_colocate.sh b/examples/ppo_trainer/run_qwen2-7b_rm_reward_loop_colocate.sh new file mode 100644 index 00000000000..9641fdcb907 --- /dev/null +++ b/examples/ppo_trainer/run_qwen2-7b_rm_reward_loop_colocate.sh @@ -0,0 +1,69 @@ +# download datasets and models +# python3 examples/data_preprocess/gsm8k.py +# python3 examples/data_preprocess/math_dataset.py +# huggingface-cli download Skywork/Skywork-Reward-V2-Llama-3.2-3B --local-dir $HOME/models/Skywork-Reward-V2-Llama-3.2-3B +# huggingface-cli download Qwen/Qwen2.5-3B-Instruct --local-dir $HOME/models/Qwen2.5-3B-Instruct + +gsm8k_train_path=$HOME/data/gsm8k/train.parquet +gsm8k_test_path=$HOME/data/gsm8k/test.parquet +math_train_path=$HOME/data/math/train.parquet +math_test_path=$HOME/data/math/test.parquet + +train_files="['$gsm8k_train_path', '$math_train_path']" +test_files="['$gsm8k_test_path', '$math_test_path']" + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=gae \ + data.train_files="$train_files" \ + data.val_files="$test_files" \ + data.train_batch_size=1024 \ + data.max_prompt_length=1024 \ + data.max_response_length=2048 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.return_raw_chat=True \ + actor_rollout_ref.model.path="$HOME/models/Qwen2.5-3B-Instruct" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + critic.optim.lr=1e-5 \ + critic.model.use_remove_padding=True \ + critic.optim.lr_warmup_steps_ratio=0.05 \ + critic.model.path="$HOME/models/Qwen2.5-3B-Instruct" \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size_per_gpu=32 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + reward_model.enable=True \ + reward_model.model.path="$HOME/models/Skywork-Reward-V2-Llama-3.2-3B" \ + reward_model.use_reward_loop=True \ + reward_model.rollout.name=vllm \ + reward_model.rollout.gpu_memory_utilization=0.8 \ + reward_model.rollout.prompt_length=4096 \ + reward_model.rollout.response_length=4096 \ + reward_model.rollout.tensor_model_parallel_size=1 \ + reward_model.num_workers=8 \ + reward_model.model.use_remove_padding=True \ + reward_model.model.fsdp_config.param_offload=True \ + reward_model.micro_batch_size_per_gpu=32 \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='verl_test_qwen25_rm' \ + trainer.val_before_train=False \ + trainer.experiment_name='reward_loop_colocate_reward_model' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 $@ diff --git a/tests/special_e2e/ppo_trainer/run_model_reward.sh b/tests/special_e2e/ppo_trainer/run_model_reward.sh index 09d6757b511..ba36a6e394b 100644 --- a/tests/special_e2e/ppo_trainer/run_model_reward.sh +++ b/tests/special_e2e/ppo_trainer/run_model_reward.sh @@ -79,6 +79,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ reward_model.enable=True \ + reward_model.use_reward_loop=False \ reward_model.ulysses_sequence_parallel_size="${SP_SIZE}" \ reward_model.model.path="${MODEL_PATH}" \ reward_model.model.use_remove_padding="${RM_PAD}" \ diff --git a/tests/special_e2e/run_ppo_trainer_megatron.sh b/tests/special_e2e/run_ppo_trainer_megatron.sh index a88500aba40..f01eba4f659 100644 --- a/tests/special_e2e/run_ppo_trainer_megatron.sh +++ b/tests/special_e2e/run_ppo_trainer_megatron.sh @@ -244,6 +244,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ critic.profiler.ranks=$PROFILE_RANKS \ critic.profiler.all_ranks=$PROFILE_RANKS_ALL \ reward_model.enable=True \ + reward_model.use_reward_loop=False \ reward_model.model.path="${MODEL_PATH}" \ reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ reward_model.megatron.use_mbridge=${USE_MBRIDGE} \ diff --git a/verl/experimental/reward/reward_manager.py b/verl/experimental/reward/reward_manager.py index 52e7403ab6e..c5c95c80946 100644 --- a/verl/experimental/reward/reward_manager.py +++ b/verl/experimental/reward/reward_manager.py @@ -136,6 +136,14 @@ async def _preprocess_reward_inputs(self, data: DataProto) -> str: add_generation_prompt=False, tokenize=False, ) + + # llama tokenizer will add bos token by default + # will be removed in vllm >= 0.11.2, where we can add "add_special_tokens" = False + if self.reward_model_tokenizer.bos_token is not None and rm_prompt.startswith( + self.reward_model_tokenizer.bos_token + ): + rm_prompt = rm_prompt[len(self.reward_model_tokenizer.bos_token) :] + return rm_prompt async def compute_score_disrm(self, data: DataProto) -> dict: @@ -148,7 +156,7 @@ async def compute_score_disrm(self, data: DataProto) -> dict: "model": model_name, "input": disrm_prompt, "activation": False, - "add_special_tokens": False, + # "add_special_tokens": False, # vllm >= 0.11.2 } output = await self._post_request(payloads, "classify") rm_score = output["data"][-1]["probs"][-1] @@ -187,7 +195,7 @@ def __init__(self, config: DictConfig, rm_resource_pool: RayResourcePool = None) def _init_reward_loop_workers(self): self.reward_loop_workers = [] - num_workers = self.config.reward_model.get("num_workers", 1) + num_workers = self.config.reward_model.num_workers node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"].get("CPU", 0) > 0] for i in range(num_workers): diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index feb73a5430e..f1bdb553d5f 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -220,13 +220,15 @@ def split_resource_pool( else: start_bundle_idx_list = np.cumsum([0] + split_size_list[:-1]) + # ensure resource_pool.pgs has been initialized + placement_groups = resource_pool.get_placement_groups() split_resource_pools = [ SubRayResourcePool( process_on_nodes=resource_pool.store, use_gpu=resource_pool.use_gpu, name_prefix=f"{resource_pool.name_prefix}_split_{split_idx}", max_colocate_count=resource_pool.max_colocate_count, - placement_groups=resource_pool.pgs, + placement_groups=placement_groups, start_bundle_index=start_bundle_idx_list[split_idx], subgroup_world_size=split_size_list[split_idx], ) diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index 89849b8f124..81bdd74fd66 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -521,7 +521,7 @@ critic: reward_model: enable: false enable_resource_pool: false - n_gpus_per_node: 0 + n_gpus_per_node: 8 nnodes: 0 strategy: megatron model: @@ -572,6 +572,7 @@ reward_model: dtype: bfloat16 load_weight: true use_reward_loop: true + num_workers: 1 rollout: _target_: verl.workers.config.RolloutConfig name: ??? @@ -592,9 +593,9 @@ reward_model: enable_chunked_prefill: true enable_prefix_caching: true disable_log_stats: true - skip_tokenizer_init: true - prompt_length: 512 - response_length: 512 + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 algorithm: rollout_correction: rollout_is: null diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index ddfe8e6aa8b..732c35a1ccd 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -455,7 +455,7 @@ critic: reward_model: enable: false enable_resource_pool: false - n_gpus_per_node: 0 + n_gpus_per_node: 8 nnodes: 0 strategy: fsdp model: @@ -496,6 +496,7 @@ reward_model: tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} ulysses_sequence_parallel_size: 1 use_reward_loop: true + num_workers: 1 rollout: _target_: verl.workers.config.RolloutConfig name: ??? @@ -516,9 +517,9 @@ reward_model: enable_chunked_prefill: true enable_prefix_caching: true disable_log_stats: true - skip_tokenizer_init: true - prompt_length: 512 - response_length: 512 + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 algorithm: rollout_correction: rollout_is: null diff --git a/verl/trainer/config/reward_model/dp_reward_loop.yaml b/verl/trainer/config/reward_model/dp_reward_loop.yaml index da046b7bada..04fb106df1c 100644 --- a/verl/trainer/config/reward_model/dp_reward_loop.yaml +++ b/verl/trainer/config/reward_model/dp_reward_loop.yaml @@ -8,7 +8,8 @@ enable: False # Whether to deploy the model to a separate resource pool. enable_resource_pool: False -n_gpus_per_node: 0 +n_gpus_per_node: 8 +num_workers: 1 nnodes: 0 model: @@ -36,7 +37,7 @@ rollout: enable_chunked_prefill: true enable_prefix_caching: true disable_log_stats: true - skip_tokenizer_init: true + skip_tokenizer_init: false - prompt_length: 512 - response_length: 512 \ No newline at end of file + prompt_length: 2048 + response_length: 2048 \ No newline at end of file diff --git a/verl/trainer/config/reward_model/megatron_reward_loop.yaml b/verl/trainer/config/reward_model/megatron_reward_loop.yaml index 1169e9e915c..f99b94abcc4 100644 --- a/verl/trainer/config/reward_model/megatron_reward_loop.yaml +++ b/verl/trainer/config/reward_model/megatron_reward_loop.yaml @@ -8,7 +8,8 @@ enable: False # Whether to deploy the model to a separate resource pool. enable_resource_pool: False -n_gpus_per_node: 0 +n_gpus_per_node: 8 +num_workers: 1 nnodes: 0 model: @@ -36,7 +37,7 @@ rollout: enable_chunked_prefill: true enable_prefix_caching: true disable_log_stats: true - skip_tokenizer_init: true + skip_tokenizer_init: false - prompt_length: 512 - response_length: 512 \ No newline at end of file + prompt_length: 2048 + response_length: 2048 \ No newline at end of file diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index e439a76d361..ed4df5cc2d0 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -323,7 +323,10 @@ def __init__( self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager self.use_reference_policy = need_reference_policy(self.role_worker_mapping) + # legacy reward model implementation self.use_rm = need_reward_model(self.role_worker_mapping) + self.use_reward_loop = self.config.reward_model.use_reward_loop + self.use_critic = need_critic(self.config) self.ray_worker_group_cls = ray_worker_group_cls self.device_name = device_name if device_name else self.config.trainer.device @@ -711,11 +714,37 @@ def init_workers(self): self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls # create a reward model if reward_fn is None - if self.use_rm: - # we create a RM here - resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) - rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + # for legacy discriminative reward model, we create a reward model worker here + # for reward loop discriminative reward model, we create a reward loop manager here + if not self.use_reward_loop: + # legacy reward model only handle reward-model based scenario + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model + ) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + else: + # reward loop handle hybrid reward scenario (rule, disrm, genrm, ...) + can_reward_loop_parallelize = self.config.actor_rollout_ref.rollout.mode == "async" and ( + not self.use_rm or self.config.reward_model.enable_resource_pool + ) + # judge if we can asynchronously parallelize reward model with actor rollout + # two condition that we can parallelize reward model with actor rollout: + # 1. reward model is not enabled (rule-based reward can parallelize) + # 2. reward model is enabled but extra resource pool is enabled + # If we cannot parallelize, we should enable synchronous mode here, and launch a reward loop manager here + # else for parallelize mode, we launch a reward worker for each rollout worker (in agent loop, not here) + if not can_reward_loop_parallelize: + from verl.experimental.reward import RewardLoopManager + + self.config.reward_model.n_gpus_per_node = self.config.trainer.n_gpus_per_node + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + self.reward_loop_manager = RewardLoopManager( + config=self.config, + rm_resource_pool=resource_pool, + ) # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, @@ -764,7 +793,7 @@ def init_workers(self): self.rm_wg = None # initalization of rm_wg will be deprecated in the future - if self.use_rm: + if self.use_rm and not self.use_reward_loop: self.rm_wg = all_wg[str(Role.RewardModel)] self.rm_wg.init_model() @@ -923,7 +952,7 @@ def _start_profiling(self, do_profile: bool) -> None: self.ref_policy_wg.start_profile(profile_step=self.global_steps) if self.use_critic: self.critic_wg.start_profile(profile_step=self.global_steps) - if self.use_rm: + if self.use_rm and not self.use_reward_loop: self.rm_wg.start_profile(profile_step=self.global_steps) def _stop_profiling(self, do_profile: bool) -> None: @@ -934,7 +963,7 @@ def _stop_profiling(self, do_profile: bool) -> None: self.ref_policy_wg.stop_profile() if self.use_critic: self.critic_wg.stop_profile() - if self.use_rm: + if self.use_rm and not self.use_reward_loop: self.rm_wg.stop_profile() def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqlen", keep_minibatch=False): @@ -1085,7 +1114,11 @@ def fit(self): # compute reward model score on batch rm_scores = None if self.use_rm and "rm_scores" not in batch.batch.keys(): - rm_scores = self.rm_wg.compute_rm_score(batch) + if not self.use_reward_loop: + rm_scores = self.rm_wg.compute_rm_score(batch) + else: + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + rm_scores = self.reward_loop_manager.compute_rm_score(batch) batch = batch.union(rm_scores) reward_baseline_tensor, _ = compute_reward(batch, self.reward_fn) reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) @@ -1117,7 +1150,11 @@ def fit(self): with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score if self.use_rm and "rm_scores" not in batch.batch.keys(): - reward_tensor = self.rm_wg.compute_rm_score(batch) + if not self.use_reward_loop: + reward_tensor = self.rm_wg.compute_rm_score(batch) + else: + assert self.reward_loop_manager is not None, "RewardLoopManager is None" + reward_tensor = self.reward_loop_manager.compute_rm_score(batch) batch = batch.union(reward_tensor) if self.config.reward_model.launch_reward_fn_async: