diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 0fb205e98e3..f3fd798132d 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1438,7 +1438,7 @@ def fit(self): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output) else: if curr_step_profile: - self.async_rollout_manager.start_profile() + self.async_rollout_manager.start_profile(global_step=self.global_steps) gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) if curr_step_profile: self.async_rollout_manager.stop_profile() diff --git a/verl/utils/profiler/config.py b/verl/utils/profiler/config.py index 732f20b8e92..4430d758698 100644 --- a/verl/utils/profiler/config.py +++ b/verl/utils/profiler/config.py @@ -36,6 +36,8 @@ def __post_init__(self) -> None: class TorchProfilerToolConfig(BaseConfig): """Torch profiler tool config.""" + step_start: int = 0 + step_end: int = -1 # options: cuda, cpu, memory, shapes, stack contents: list[str] = field(default_factory=list) discrete: bool = False @@ -43,9 +45,10 @@ class TorchProfilerToolConfig(BaseConfig): def __post_init__(self) -> None: """config validation logics go here""" + __support_contents = ["cuda", "cpu", "memory", "shapes", "stack", "profile-by-stage", "merge-profiles"] for content in self.contents: - assert content in ["cuda", "cpu", "memory", "shapes", "stack"], ( - f"Profiler contents only supports cuda, cpu, memory, shapes, stack, but gets {content}" + assert content in __support_contents, ( + f"Profiler contents only supports {__support_contents}, but gets {content}" ) assert isinstance(self.contents, list), f"Profiler contents must be of type list, got {type(self.contents)}" diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index e7e800622a6..728e6d69b0f 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -235,7 +235,7 @@ async def clear_kv_cache(self): async def start_profile(self, **kwargs): """Start profiling on the replica.""" - await asyncio.gather(*[server.start_profile.remote() for server in self.servers]) + await asyncio.gather(*[server.start_profile.remote(**kwargs) for server in self.servers]) async def stop_profile(self): """Stop profiling on the replica.""" diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index 75e79e0287e..7ec73181dd4 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -61,6 +61,87 @@ visible_devices_keyword = get_visible_devices_keyword() +class SGLangProfilerArgsBuilder: + """Builder for SGLang profiling parameters, decoupling profiler parameter logic from the core service class.""" + + def __init__( + self, + profiler_controller: DistProfiler, + rollout_config: RolloutConfig, + replica_rank: int, + ): + self.profiler_controller = profiler_controller + self.rollout_config = rollout_config + self.replica_rank = replica_rank + self.auto_stop_profiling = False + + def build_profile_args(self, **kwargs) -> dict[str, Any]: + global_step = kwargs.pop("global_step", 0) + config = self.profiler_controller.tool_config + contents = self.profiler_controller.tool_config.contents + + save_path = os.path.join( + self.rollout_config.profiler.save_path, + f"rollout_step_{global_step}", + f"agent_loop_replica_{self.replica_rank}", + ) + os.makedirs(save_path, exist_ok=True) + + profiler_tool = self.rollout_config.profiler.tool + activities: Optional[list[str]] = None + if contents and profiler_tool: + activities_tmp = [] + check_map = { + "cpu": ("CPU", "torch"), + "cuda|gpu": ("GPU", "torch"), + "MEM": ("MEM", "torch_memory"), + } + for key, (act, tool) in check_map.items(): + if any(k in contents for k in key.split("|")): + activities_tmp.append(act) + if profiler_tool != tool: + raise ValueError(f"{act} profiling requires '{tool}' (got '{profiler_tool}')") + for unsupported in ("CUDA_PROFILER", "RPD"): + if unsupported in contents: + raise NotImplementedError(f"{unsupported} profiling is not supported") + activities = activities_tmp if len(activities_tmp) > 0 else activities + + with_stack = bool(contents) and "stack" in contents + record_shapes = bool(contents) and "shapes" in contents + # Profiling by stage of Prefill or Decode + profile_by_stage = bool(contents) and "profile-by-stage" in contents + # Merge profiles from all ranks into a single trace + merge_profiles = bool(contents) and "merge-profiles" in contents + + # Rollout start step must be greater than 0 for sglang + rollout_start_step = config.step_start if config.step_end is not None else 1 + rollout_end_step = config.step_end if config.step_end is not None else -1 + rollout_num_steps = rollout_end_step - rollout_start_step + self.auto_stop_profiling = rollout_num_steps > 0 + + # num_steps must be greater than 0 or None in SGLang. + rollout_num_steps = None if rollout_num_steps <= 0 else rollout_num_steps + + if rollout_num_steps is None and profile_by_stage: + raise Exception( + "profile_by_stage requires rollout_num_steps to be set (possible limitation in sglang <= 0.5.5)" + ) + + # start_step must be greater than 0 for sglang + rollout_start_step = max(rollout_start_step, 1) + + return { + "start_step": rollout_start_step, + "num_steps": rollout_num_steps, + "activities": activities, + "with_stack": with_stack, + "record_shapes": record_shapes, + "output_dir": save_path, + "profile_by_stage": profile_by_stage, + "merge_profiles": merge_profiles, + }, self.auto_stop_profiling + + class SGLangHttpServer: """SGLang http server in single node, this is equivalent to launch server with command line: ``` @@ -395,20 +476,17 @@ async def start_profile(self, **kwargs): and self.profiler_controller.check_this_rank() and self.profiler_controller.is_discrete_mode() ): - contents = self.profiler_controller.tool_config.contents - save_path = os.path.join(self.config.profiler.save_path, f"agent_loop_replica_{self.replica_rank}") - await self.tokenizer_manager.start_profile( - output_dir=save_path, - with_stack=contents is None or "stack" in contents, - record_shapes=contents is None or "shapes" in contents, - **kwargs, - ) + profile_args, self._auto_stop_profiling = SGLangProfilerArgsBuilder( + profiler_controller=self.profiler_controller, rollout_config=self.config, replica_rank=self.replica_rank + ).build_profile_args(**kwargs) + await self.tokenizer_manager.start_profile(**profile_args) async def stop_profile(self): if ( self.profiler_controller.check_enable() and self.profiler_controller.check_this_rank() and self.profiler_controller.is_discrete_mode() + and not self._auto_stop_profiling ): await self.tokenizer_manager.stop_profile() diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index bb3c52ecb4e..4a0eb0d3fb3 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -609,6 +609,8 @@ async def sleep(self): logger.info("skip sleep in standalone mode") async def start_profile(self, **kwargs): + # TODO: Persist global_step to engine server-created file/path + kwargs.pop("global_step") if ( self.profiler_controller.check_enable() and self.profiler_controller.check_this_rank()