[ray]{feat}: 1) add multithread execution mode and global thread pool management; 2) add launch_reward_fn_sub_thread#2861
Conversation
…management; 2) add launch_reward_fn_sub_thread - Added `ALL_MULTITHREAD` to the `Execute` enum. - Updated `get_predefined_execute_fn` to return the appropriate function for the new execution mode. - Implemented `execute_all_multithread_submit` in `RayWorkerGroup` for parallel method execution. - Introduced `GlobalThreadPoolManager` to manage a shared thread pool across the application. - Updated configuration files to include options for the new execution mode and thread pool size. - Added `launch_reward_fn_sub_thread` configuration option for asynchronous reward computation. - Implemented `_async_compute_reward_wrapper` method in `RayPPOTrainer` for thread-safe reward computation. - Added conditional thread pool initialization based on `launch_reward_fn_sub_thread` setting. - Enhanced reward computation flow to support three modes: - `launch_reward_fn_sub_thread: True` - Use local thread pool for async computation - `launch_reward_fn_async: True` - Use Ray remote function for async computation - Both False - Synchronous computation - Added mutual exclusivity check between `launch_reward_fn_sub_thread` and `launch_reward_fn_async`. - Implemented proper thread pool cleanup in training completion. - Added `execute_mode` and `execute_thread_pool_size` options in `ppo_trainer.yaml` for user configuration. - Added `launch_reward_fn_sub_thread: False` in `reward_model.yaml` for reward computation mode control. - Renamed `thread_pool_size` to `execute_thread_pool_size` for better clarity. - Thread pool is only initialized when `launch_reward_fn_sub_thread` is enabled, reducing resource usage. - Reward computation wrapper ensures proper device handling in sub-threads. - Safe thread pool shutdown with null checks to prevent errors. - Maintains backward compatibility with existing reward computation methods. This enhancement aims to improve performance and flexibility in distributed execution scenarios, providing both parallel worker group execution and asynchronous reward computation capabilities.
There was a problem hiding this comment.
Code Review
This pull request introduces a multithreaded execution mode and asynchronous reward computation, which are great for improving performance. My review focuses on improving maintainability, correctness, and robustness. Key suggestions include refactoring duplicated code, fixing a critical bug in a fallback mechanism, correcting a faulty conditional check, and improving exception handling to avoid masking errors. Addressing these points will make the new features more reliable and easier to maintain.
| except Exception as e: | ||
| print(f"[EXECUTE MODE ERROR] Failed to use custom {custom_execute_mode}, falling back to default: {e}") |
There was a problem hiding this comment.
Catching a broad Exception can hide underlying issues and make debugging difficult. It's better to catch more specific exceptions that you expect to handle, such as AttributeError if a method is not found, or ValueError from get_predefined_execute_fn. If you must catch a broad exception, consider logging the full traceback for better diagnostics.
| except Exception as e: | ||
| print(f"[WARNING] Global thread pool not available, falling back to sync execution: {e}") |
There was a problem hiding this comment.
| "active_threads": len(self._thread_pool._threads), | ||
| "queue_size": self._thread_pool._work_queue.qsize() if hasattr(self._thread_pool, '_work_queue') else 0 |
There was a problem hiding this comment.
Accessing private attributes _threads and _work_queue of ThreadPoolExecutor is fragile and can break with future Python updates. ThreadPoolExecutor does not provide a public API for these stats. While this might work now, it's a maintainability risk. Consider wrapping the executor to track tasks and provide statistics in a safer way if these stats are critical.
| if OmegaConf.select(self.config.trainer, "execute_mode") is not None: | ||
| self.critic_wg.set_execute_mode(self.config.trainer.execute_mode) | ||
| self.critic_wg.init_model() | ||
|
|
||
| if self.use_reference_policy and not self.ref_in_actor: | ||
| self.ref_policy_wg = all_wg["ref"] | ||
| # Set execute mode for reference policy worker group | ||
| if OmegaConf.select(self.config.trainer, "execute_mode") is not None: | ||
| self.ref_policy_wg.set_execute_mode(self.config.trainer.execute_mode) | ||
| self.ref_policy_wg.init_model() | ||
|
|
||
| if self.use_rm: | ||
| self.rm_wg = all_wg["rm"] | ||
| # Set execute mode for reward model worker group | ||
| if OmegaConf.select(self.config.trainer, "execute_mode") is not None: | ||
| self.rm_wg.set_execute_mode(self.config.trainer.execute_mode) | ||
| self.rm_wg.init_model() | ||
|
|
||
| # we should create rollout at the end so that vllm can have a better estimation of kv cache memory | ||
| self.actor_rollout_wg = all_wg["actor_rollout"] | ||
| # Set execute mode for actor rollout worker group | ||
| if OmegaConf.select(self.config.trainer, "execute_mode") is not None: | ||
| self.actor_rollout_wg.set_execute_mode(self.config.trainer.execute_mode) |
There was a problem hiding this comment.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
What does this PR do?
We found that ray.remote() takes long time to submit tasks with
large batch. We can use a thread pool to run these ray task submissions in the single controller. In an experiment setting ofvllm, grpo, geo3k, 8 H100, we can speed up RL training: 1) reward_fn: 5s -> <1s; 2) execute worker (x3): 18s -> 12s.Here are some experiments of this feature.
1 REWARD FUNCTION
The
reward_computeblock is hidden in the timeline by usinglaunch_reward_fn_sub_thread=True.1.1 launch_reward_fn_async = True (default)
1.2 launch_reward_fn_sub_thread=True (new)
2. EXECUTE WORKERS
Start time of worker in each rank are different. Workers do not run util the slowest rank starts.
Multithread executecan help reduce this overhead cost from 18s to 12s.2.1 execute_all (default)
2.2 trainer.execute_mode=all_multithread (new)
main changes
Added
ALL_MULTITHREADto theExecuteenum.Updated
get_predefined_execute_fnto return the appropriate function for the new execution mode.Implemented
execute_all_multithread_submitinRayWorkerGroupfor parallel method execution.Introduced
GlobalThreadPoolManagerto manage a shared thread pool across the application.Updated configuration files to include options for the new execution mode and thread pool size.
Added
launch_reward_fn_sub_threadconfiguration option for asynchronous reward computation.Implemented
_async_compute_reward_wrappermethod inRayPPOTrainerfor thread-safe reward computation.Added conditional thread pool initialization based on
launch_reward_fn_sub_threadsetting.Enhanced reward computation flow to support three modes:
launch_reward_fn_sub_thread: True- Use local thread pool for async computationlaunch_reward_fn_async: True- Use Ray remote function for async computationAdded mutual exclusivity check between
launch_reward_fn_sub_threadandlaunch_reward_fn_async.Implemented proper thread pool cleanup in training completion.
Added
execute_modeandexecute_thread_pool_sizeoptions inppo_trainer.yamlfor user configuration.Added
launch_reward_fn_sub_thread: Falseinreward_model.yamlfor reward computation mode control.Renamed
thread_pool_sizetoexecute_thread_pool_sizefor better clarity.Thread pool is only initialized when
launch_reward_fn_sub_threadis enabled, reducing resource usage.Reward computation wrapper ensures proper device handling in sub-threads.
Safe thread pool shutdown with null checks to prevent errors.
Maintains backward compatibility with existing reward computation methods.
This enhancement aims to improve performance and flexibility in distributed execution scenarios, providing both parallel worker group execution and asynchronous reward computation capabilities.
Checklist Before Starting
[{modules}] {type}: {description}Test
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)