diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index da728169f70f..3a7a8ad49e87 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -126,7 +126,8 @@ class SamplingParams: # Profiling profile: bool = False - num_profiled_timesteps: int = 2 + num_profiled_timesteps: int = 5 + profile_all_stages: bool = False # Debugging debug: bool = False @@ -226,7 +227,7 @@ def _adjust( if pipeline_config.task_type.is_image_gen(): # settle num_frames - logger.debug(f"Setting num_frames to 1 because this is a image-gen model") + logger.debug(f"Setting num_frames to 1 because this is an image-gen model") self.num_frames = 1 self.data_type = DataType.IMAGE else: @@ -329,24 +330,35 @@ def add_cli_args(parser: Any) -> Any: action="store_true", default=SamplingParams.enable_teacache, ) + + # profiling parser.add_argument( "--profile", action="store_true", default=SamplingParams.profile, help="Enable torch profiler for denoising stage", ) - parser.add_argument( - "--debug", - action="store_true", - default=SamplingParams.debug, - help="", - ) parser.add_argument( "--num-profiled-timesteps", type=int, default=SamplingParams.num_profiled_timesteps, help="Number of timesteps to profile after warmup", ) + parser.add_argument( + "--profile-all-stages", + action="store_true", + dest="profile_all_stages", + default=SamplingParams.profile_all_stages, + help="Used with --profile, profile all pipeline stages", + ) + + parser.add_argument( + "--debug", + action="store_true", + default=SamplingParams.debug, + help="", + ) + parser.add_argument( "--prompt", type=str, diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py b/python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py index dabe531c666f..52d85d32be18 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/executors/parallel_executor.py @@ -8,6 +8,7 @@ from sglang.multimodal_gen.runtime.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, + get_world_rank, ) from sglang.multimodal_gen.runtime.pipelines_core import Req from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import ( @@ -20,6 +21,9 @@ ) from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) class ParallelExecutor(PipelineExecutor): @@ -48,14 +52,16 @@ def collect_from_main(self, batches: list[Req]): src=self.worker.cfg_group.ranks[0], ) - def execute( + def _execute( self, stages: List[PipelineStage], batch: Req, server_args: ServerArgs, ) -> Req: + """ + Execute all pipeline stages respecting their declared parallelism type. + """ rank = get_classifier_free_guidance_rank() - cfg_rank = get_classifier_free_guidance_rank() cfg_group = get_cfg_group() # TODO: decide when to gather on main when CFG_PARALLEL -> MAIN_RANK_ONLY @@ -65,14 +71,8 @@ def execute( if paradigm == StageParallelismType.MAIN_RANK_ONLY: if rank == 0: + # Only main rank executes, others just wait batch = stage(batch, server_args) - # obj_list = [batch] if rank == 0 else [] - # - # broadcasted_list = broadcast_pyobj( - # obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0 - # ) - # if rank != 0: - # batch = broadcasted_list[0] torch.distributed.barrier() elif paradigm == StageParallelismType.CFG_PARALLEL: @@ -88,5 +88,22 @@ def execute( elif paradigm == StageParallelismType.REPLICATED: batch = stage(batch, server_args) + return batch + + def execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> Req: + rank = get_classifier_free_guidance_rank() + + if batch.profile and batch.profile_all_stages: + world_rank = get_world_rank() + else: + world_rank = 0 + + with self.profile_execution(batch, check_rank=rank, dump_rank=world_rank): + batch = self._execute(stages, batch, server_args) return batch diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py b/python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py index 917af3203ad7..772981380bd3 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/executors/pipeline_executor.py @@ -5,14 +5,19 @@ Base class for all pipeline executors. """ +import contextlib from abc import ABC, abstractmethod -from typing import List +from typing import TYPE_CHECKING, List from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req -from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler +from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler + +if TYPE_CHECKING: + # Only for type checkers; avoids runtime circular import + from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage logger = init_logger(__name__) @@ -41,7 +46,7 @@ def __init__(self, server_args): @abstractmethod def execute( self, - stages: List[PipelineStage], + stages: List["PipelineStage"], batch: Req, server_args: ServerArgs, ) -> Req: @@ -57,3 +62,28 @@ def execute( The processed batch. """ raise NotImplementedError + + @contextlib.contextmanager + def profile_execution(self, batch: Req, check_rank: int = 0, dump_rank: int = 0): + """ + Context manager for profiling execution. + """ + do_profile = batch.profile + + if not do_profile: + yield + return + + request_id = batch.request_id + profiler = SGLDiffusionProfiler( + request_id=request_id, + rank=check_rank, + full_profile=batch.profile_all_stages, + num_steps=batch.num_profiled_timesteps, + num_inference_steps=batch.num_inference_steps, + ) + try: + yield + finally: + should_export = check_rank == 0 + profiler.stop(export_trace=should_export, dump_rank=dump_rank) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py b/python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py index 5e429b3e8768..47491a111b49 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/executors/sync_executor.py @@ -8,8 +8,8 @@ from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import ( PipelineExecutor, + SGLDiffusionProfiler, Timer, - logger, ) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage @@ -21,19 +21,35 @@ class SyncExecutor(PipelineExecutor): A simple synchronous executor that runs stages sequentially. """ - def execute( + def run_profile_all_stages( self, stages: List[PipelineStage], batch: Req, server_args: ServerArgs, ) -> Req: """ - Execute the pipeline stages sequentially. + Execute all pipeline stages sequentially. """ - logger.info("Running pipeline stages sequentially with SyncExecutor.") - for stage in stages: with Timer(stage.__class__.__name__): batch = stage(batch, server_args) + profiler = SGLDiffusionProfiler.get_instance() + if profiler: + profiler.step_stage() + return batch + + def execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Execute the pipeline stages sequentially. + """ + + with self.profile_execution(batch, check_rank=0, dump_rank=0): + batch = self.run_profile_all_stages(stages, batch, server_args) + return batch diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py index 26bb392d6535..ec75af00fba8 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py @@ -171,7 +171,8 @@ class Req: # profile profile: bool = False - num_profiled_timesteps: int = 8 + profile_all_stages: bool = False + num_profiled_timesteps: int = None # debugging debug: bool = False diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py old mode 100644 new mode 100755 index 55727461fdac..399d1bd2902d --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -7,7 +7,6 @@ import inspect import math -import os import time import weakref from collections.abc import Iterable @@ -15,7 +14,6 @@ from typing import Any import torch -import torch.profiler from einops import rearrange from tqdm.auto import tqdm @@ -35,7 +33,6 @@ from sglang.multimodal_gen.runtime.distributed.parallel_state import ( get_cfg_group, get_classifier_free_guidance_rank, - get_world_rank, ) from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( FlashAttentionBackend, @@ -62,6 +59,7 @@ from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler +from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler from sglang.multimodal_gen.utils import dict_to_3d_list, masks_like try: @@ -745,57 +743,10 @@ def _postprocess_sp_latents( trajectory_tensor = trajectory_tensor[:, :, :orig_s, :] return latents, trajectory_tensor - def start_profile(self, batch: Req): - if not batch.profile: - return - - logger.info("Starting Profiler...") - # Build activities dynamically to avoid CUDA hangs when CUDA is unavailable - activities = [torch.profiler.ProfilerActivity.CPU] - if torch.cuda.is_available(): - activities.append(torch.profiler.ProfilerActivity.CUDA) - - self.profiler = torch.profiler.profile( - activities=activities, - schedule=torch.profiler.schedule( - skip_first=0, - wait=0, - warmup=1, - active=batch.num_profiled_timesteps, - repeat=5, - ), - on_trace_ready=None, - record_shapes=True, - with_stack=True, - ) - self.profiler.start() - def step_profile(self): - if self.profiler: - if torch.cuda.is_available(): - torch.cuda.synchronize() - self.profiler.step() - - def stop_profile(self, batch: Req): - try: - if self.profiler: - logger.info("Stopping Profiler...") - if torch.cuda.is_available(): - torch.cuda.synchronize() - self.profiler.stop() - request_id = batch.request_id if batch.request_id else "profile_trace" - log_dir = f"./logs" - os.makedirs(log_dir, exist_ok=True) - - rank = get_world_rank() - trace_path = os.path.abspath( - os.path.join(log_dir, f"{request_id}-rank{rank}.trace.json.gz") - ) - logger.info(f"Saving profiler traces to: {trace_path}") - self.profiler.export_chrome_trace(trace_path) - torch.distributed.barrier() - except Exception as e: - logger.error(f"{e}") + profiler = SGLDiffusionProfiler.get_instance() + if profiler: + profiler.step_denoising_step() def _manage_device_placement( self, @@ -968,8 +919,6 @@ def forward( # Run denoising loop denoising_start_time = time.time() - self.start_profile(batch=batch) - # to avoid device-sync caused by timestep comparison timesteps_cpu = timesteps.cpu() num_timesteps = timesteps_cpu.shape[0] @@ -1069,8 +1018,6 @@ def forward( self.step_profile() - self.stop_profile(batch) - denoising_end_time = time.time() if num_timesteps > 0: diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py index ec9dad38ab19..538bab8a5a8d 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py @@ -91,8 +91,6 @@ def forward( prompt_embeds = prepared_vars["prompt_embeds"] denoising_loop_start_time = time.time() - self.start_profile(batch=batch) - with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): # Skip if interrupted @@ -186,7 +184,6 @@ def forward( self.step_profile() - self.stop_profile(batch) denoising_loop_end_time = time.time() if len(timesteps) > 0: self.log_info( diff --git a/python/sglang/multimodal_gen/runtime/utils/perf_logger.py b/python/sglang/multimodal_gen/runtime/utils/perf_logger.py index 16f94255d906..ea53a88bd944 100644 --- a/python/sglang/multimodal_gen/runtime/utils/perf_logger.py +++ b/python/sglang/multimodal_gen/runtime/utils/perf_logger.py @@ -18,8 +18,11 @@ from sglang.multimodal_gen.runtime.utils.logging_utils import ( _SGLDiffusionLogger, get_is_main_process, + init_logger, ) +logger = init_logger(__name__) + @dataclasses.dataclass class RequestTimings: diff --git a/python/sglang/multimodal_gen/runtime/utils/profiler.py b/python/sglang/multimodal_gen/runtime/utils/profiler.py new file mode 100644 index 000000000000..ba42114e535c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/profiler.py @@ -0,0 +1,128 @@ +import os + +import torch + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SGLDiffusionProfiler: + """ + A wrapper around torch.profiler to simplify usage in pipelines. + Supports both full profiling and scheduled profiling. + + + 1. if profile_all_stages is on: profile all stages, including all denoising steps + 2. otherwise, if num_profiled_timesteps is specified: profile {num_profiled_timesteps} denoising steps. profile all steps if num_profiled_timesteps==-1 + """ + + _instance = None + + def __init__( + self, + request_id: str | None = None, + rank: int = 0, + full_profile: bool = False, + num_steps: int | None = None, + num_inference_steps: int | None = None, + log_dir: str = "./logs", + ): + self.request_id = request_id or "profile_trace" + self.rank = rank + self.full_profile = full_profile + self.log_dir = log_dir + + try: + os.makedirs(self.log_dir, exist_ok=True) + except OSError: + pass + + activities = [torch.profiler.ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(torch.profiler.ProfilerActivity.CUDA) + + common_torch_profiler_args = dict( + activities=activities, + record_shapes=True, + with_stack=True, + on_trace_ready=None, + ) + if self.full_profile: + # profile all stages + self.profiler = torch.profiler.profile(**common_torch_profiler_args) + self.profile_mode_id = "full stages" + else: + # profile denoising stage only + warmup = 1 + num_actual_steps = num_inference_steps if num_steps == -1 else num_steps + num_active_steps = num_actual_steps + warmup + self.profiler = torch.profiler.profile( + **common_torch_profiler_args, + schedule=torch.profiler.schedule( + skip_first=0, + wait=0, + warmup=warmup, + active=num_active_steps, + repeat=1, + ), + ) + self.profile_mode_id = f"{num_actual_steps} steps" + + logger.info(f"Profiling request: {request_id} for {self.profile_mode_id}...") + + self.has_stopped = False + + SGLDiffusionProfiler._instance = self + self.start() + + def start(self): + logger.info("Starting Profiler...") + self.profiler.start() + + def _step(self): + self.profiler.step() + + def step_stage(self): + if self.full_profile: + self._step() + + def step_denoising_step(self): + if not self.full_profile: + self._step() + + @classmethod + def get_instance(cls) -> "SGLDiffusionProfiler": + return cls._instance + + def stop(self, export_trace: bool = True, dump_rank: int | None = None): + if self.has_stopped: + return + self.has_stopped = True + logger.info("Stopping Profiler...") + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.profiler.stop() + + if export_trace: + self._export_trace(dump_rank) + + SGLDiffusionProfiler._instance = None + + def _export_trace(self, dump_rank: int | None = None): + if dump_rank is None: + dump_rank = self.rank + + try: + os.makedirs(self.log_dir, exist_ok=True) + sanitized_profile_mode_id = self.profile_mode_id.replace(" ", "_") + trace_path = os.path.abspath( + os.path.join( + self.log_dir, + f"{self.request_id}-{sanitized_profile_mode_id}-global-rank{dump_rank}.trace.json.gz", + ) + ) + logger.info(f"Saving profiler traces to: {trace_path}") + self.profiler.export_chrome_trace(trace_path) + except Exception as e: + logger.error(f"Failed to export trace: {e}")