Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions verl/single_controller/base/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Execute(DynamicEnum):
def init_predefined_execute_mode():
Execute.register("ALL")
Execute.register("RANK_ZERO")
Execute.register("ALL_MULTITHREAD")


# Initialize the two Dynamic Enum Classes
Expand Down Expand Up @@ -446,9 +447,16 @@ def get_predefined_execute_fn(execute_mode):
Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users
"""
if isinstance(execute_mode, str):
try:
execute_mode = Execute[execute_mode.upper()]
except KeyError:
raise ValueError(f"Unknown execute_mode: {execute_mode}")

predefined_execute_mode_fn = {
Execute.ALL: {"execute_fn_name": "execute_all"},
Execute.RANK_ZERO: {"execute_fn_name": "execute_rank_zero"},
Execute.ALL_MULTITHREAD: {"execute_fn_name": "execute_all_multithread_submit"},
}
return predefined_execute_mode_fn[execute_mode]

Expand Down
149 changes: 147 additions & 2 deletions verl/single_controller/ray/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import logging
import time
from copy import deepcopy
from typing import Any, Optional
from typing import Any, Optional, Dict
from concurrent.futures import ThreadPoolExecutor
import threading

import ray
from ray.experimental.state.api import get_actor
Expand Down Expand Up @@ -45,7 +47,24 @@ class Functor:
def __call__(this, *args, **kwargs):
args, kwargs = dispatch_fn(self, *args, **kwargs)
padding_count = kwargs.pop(_padding_size_key, 0)
output = execute_fn(method_name, *args, **kwargs)
time_start = time.time()
# Check for external custom execute_mode
custom_execute_mode = getattr(self, 'custom_execute_mode', None)
if custom_execute_mode:
try:
from verl.single_controller.base.decorator import get_predefined_execute_fn
execute_config = get_predefined_execute_fn(custom_execute_mode)
custom_execute_fn = getattr(self, execute_config["execute_fn_name"])
output = custom_execute_fn(method_name, *args, **kwargs)
print(f"[EXECUTE MODE] Using custom {custom_execute_mode} for {method_name}")
except Exception as e:
print(f"[EXECUTE MODE ERROR] Failed to use custom {custom_execute_mode}, falling back to default: {e}")
Comment on lines +60 to +61
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a broad Exception can hide underlying issues and make debugging difficult. It's better to catch more specific exceptions that you expect to handle, such as AttributeError if a method is not found, or ValueError from get_predefined_execute_fn. If you must catch a broad exception, consider logging the full traceback for better diagnostics.

output = execute_fn(method_name, *args, **kwargs)
else:
output = execute_fn(method_name, *args, **kwargs)

time_end = time.time()
print(f"[REMOTE EXECUTION] {method_name} on {len(self._workers)} workers submitted in {time_end - time_start:.4f} seconds")
if blocking:
output = ray.get(output)
output = collect_fn(self, output)
Expand Down Expand Up @@ -674,6 +693,58 @@ def execute_all_async(self, method_name: str, *args, **kwargs):

return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers]

def _make_remote_call_parallel(self, worker, method_name: str, *args, **kwargs):
"""Helper function to make remote call (similar to _execute_remote_single_worker but without timing)."""
if self.fused_worker_used and method_name not in self.method_names:
remote_call = getattr(worker, self.fused_worker_execute_fn_name)
return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs)
else:
remote_call = getattr(worker, method_name)
return remote_call.remote(*args, **kwargs)

def execute_all_multithread_submit(self, method_name: str, *args, **kwargs):
"""Execute a method on all workers asynchronously with parallel submission.

This method submits all remote calls in parallel using thread pool,
then returns remote object references (similar to execute_all_async).

Args:
method_name: Name of the method to execute
*args: Positional arguments for the method
**kwargs: Keyword arguments for the method

