diff --git a/verl/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py index 303d9ed9045..38c46537b93 100644 --- a/verl/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -66,6 +66,7 @@ class Execute(DynamicEnum): def init_predefined_execute_mode(): Execute.register("ALL") Execute.register("RANK_ZERO") + Execute.register("ALL_MULTITHREAD") # Initialize the two Dynamic Enum Classes @@ -446,9 +447,16 @@ def get_predefined_execute_fn(execute_mode): Note that here we only asks execute_all and execute_rank_zero to be implemented Leave the choice of how these two functions handle argument 'blocking' to users """ + if isinstance(execute_mode, str): + try: + execute_mode = Execute[execute_mode.upper()] + except KeyError: + raise ValueError(f"Unknown execute_mode: {execute_mode}") + predefined_execute_mode_fn = { Execute.ALL: {"execute_fn_name": "execute_all"}, Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"}, + Execute.ALL_MULTITHREAD: {"execute_fn_name": "execute_all_multithread_submit"}, } return predefined_execute_mode_fn[execute_mode] diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index e6a51453246..fb269ea0236 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -16,7 +16,9 @@ import logging import time from copy import deepcopy -from typing import Any, Optional +from typing import Any, Optional, Dict +from concurrent.futures import ThreadPoolExecutor +import threading import ray from ray.experimental.state.api import get_actor @@ -45,7 +47,24 @@ class Functor: def __call__(this, *args, **kwargs): args, kwargs = dispatch_fn(self, *args, **kwargs) padding_count = kwargs.pop(_padding_size_key, 0) - output = execute_fn(method_name, *args, **kwargs) + time_start = time.time() + # Check for external custom execute_mode + custom_execute_mode = getattr(self, 'custom_execute_mode', None) + if custom_execute_mode: + try: + from verl.single_controller.base.decorator import get_predefined_execute_fn + execute_config = get_predefined_execute_fn(custom_execute_mode) + custom_execute_fn = getattr(self, execute_config["execute_fn_name"]) + output = custom_execute_fn(method_name, *args, **kwargs) + print(f"[EXECUTE MODE] Using custom {custom_execute_mode} for {method_name}") + except Exception as e: + print(f"[EXECUTE MODE ERROR] Failed to use custom {custom_execute_mode}, falling back to default: {e}") + output = execute_fn(method_name, *args, **kwargs) + else: + output = execute_fn(method_name, *args, **kwargs) + + time_end = time.time() + print(f"[REMOTE EXECUTION] {method_name} on {len(self._workers)} workers submitted in {time_end - time_start:.4f} seconds") if blocking: output = ray.get(output) output = collect_fn(self, output) @@ -674,6 +693,58 @@ def execute_all_async(self, method_name: str, *args, **kwargs): return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers] + def _make_remote_call_parallel(self, worker, method_name: str, *args, **kwargs): + """Helper function to make remote call (similar to _execute_remote_single_worker but without timing).""" + if self.fused_worker_used and method_name not in self.method_names: + remote_call = getattr(worker, self.fused_worker_execute_fn_name) + return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs) + else: + remote_call = getattr(worker, method_name) + return remote_call.remote(*args, **kwargs) + + def execute_all_multithread_submit(self, method_name: str, *args, **kwargs): + """Execute a method on all workers asynchronously with parallel submission. + + This method submits all remote calls in parallel using thread pool, + then returns remote object references (similar to execute_all_async). + + Args: + method_name: Name of the method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + List of remote object references to the method executions + """ + # Handle argument slicing like execute_all_async + length = len(self._workers) + # Check if global thread pool is available (only initialized for multithread mode) + try: + thread_pool = global_thread_pool_manager.get_thread_pool() + except Exception as e: + print(f"[WARNING] Global thread pool not available, falling back to sync execution: {e}") + return self.execute_all_async(method_name, *args, **kwargs) + + if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): + if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): + # Submit all remote calls in parallel with sliced args + futures = [ + thread_pool.submit( + lambda i=i: self._make_remote_call_parallel( + self._workers[i], method_name, + *(arg[i] for arg in args), + **{k: v[i] for k, v in kwargs.items()} + ) + ) for i in range(length) + ] + return [future.result() for future in futures] + + futures = [ + thread_pool.submit(self._make_remote_call_parallel, worker, method_name, *args, **kwargs) + for worker in self._workers + ] + return [future.result() for future in futures] + @property def master_address(self): return self._master_addr @@ -689,6 +760,19 @@ def workers(self): @property def world_size(self): return self._world_size + + def set_execute_mode(self, execute_mode: str): + """Set custom execution mode + + Args: + execute_mode: Execution mode, options: "multithread", "sync", "async", "rank_zero" + """ + self.custom_execute_mode = execute_mode + print(f"[WORKER GROUP] Set execute_mode to: {execute_mode}") + + def get_execute_mode(self) -> str: + """Get current execution mode""" + return getattr(self, 'custom_execute_mode', 'default') """ @@ -906,3 +990,64 @@ def create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs cia.fused_worker_used = True return cia + + +class GlobalThreadPoolManager: + """Global thread pool manager for shared remote call execution""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, '_initialized'): + self._thread_pool = None + self._max_workers = 8 # Default max workers + self._initialized = True + + def get_thread_pool(self, max_workers: Optional[int] = None) -> ThreadPoolExecutor: + """Get or create the global thread pool""" + if max_workers is not None: + self._max_workers = max_workers + + if self._thread_pool is None: + with self._lock: + if self._thread_pool is None: + self._thread_pool = ThreadPoolExecutor( + max_workers=self._max_workers, + thread_name_prefix="global_ray_remote" + ) + print(f"[GLOBAL THREAD POOL] Created with {self._max_workers} workers") + + return self._thread_pool + + def shutdown(self): + """Shutdown the global thread pool""" + if self._thread_pool is not None: + with self._lock: + if self._thread_pool is not None: + self._thread_pool.shutdown(wait=True) + self._thread_pool = None + print("[GLOBAL THREAD POOL] Shutdown complete") + + def get_stats(self) -> Dict[str, Any]: + """Get thread pool statistics""" + if self._thread_pool is None: + return {"status": "not_initialized"} + + return { + "status": "active", + "max_workers": self._max_workers, + "active_threads": len(self._thread_pool._threads), + "queue_size": self._thread_pool._work_queue.qsize() if hasattr(self._thread_pool, '_work_queue') else 0 + } + + +# Global instance +global_thread_pool_manager = GlobalThreadPoolManager() \ No newline at end of file diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index b44ace79e36..43bf178cee8 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -321,6 +321,18 @@ trainer: # Timeout (in seconds) for Ray worker to wait for registration ray_wait_register_center_timeout: 300 + # Execution mode for Ray worker groups + # Options: "all"(default), "all_multithread", "rank_zero" + # "all": Execute all workers asynchronously + # "all_multithread": Use ThreadPoolExecutor for parallel submission + # "rank_zero": Execute only on rank 0 + execute_mode: all + + # Thread pool size for all_multithread execution mode + # If null, will be calculated as min(8, max(2, total_gpus)) + # Recommended: set to worker count to avoid blocking + execute_thread_pool_size: null + # Device to run training on (e.g., "cuda", "cpu") device: cuda diff --git a/verl/trainer/config/reward_model/reward_model.yaml b/verl/trainer/config/reward_model/reward_model.yaml index 7fd597c0fa3..b213b0ccaa7 100644 --- a/verl/trainer/config/reward_model/reward_model.yaml +++ b/verl/trainer/config/reward_model/reward_model.yaml @@ -53,6 +53,10 @@ reward_manager: naive # custom reward function executed async on CPU, during log_prob launch_reward_fn_async: False +# Whether to launch custom reward function in a new thread during log_prob +# keep batch read-only during reward computation +launch_reward_fn_sub_thread: False + # Cloud/local sandbox fusion configuration for custom reward logic sandbox_fusion: diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 6b88f190ede..94e96fb2756 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -27,7 +27,8 @@ from dataclasses import dataclass, field from enum import Enum from pprint import pprint -from typing import Optional +from typing import Optional, Callable +from concurrent.futures import ThreadPoolExecutor, Future import numpy as np import ray @@ -361,6 +362,26 @@ def __init__( experiment_name=self.config.trainer.experiment_name, ) + # Initialize thread pool for asynchronous reward computation (only when needed) + self.reward_thread_pool = None + + # Initialize thread pool only if launch_reward_fn_sub_thread is enabled + if OmegaConf.select(config.reward_model, "launch_reward_fn_sub_thread", default=False): + self.reward_thread_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="reward_worker") + + # Configure global thread pool for remote calls only when execute_mode is multithread + execute_mode = OmegaConf.select(config.trainer, "execute_mode", default="all") + if execute_mode == "all_multithread": + from verl.single_controller.ray.base import global_thread_pool_manager + # Get thread pool size from config or calculate automatically + global_thread_pool_size = OmegaConf.select(config.trainer, "execute_thread_pool_size", default=None) + if global_thread_pool_size is None: + # Auto calculate based on total GPU count + total_gpus = self.resource_pool_manager.get_n_gpus() + global_thread_pool_size = min(8, max(2, total_gpus)) + global_thread_pool_manager.get_thread_pool(max_workers=global_thread_pool_size) + + # if ref_in_actor is True, the reference policy will be actor without lora applied self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 @@ -383,6 +404,30 @@ def __init__( self._validate_config() self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + def _async_compute_reward_wrapper(self, batch: DataProto, reward_fn: Callable) -> tuple[torch.Tensor, dict]: + """ + Asynchronous reward computation wrapper function that can be safely executed in a sub-thread + + Args: + batch: Input data + reward_fn: Reward computation function + + Returns: + tuple: (reward_tensor, reward_extra_infos_dict) + """ + # Ensure proper device setup in sub-thread + if hasattr(batch, 'to') and callable(getattr(batch, 'to')): + # If batch supports device transfer, ensure execution on CPU + batch = batch.to('cpu') + + # Execute reward computation + reward_tensor, reward_extra_infos_dict = compute_reward(batch, reward_fn) + + if isinstance(reward_tensor, torch.Tensor): + reward_tensor = reward_tensor.cpu() + + return reward_tensor, reward_extra_infos_dict def _validate_config(self): config = self.config @@ -836,6 +881,11 @@ def init_workers(self): OmegaConf.select(self.config.trainer, "worker_nsight_options") ) wg_kwargs["device_name"] = self.device_name + + # Configure execute_mode for worker groups + if OmegaConf.select(self.config.trainer, "execute_mode") is not None: + wg_kwargs["execute_mode"] = self.config.trainer.execute_mode + print(f"[WORKER GROUP CONFIG] Using execute_mode: {self.config.trainer.execute_mode}") for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) @@ -849,18 +899,30 @@ def init_workers(self): if self.use_critic: self.critic_wg = all_wg["critic"] + # Set execute mode for critic worker group + if OmegaConf.select(self.config.trainer, "execute_mode") is not None: + self.critic_wg.set_execute_mode(self.config.trainer.execute_mode) self.critic_wg.init_model() if self.use_reference_policy and not self.ref_in_actor: self.ref_policy_wg = all_wg["ref"] + # Set execute mode for reference policy worker group + if OmegaConf.select(self.config.trainer, "execute_mode") is not None: + self.ref_policy_wg.set_execute_mode(self.config.trainer.execute_mode) self.ref_policy_wg.init_model() if self.use_rm: self.rm_wg = all_wg["rm"] + # Set execute mode for reward model worker group + if OmegaConf.select(self.config.trainer, "execute_mode") is not None: + self.rm_wg.set_execute_mode(self.config.trainer.execute_mode) self.rm_wg.init_model() # we should create rollout at the end so that vllm can have a better estimation of kv cache memory self.actor_rollout_wg = all_wg["actor_rollout"] + # Set execute mode for actor rollout worker group + if OmegaConf.select(self.config.trainer, "execute_mode") is not None: + self.actor_rollout_wg.set_execute_mode(self.config.trainer.execute_mode) self.actor_rollout_wg.init_model() # create async rollout manager and request scheduler @@ -1028,6 +1090,14 @@ def _balance_batch(self, batch: DataProto, metrics, logging_prefix="global_seqle ) metrics.update(global_balance_stats) + def print_resource_pool_stats(self): + """Print global thread pool statistics""" + execute_mode = OmegaConf.select(self.config.trainer, "execute_mode", default="all") + if execute_mode == "all_multithread": + from verl.single_controller.ray.base import global_thread_pool_manager + thread_stats = global_thread_pool_manager.get_stats() + print(f"[THREAD POOL STATS] {thread_stats}") + def fit(self): """ The training loop of PPO. @@ -1169,13 +1239,21 @@ def fit(self): # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() - with marked_timer("reward", timing_raw, color="yellow"): + with marked_timer("reward_model", timing_raw, color="yellow"): # compute reward model score if self.use_rm: reward_tensor = self.rm_wg.compute_rm_score(batch) batch = batch.union(reward_tensor) - if self.config.reward_model.launch_reward_fn_async: + with marked_timer("reward_compute", timing_raw, color="orange"): + assert not self.config.reward_model.launch_reward_fn_sub_thread or \ + not self.config.reward_model.launch_reward_fn_async, \ + "Only one of launch_reward_fn_sub_thread and launch_reward_fn_async can be True" + if self.config.reward_model.launch_reward_fn_sub_thread: + future_reward = self.reward_thread_pool.submit( + self._async_compute_reward_wrapper, batch, self.reward_fn + ) + elif self.config.reward_model.launch_reward_fn_async: future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn) else: reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) @@ -1234,8 +1312,15 @@ def fit(self): with marked_timer("adv", timing_raw, color="brown"): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] - if self.config.reward_model.launch_reward_fn_async: + if self.config.reward_model.launch_reward_fn_sub_thread: + # Get result from sub-thread + reward_tensor, reward_extra_infos_dict = future_reward.result() + elif self.config.reward_model.launch_reward_fn_async: + # Get result from Ray remote function reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + else: + # reward_tensor and reward_extra_infos_dict are already computed synchronously + pass batch.batch["token_level_scores"] = reward_tensor if reward_extra_infos_dict: @@ -1380,6 +1465,12 @@ def fit(self): if is_last_step: pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() + # Clean up reward thread pool if it exists + if self.reward_thread_pool is not None: + self.reward_thread_pool.shutdown(wait=True) + # Clean up global thread pool + from verl.single_controller.ray.base import global_thread_pool_manager + global_thread_pool_manager.shutdown() return # this is experimental and may be changed/removed in the future