-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[ray]{feat}: 1) add multithread execution mode and global thread pool management; 2) add launch_reward_fn_sub_thread #2861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
501915a
1e79170
8dc0867
7434cee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,9 @@ | |
| import logging | ||
| import time | ||
| from copy import deepcopy | ||
| from typing import Any, Optional | ||
| from typing import Any, Optional, Dict | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| import threading | ||
|
|
||
| import ray | ||
| from ray.experimental.state.api import get_actor | ||
|
|
@@ -45,7 +47,24 @@ class Functor: | |
| def __call__(this, *args, **kwargs): | ||
| args, kwargs = dispatch_fn(self, *args, **kwargs) | ||
| padding_count = kwargs.pop(_padding_size_key, 0) | ||
| output = execute_fn(method_name, *args, **kwargs) | ||
| time_start = time.time() | ||
| # Check for external custom execute_mode | ||
| custom_execute_mode = getattr(self, 'custom_execute_mode', None) | ||
| if custom_execute_mode: | ||
| try: | ||
| from verl.single_controller.base.decorator import get_predefined_execute_fn | ||
| execute_config = get_predefined_execute_fn(custom_execute_mode) | ||
| custom_execute_fn = getattr(self, execute_config["execute_fn_name"]) | ||
| output = custom_execute_fn(method_name, *args, **kwargs) | ||
| print(f"[EXECUTE MODE] Using custom {custom_execute_mode} for {method_name}") | ||
| except Exception as e: | ||
| print(f"[EXECUTE MODE ERROR] Failed to use custom {custom_execute_mode}, falling back to default: {e}") | ||
| output = execute_fn(method_name, *args, **kwargs) | ||
| else: | ||
| output = execute_fn(method_name, *args, **kwargs) | ||
|
|
||
| time_end = time.time() | ||
| print(f"[REMOTE EXECUTION] {method_name} on {len(self._workers)} workers submitted in {time_end - time_start:.4f} seconds") | ||
| if blocking: | ||
| output = ray.get(output) | ||
| output = collect_fn(self, output) | ||
|
|
@@ -674,6 +693,58 @@ def execute_all_async(self, method_name: str, *args, **kwargs): | |
|
|
||
| return [self._execute_remote_single_worker(worker, method_name, *args, **kwargs) for worker in self._workers] | ||
|
|
||
| def _make_remote_call_parallel(self, worker, method_name: str, *args, **kwargs): | ||
| """Helper function to make remote call (similar to _execute_remote_single_worker but without timing).""" | ||
| if self.fused_worker_used and method_name not in self.method_names: | ||
| remote_call = getattr(worker, self.fused_worker_execute_fn_name) | ||
| return remote_call.remote(f"{self.sub_cls_name}_fwmn_{method_name}", *args, **kwargs) | ||
| else: | ||
| remote_call = getattr(worker, method_name) | ||
| return remote_call.remote(*args, **kwargs) | ||
|
|
||
| def execute_all_multithread_submit(self, method_name: str, *args, **kwargs): | ||
| """Execute a method on all workers asynchronously with parallel submission. | ||
|
|
||
| This method submits all remote calls in parallel using thread pool, | ||
| then returns remote object references (similar to execute_all_async). | ||
|
|
||
| Args: | ||
| method_name: Name of the method to execute | ||
| *args: Positional arguments for the method | ||
| **kwargs: Keyword arguments for the method | ||
|
|
||
| Returns: | ||
| List of remote object references to the method executions | ||
| """ | ||
| # Handle argument slicing like execute_all_async | ||
| length = len(self._workers) | ||
| # Check if global thread pool is available (only initialized for multithread mode) | ||
| try: | ||
| thread_pool = global_thread_pool_manager.get_thread_pool() | ||
| except Exception as e: | ||
| print(f"[WARNING] Global thread pool not available, falling back to sync execution: {e}") | ||
|
Comment on lines
+724
to
+725
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return self.execute_all_async(method_name, *args, **kwargs) | ||
|
|
||
| if all(isinstance(arg, list) for arg in args) and all(isinstance(kwarg, list) for kwarg in kwargs.values()): | ||
| if all(len(arg) == length for arg in args) and all(len(kwarg) == length for kwarg in kwargs.values()): | ||
| # Submit all remote calls in parallel with sliced args | ||
| futures = [ | ||
| thread_pool.submit( | ||
| lambda i=i: self._make_remote_call_parallel( | ||
| self._workers[i], method_name, | ||
| *(arg[i] for arg in args), | ||
| **{k: v[i] for k, v in kwargs.items()} | ||
| ) | ||
| ) for i in range(length) | ||
| ] | ||
| return [future.result() for future in futures] | ||
|
|
||
| futures = [ | ||
| thread_pool.submit(self._make_remote_call_parallel, worker, method_name, *args, **kwargs) | ||
| for worker in self._workers | ||
| ] | ||
| return [future.result() for future in futures] | ||
|
|
||
| @property | ||
| def master_address(self): | ||
| return self._master_addr | ||
|
|
@@ -689,6 +760,19 @@ def workers(self): | |
| @property | ||
| def world_size(self): | ||
| return self._world_size | ||
|
|
||
| def set_execute_mode(self, execute_mode: str): | ||
| """Set custom execution mode | ||
|
|
||
| Args: | ||
| execute_mode: Execution mode, options: "multithread", "sync", "async", "rank_zero" | ||
| """ | ||
| self.custom_execute_mode = execute_mode | ||
| print(f"[WORKER GROUP] Set execute_mode to: {execute_mode}") | ||
|
|
||
| def get_execute_mode(self) -> str: | ||
| """Get current execution mode""" | ||
| return getattr(self, 'custom_execute_mode', 'default') | ||
|
|
||
|
|
||
| """ | ||
|
|
@@ -906,3 +990,64 @@ def create_colocated_worker_cls_fused(class_dict: dict[str, RayClassWithInitArgs | |
| cia.fused_worker_used = True | ||
|
|
||
| return cia | ||
|
|
||
|
|
||
| class GlobalThreadPoolManager: | ||
| """Global thread pool manager for shared remote call execution""" | ||
|
|
||
| _instance = None | ||
| _lock = threading.Lock() | ||
|
|
||
| def __new__(cls): | ||
| if cls._instance is None: | ||
| with cls._lock: | ||
| if cls._instance is None: | ||
| cls._instance = super().__new__(cls) | ||
| return cls._instance | ||
|
|
||
| def __init__(self): | ||
| if not hasattr(self, '_initialized'): | ||
| self._thread_pool = None | ||
| self._max_workers = 8 # Default max workers | ||
| self._initialized = True | ||
|
|
||
| def get_thread_pool(self, max_workers: Optional[int] = None) -> ThreadPoolExecutor: | ||
| """Get or create the global thread pool""" | ||
| if max_workers is not None: | ||
| self._max_workers = max_workers | ||
|
|
||
| if self._thread_pool is None: | ||
| with self._lock: | ||
| if self._thread_pool is None: | ||
| self._thread_pool = ThreadPoolExecutor( | ||
| max_workers=self._max_workers, | ||
| thread_name_prefix="global_ray_remote" | ||
| ) | ||
| print(f"[GLOBAL THREAD POOL] Created with {self._max_workers} workers") | ||
|
|
||
| return self._thread_pool | ||
|
|
||
| def shutdown(self): | ||
| """Shutdown the global thread pool""" | ||
| if self._thread_pool is not None: | ||
| with self._lock: | ||
| if self._thread_pool is not None: | ||
| self._thread_pool.shutdown(wait=True) | ||
| self._thread_pool = None | ||
| print("[GLOBAL THREAD POOL] Shutdown complete") | ||
|
|
||
| def get_stats(self) -> Dict[str, Any]: | ||
| """Get thread pool statistics""" | ||
| if self._thread_pool is None: | ||
| return {"status": "not_initialized"} | ||
|
|
||
| return { | ||
| "status": "active", | ||
| "max_workers": self._max_workers, | ||
| "active_threads": len(self._thread_pool._threads), | ||
| "queue_size": self._thread_pool._work_queue.qsize() if hasattr(self._thread_pool, '_work_queue') else 0 | ||
|
Comment on lines
+1047
to
+1048
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing private attributes |
||
| } | ||
|
|
||
|
|
||
| # Global instance | ||
| global_thread_pool_manager = GlobalThreadPoolManager() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad
Exceptioncan hide underlying issues and make debugging difficult. It's better to catch more specific exceptions that you expect to handle, such asAttributeErrorif a method is not found, orValueErrorfromget_predefined_execute_fn. If you must catch a broad exception, consider logging the full traceback for better diagnostics.