Returns:
List of remote object references to the method executions
"""
# Handle argument slicing like execute_all_async
length = len(self._workers)
# Check if global thread pool is available (only initialized for multithread mode)
try:
thread_pool = global_thread_pool_manager.get_thread_pool()
except Exception as e:
print(f"[WARNING] Global thread pool not available, falling back to sync execution: {e}")
Comment on lines +724 to +725
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a broad Exception here can mask problems with the global_thread_pool_manager, such as configuration errors. This makes debugging harder. It's recommended to catch more specific exceptions if possible, or at least log the full traceback to provide more context on the failure.

return self.execute_all_async(method_name, *args, **kwargs)

if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()):
if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()):
# Submit all remote calls in parallel with sliced args
futures = [
thread_pool.submit(
lambda i=i: self._make_remote_call_parallel(
self._workers[i], method_name,
*(arg[i] for arg in args),
**{k: v[i] for k, v in kwargs.items()}
)
) for i in range(length)
]
return [future.result() for future in futures]

futures = [
thread_pool.submit(self._make_remote_call_parallel, worker, method_name, *args, **kwargs)
for worker in self._workers
]
return [future.result() for future in futures]

@property
def master_address(self):
return self._master_addr
Expand All @@ -689,6 +760,19 @@ def workers(self):
@property
def world_size(self):
return self._world_size

def set_execute_mode(self, execute_mode: str):
"""Set custom execution mode

Args:
execute_mode: Execution mode, options: "multithread", "sync", "async", "rank_zero"
"""
self.custom_execute_mode = execute_mode
print(f"[WORKER GROUP] Set execute_mode to: {execute_mode}")

def get_execute_mode(self) -> str:
"""Get current execution mode"""
return getattr(self, 'custom_execute_mode', 'default')


"""
Expand Down Expand Up @@ -906,3 +990,64 @@ def create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs
cia.fused_worker_used = True

return cia


class GlobalThreadPoolManager:
"""Global thread pool manager for shared remote call execution"""

_instance = None
_lock = threading.Lock()

def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
if not hasattr(self, '_initialized'):
self._thread_pool = None
self._max_workers = 8 # Default max workers
self._initialized = True

def get_thread_pool(self, max_workers: Optional[int] = None) -> ThreadPoolExecutor:
"""Get or create the global thread pool"""
if max_workers is not None:
self._max_workers = max_workers

if self._thread_pool is None:
with self._lock:
if self._thread_pool is None:
self._thread_pool = ThreadPoolExecutor(
max_workers=self._max_workers,
thread_name_prefix="global_ray_remote"
)
print(f"[GLOBAL THREAD POOL] Created with {self._max_workers} workers")

return self._thread_pool

def shutdown(self):
"""Shutdown the global thread pool"""
if self._thread_pool is not None:
with self._lock:
if self._thread_pool is not None:
self._thread_pool.shutdown(wait=True)
self._thread_pool = None
print("[GLOBAL THREAD POOL] Shutdown complete")

def get_stats(self) -> Dict[str, Any]:
"""Get thread pool statistics"""
if self._thread_pool is None:
return {"status": "not_initialized"}

return {
"status": "active",
"max_workers": self._max_workers,
"active_threads": len(self._thread_pool._threads),
"queue_size": self._thread_pool._work_queue.qsize() if hasattr(self._thread_pool, '_work_queue') else 0
Comment on lines +1047 to +1048
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing private attributes _threads and _work_queue of ThreadPoolExecutor is fragile and can break with future Python updates. ThreadPoolExecutor does not provide a public API for these stats. While this might work now, it's a maintainability risk. Consider wrapping the executor to track tasks and provide statistics in a safer way if these stats are critical.

}


# Global instance
global_thread_pool_manager = GlobalThreadPoolManager()
12 changes: 12 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,18 @@ trainer:
# Timeout (in seconds) for Ray worker to wait for registration
ray_wait_register_center_timeout: 300

# Execution mode for Ray worker groups
# Options: "all"(default), "all_multithread", "rank_zero"
# "all": Execute all workers asynchronously
# "all_multithread": Use ThreadPoolExecutor for parallel submission
# "rank_zero": Execute only on rank 0
execute_mode: all

# Thread pool size for all_multithread execution mode
# If null, will be calculated as min(8, max(2, total_gpus))
# Recommended: set to worker count to avoid blocking
execute_thread_pool_size: null

# Device to run training on (e.g., "cuda", "cpu")
device: cuda

Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config/reward_model/reward_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ reward_manager: naive
# custom reward function executed async on CPU, during log_prob
launch_reward_fn_async: False

# Whether to launch custom reward function in a new thread during log_prob
# keep batch read-only during reward computation
launch_reward_fn_sub_thread: False

# Cloud/local sandbox fusion configuration for custom reward logic
sandbox_fusion:

Expand Down
Loading