Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions verl/utils/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..import_utils import is_nvtx_available
from .performance import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from .profile import DistProfiler, DistProfilerExtension, Profiler, ProfilerConfig
from .rollout_profile import rollout_profile_args

# Select marker implementations by availability, but keep DistProfiler as our dispatcher
if is_nvtx_available():
Expand All @@ -38,4 +39,5 @@
"ProfilerConfig",
"simple_timer",
"marked_timer",
"rollout_profile_args",
]
91 changes: 91 additions & 0 deletions verl/utils/profiler/rollout_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any

from omegaconf import DictConfig, OmegaConf


def rollout_profile_args(config: DictConfig, global_step: int = 1) -> dict[str, Any]:
"""
Generate profiling parameters for different rollout backends (currently supports sglang,
with vllm extension interface reserved)

Args:
config: Global configuration (Hydra DictConfig), must contain rollout related configurations
global_step: Current training global step number, used to distinguish profile
result directories for different steps

Returns:
Dictionary of profiling parameters corresponding to the backend

Raises:
NotImplementedError: Unsupported rollout backend
ValueError: Unsupported profiler tool/missing configuration
"""
backend = config.rollout.name.lower()
backend_profile_builders = {
"sglang": _get_sglang_profile_tags,
}

if backend not in backend_profile_builders:
raise NotImplementedError(
f"Unsupported rollout backend: {config.rollout.name}, "
f"currently supported: {list(backend_profile_builders.keys())}"
)

return backend_profile_builders[backend](config, global_step)


def _get_sglang_profile_tags(config: DictConfig, global_step: int) -> dict[str, Any]:
"""Generate profiling parameters for sglang backend"""
tool_to_activities = {
"torch": ["CPU", "GPU"],
"torch_memory": ["MEM"],
"cuda": ["CUDA_PROFILER"],
"RPD": ["RPD"],
}
profiler_tool = config.rollout.profiler.tool
if profiler_tool not in tool_to_activities:
raise ValueError(
f"Unsupported profiler tool for sglang backend: {profiler_tool}, \
supported tools: {list(tool_to_activities.keys())}"
)

# Profiling by stage of Prefill or Decode
profile_by_stage = OmegaConf.select(config, "rollout.profiler.tool_config.torch.profile_by_stage", default=False)
# Merge profiles from all ranks into a single trace
merge_profiles = OmegaConf.select(config, "rollout.profiler.tool_config.torch.merge_profiles", default=False)
rollout_start_step = OmegaConf.select(config, "rollout.profiler.tool_config.torch.step_start", default=1)
rollout_end_step = OmegaConf.select(config, "rollout.profiler.tool_config.torch.step_end", default=5)
rollout_num_steps = rollout_end_step - rollout_start_step

assert rollout_start_step > 0, f"Rollout start step must be greater than 0 for sglang, but got {rollout_start_step}"
assert rollout_num_steps > 0, f"Rollout num steps must be greater than 0 for sglang, but got {rollout_num_steps}"

base_save_path = config.rollout.profiler.save_path
output_dir = os.path.join(base_save_path, f"rollout_step_{global_step}")
os.makedirs(output_dir, exist_ok=True)

return {
"start_step": rollout_start_step,
"num_steps": rollout_num_steps,
"activities": tool_to_activities[profiler_tool],
"with_stack": True,
"record_shapes": True,
"output_dir": output_dir,
"profile_by_stage": profile_by_stage,
"merge_profiles": merge_profiles,
}
10 changes: 10 additions & 0 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
GPUMemoryLogger,
ProfilerConfig,
log_gpu_memory_usage,
rollout_profile_args,
simple_timer,
)
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
Expand Down Expand Up @@ -684,6 +685,12 @@ async def rollout_mode(self):
aggressive_empty_cache(force_sync=True)
set_expandable_segments(False)

if self.config.rollout.profiler.enable and self._do_profile:
await self.rollout.start_profile_auto_stop(
tags=rollout_profile_args(self.config, self._profile_step),
profile_ranks=self.config.rollout.profiler.ranks,
)

if self._is_offload_param:
load_megatron_model_to_gpu(self.actor.actor_module, load_grad=False)
log_gpu_memory_usage("After load actor params during rollout_mode", logger=logger)
Expand Down Expand Up @@ -955,11 +962,14 @@ def async_calls_finalize_fn_exec(self, blocking=False):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def start_profile(self, **kwargs) -> None:
"""Start profiling for the current rank in the current training step."""
self._profile_step = kwargs.get("profile_step", 1)
self._do_profile = True
self.profiler.start(**kwargs)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def stop_profile(self) -> None:
"""Stop profiling for the current rank in the current training step."""
self._do_profile = False
self.profiler.stop()

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
Expand Down
42 changes: 42 additions & 0 deletions verl/workers/rollout/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,48 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
"""
raise NotImplementedError

@abstractmethod
async def start_profile_auto_stop(self, **kwargs):
"""
Abstract method: Start profiling with auto-stop (wrapper for start_profile).

Args:
**kwargs: Must contain 'tags' dict with:
- "activities": List of profiled activity types
- "num_steps": Auto-stop step count

Returns:
Any: Engine response of profiling start
"""
pass

@abstractmethod
async def start_profile(self, tags: dict[str, any] = None, profile_ranks: list[int] = None):
"""
Abstract method: Start profiling (only for specified dp ranks).

Args:
tags: Profiling config (required: "activities"; optional: "num_steps")
profile_ranks: Target dp ranks (default: [0])

Returns:
Any: Engine response of profiling start
"""
pass

@abstractmethod
async def stop_profile(self, profile_ranks: list[int] = None):
"""
Abstract method: Stop profiling (only for specified dp ranks).

Args:
profile_ranks: Target dp ranks (default: [0])

Returns:
Any: Engine response of profiling stop
"""
pass


_ROLLOUT_REGISTRY = {
("vllm", "async"): "verl.workers.rollout.vllm_rollout.vLLMAsyncRollout",
Expand Down
21 changes: 21 additions & 0 deletions verl/workers/rollout/sglang_rollout/http_server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,27 @@ async def async_reward_score(
lora_path=lora_path,
)

async def start_profile(self, tags: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""Start profile

Args:
tags (Optional[dict[str, Any]], optional): Arguments for profiling. Defaults to None.

Returns:
Dict[str, Any]: Server response indicating profile status
"""
return await self._make_async_request("start_profile", payload=tags)

async def stop_profile(self) -> dict[str, Any]:
"""Stop profile

Args:
(No arguments)
Returns:
Dict[str, Any]: Server response indicating profile status
"""
return await self._make_async_request("stop_profile", payload=None)

async def abort_request(self, rid: str = "", abort_all: bool = False) -> dict[str, Any]:
"""Abort a request asynchronously.

Expand Down
72 changes: 72 additions & 0 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,75 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None

if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self._engine.flush_cache()

async def start_profile_auto_stop(self, **kwargs):
"""
Start performance profiling with auto-stop functionality.
This is a wrapper method that internally calls start_profile() with provided arguments.
Args:
**kwargs: Keyword arguments passed to start_profile().
Must include 'tags' dict containing:
- "activities": List of activity types to profile (e.g., ["cpu", "gpu"])
- "num_steps": Number of steps after which profiling should auto-stop
Raises:
AssertionError: If 'tags' is not provided in kwargs, or if required keys
("activities" and "num_steps") are missing from tags.
return: Engine response of profiling start
"""
assert "num_steps" in kwargs.get("tags", {}), (
"Missing required 'num_steps' in tags for auto-stop profiling for sglang. "
)
return await self.start_profile(**kwargs)

async def start_profile(self, tags: dict[str, Any] = None, profile_ranks: list[int] = None):
"""
Start performance profiling (only executed by specified rank processes)
Args:
tags: Profiling configuration tags, must contain "activities" (types of activities to profile),
optional "num_steps" (steps for auto-stop)
profile_ranks: List of dp ranks to perform profiling, default = [0]
return: Engine response of profiling start
"""
profile_ranks = profile_ranks or [0]
tags = tags or {}
if (
self.device_mesh["infer_tp"].get_local_rank() == 0
and self.device_mesh["dp"].get_local_rank() in profile_ranks
):
assert tags.get("activities") is not None, "Please specify the activities to profile."
await self._init_server_adapter()

response = await self._engine.start_profile(tags=tags)
if response:
self._profiling = True
self._rollout_profile_auto_stop = "num_steps" in tags
else:
self._profiling = None
logger.debug(f"Start profile done for rank {self.device_mesh['dp'].get_local_rank()}. Response: {response}")
return response
else:
return None

async def stop_profile(self, profile_ranks: list[int] = None):
"""
Stop performance profiling (only executed by specified rank processes)
Args:
profile_ranks: List of dp ranks to stop profiling, default = [0]
return: Engine response of profiling stop
"""
profile_ranks = profile_ranks or [0]
if (
self.device_mesh["infer_tp"].get_local_rank() == 0
and self.device_mesh["dp"].get_local_rank() in profile_ranks
):
logger.debug(f"Try to stopping rollout profile for rank {self.device_mesh['dp'].get_local_rank()}.")
if not self._profiling or self._rollout_profile_auto_stop:
return None

await self._init_server_adapter()
response = await self._engine.stop_profile()
if response:
self._profiling = False
return response
else:
return None