Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3e3562f
only compile forward
AichenF Nov 18, 2025
4414755
change mode of transformer
jianyingzhu Nov 18, 2025
96a6a4a
using max-autotune-no-cudagraphs
AichenF Nov 20, 2025
5f93ca7
restore
AichenF Nov 20, 2025
75d8aaa
Merge branch 'main' into feat/torch_compile
AichenF Nov 20, 2025
1119f6d
Merge branch 'main' into feat/torch_compile
jianyingzhu Nov 21, 2025
fe0da05
Merge branch 'main' into feat/torch_compile
AichenF Nov 27, 2025
d1581da
Merge branch 'main' into feat/torch_compile
jianyingzhu Nov 27, 2025
1191628
remove torch.compile in ComposedPipelineBase
AichenF Nov 27, 2025
be75a73
Merge branch 'feat/torch_compile' of https://github.com/AichenF/sglan…
AichenF Nov 27, 2025
99dbb39
add full stage profile and denoising profile
AichenF Nov 27, 2025
6c9c8b0
perf_logger
AichenF Nov 28, 2025
b2f38fb
Merge branch 'main' into feat/benchmark_profile
AichenF Nov 28, 2025
f8ae187
clean code
AichenF Nov 28, 2025
f148f37
clear cudaevent torchcompile
AichenF Dec 1, 2025
3c78b81
clear code
AichenF Dec 1, 2025
15c8b9d
Merge branch 'main' into feat/benchmark_profile
AichenF Dec 1, 2025
3376654
rename profiler cli
AichenF Dec 1, 2025
99badb0
clean up global profile args
AichenF Dec 2, 2025
2e44c17
clean sampling params
AichenF Dec 2, 2025
2c21b0c
refactor local function
AichenF Dec 2, 2025
0ffc283
refactor denoising
AichenF Dec 2, 2025
3af6d67
restore composed_pipeline_base
AichenF Dec 2, 2025
b8d4669
step profile
AichenF Dec 2, 2025
9d1eb14
clear comments
AichenF Dec 2, 2025
1b06323
fix format
AichenF Dec 2, 2025
e22e7b3
refactor, and add SGLDiffusionProfiler
mickqian Dec 5, 2025
ad9cb45
fix import error
AichenF Dec 7, 2025
43f871f
upd
mickqian Dec 7, 2025
55815f8
upd
mickqian Dec 7, 2025
c9bed7c
upd
mickqian Dec 7, 2025
f39d122
Merge remote-tracking branch 'origin/main' into feat/benchmark_profile
mickqian Dec 7, 2025
1d18f4d
upd
mickqian Dec 7, 2025
34c0aa0
upd
mickqian Dec 7, 2025
1db6481
upd
mickqian Dec 7, 2025
fa75c8d
fix
mickqian Dec 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions python/sglang/multimodal_gen/configs/sample/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class SamplingParams:

# Profiling
profile: bool = False
num_profiled_timesteps: int = 2
num_profiled_timesteps: int = 5
profile_all_stages: bool = False

# Debugging
debug: bool = False
Expand Down Expand Up @@ -226,7 +227,7 @@ def _adjust(

if pipeline_config.task_type.is_image_gen():
# settle num_frames
logger.debug(f"Setting num_frames to 1 because this is a image-gen model")
logger.debug(f"Setting num_frames to 1 because this is an image-gen model")
self.num_frames = 1
self.data_type = DataType.IMAGE
else:
Expand Down Expand Up @@ -329,24 +330,35 @@ def add_cli_args(parser: Any) -> Any:
action="store_true",
default=SamplingParams.enable_teacache,
)

# profiling
parser.add_argument(
"--profile",
action="store_true",
default=SamplingParams.profile,
help="Enable torch profiler for denoising stage",
)
parser.add_argument(
"--debug",
action="store_true",
default=SamplingParams.debug,
help="",
)
parser.add_argument(
"--num-profiled-timesteps",
type=int,
default=SamplingParams.num_profiled_timesteps,
help="Number of timesteps to profile after warmup",
)
parser.add_argument(
"--profile-all-stages",
action="store_true",
dest="profile_all_stages",
default=SamplingParams.profile_all_stages,
help="Used with --profile, profile all pipeline stages",
)

parser.add_argument(
"--debug",
action="store_true",
default=SamplingParams.debug,
help="",
)

parser.add_argument(
"--prompt",
type=str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_cfg_group,
get_classifier_free_guidance_rank,
get_world_rank,
)
from sglang.multimodal_gen.runtime.pipelines_core import Req
from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import (
Expand All @@ -20,6 +21,9 @@
)
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)


class ParallelExecutor(PipelineExecutor):
Expand Down Expand Up @@ -48,14 +52,16 @@ def collect_from_main(self, batches: list[Req]):
src=self.worker.cfg_group.ranks[0],
)

def execute(
def _execute(
self,
stages: List[PipelineStage],
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Execute all pipeline stages respecting their declared parallelism type.
"""
rank = get_classifier_free_guidance_rank()
cfg_rank = get_classifier_free_guidance_rank()
cfg_group = get_cfg_group()

# TODO: decide when to gather on main when CFG_PARALLEL -> MAIN_RANK_ONLY
Expand All @@ -65,14 +71,8 @@ def execute(

if paradigm == StageParallelismType.MAIN_RANK_ONLY:
if rank == 0:
# Only main rank executes, others just wait
batch = stage(batch, server_args)
# obj_list = [batch] if rank == 0 else []
#
# broadcasted_list = broadcast_pyobj(
# obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0
# )
# if rank != 0:
# batch = broadcasted_list[0]
torch.distributed.barrier()

elif paradigm == StageParallelismType.CFG_PARALLEL:
Expand All @@ -88,5 +88,22 @@ def execute(

elif paradigm == StageParallelismType.REPLICATED:
batch = stage(batch, server_args)
return batch

def execute(
self,
stages: List[PipelineStage],
batch: Req,
server_args: ServerArgs,
) -> Req:
rank = get_classifier_free_guidance_rank()

if batch.profile and batch.profile_all_stages:
world_rank = get_world_rank()
else:
world_rank = 0

with self.profile_execution(batch, check_rank=rank, dump_rank=world_rank):
batch = self._execute(stages, batch, server_args)

return batch
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@
Base class for all pipeline executors.
"""

import contextlib
from abc import ABC, abstractmethod
from typing import List
from typing import TYPE_CHECKING, List

from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler
from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler

if TYPE_CHECKING:
# Only for type checkers; avoids runtime circular import
from sglang.multimodal_gen.runtime.pipelines_core.stages.base import PipelineStage

logger = init_logger(__name__)

Expand Down Expand Up @@ -41,7 +46,7 @@ def __init__(self, server_args):
@abstractmethod
def execute(
self,
stages: List[PipelineStage],
stages: List["PipelineStage"],
batch: Req,
server_args: ServerArgs,
) -> Req:
Expand All @@ -57,3 +62,28 @@ def execute(
The processed batch.
"""
raise NotImplementedError

@contextlib.contextmanager
def profile_execution(self, batch: Req, check_rank: int = 0, dump_rank: int = 0):
"""
Context manager for profiling execution.
"""
do_profile = batch.profile

if not do_profile:
yield
return

request_id = batch.request_id
profiler = SGLDiffusionProfiler(
request_id=request_id,
rank=check_rank,
full_profile=batch.profile_all_stages,
num_steps=batch.num_profiled_timesteps,
num_inference_steps=batch.num_inference_steps,
)
try:
yield
finally:
should_export = check_rank == 0
profiler.stop(export_trace=should_export, dump_rank=dump_rank)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from sglang.multimodal_gen.runtime.pipelines_core.executors.pipeline_executor import (
PipelineExecutor,
SGLDiffusionProfiler,
Timer,
logger,
)
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage
Expand All @@ -21,19 +21,35 @@ class SyncExecutor(PipelineExecutor):
A simple synchronous executor that runs stages sequentially.
"""

def execute(
def run_profile_all_stages(
self,
stages: List[PipelineStage],
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Execute the pipeline stages sequentially.
Execute all pipeline stages sequentially.
"""
logger.info("Running pipeline stages sequentially with SyncExecutor.")

for stage in stages:
with Timer(stage.__class__.__name__):
batch = stage(batch, server_args)

profiler = SGLDiffusionProfiler.get_instance()
if profiler:
profiler.step_stage()
return batch

def execute(
self,
stages: List[PipelineStage],
batch: Req,
server_args: ServerArgs,
) -> Req:
"""
Execute the pipeline stages sequentially.
"""

with self.profile_execution(batch, check_rank=0, dump_rank=0):
batch = self.run_profile_all_stages(stages, batch, server_args)

return batch
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ class Req:

# profile
profile: bool = False
num_profiled_timesteps: int = 8
profile_all_stages: bool = False
num_profiled_timesteps: int = None

# debugging
debug: bool = False
Expand Down
61 changes: 4 additions & 57 deletions python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@

import inspect
import math
import os
import time
import weakref
from collections.abc import Iterable
from functools import lru_cache
from typing import Any

import torch
import torch.profiler
from einops import rearrange
from tqdm.auto import tqdm

Expand All @@ -35,7 +33,6 @@
from sglang.multimodal_gen.runtime.distributed.parallel_state import (
get_cfg_group,
get_classifier_free_guidance_rank,
get_world_rank,
)
from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import (
FlashAttentionBackend,
Expand All @@ -62,6 +59,7 @@
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
from sglang.multimodal_gen.runtime.utils.perf_logger import StageProfiler
from sglang.multimodal_gen.runtime.utils.profiler import SGLDiffusionProfiler
from sglang.multimodal_gen.utils import dict_to_3d_list, masks_like

try:
Expand Down Expand Up @@ -745,57 +743,10 @@ def _postprocess_sp_latents(
trajectory_tensor = trajectory_tensor[:, :, :orig_s, :]
return latents, trajectory_tensor

def start_profile(self, batch: Req):
if not batch.profile:
return

logger.info("Starting Profiler...")
# Build activities dynamically to avoid CUDA hangs when CUDA is unavailable
activities = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(torch.profiler.ProfilerActivity.CUDA)

self.profiler = torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
skip_first=0,
wait=0,
warmup=1,
active=batch.num_profiled_timesteps,
repeat=5,
),
on_trace_ready=None,
record_shapes=True,
with_stack=True,
)
self.profiler.start()

def step_profile(self):
if self.profiler:
if torch.cuda.is_available():
torch.cuda.synchronize()
self.profiler.step()

def stop_profile(self, batch: Req):
try:
if self.profiler:
logger.info("Stopping Profiler...")
if torch.cuda.is_available():
torch.cuda.synchronize()
self.profiler.stop()
request_id = batch.request_id if batch.request_id else "profile_trace"
log_dir = f"./logs"
os.makedirs(log_dir, exist_ok=True)

rank = get_world_rank()
trace_path = os.path.abspath(
os.path.join(log_dir, f"{request_id}-rank{rank}.trace.json.gz")
)
logger.info(f"Saving profiler traces to: {trace_path}")
self.profiler.export_chrome_trace(trace_path)
torch.distributed.barrier()
except Exception as e:
logger.error(f"{e}")
profiler = SGLDiffusionProfiler.get_instance()
if profiler:
profiler.step_denoising_step()

def _manage_device_placement(
self,
Expand Down Expand Up @@ -968,8 +919,6 @@ def forward(
# Run denoising loop
denoising_start_time = time.time()

self.start_profile(batch=batch)

# to avoid device-sync caused by timestep comparison
timesteps_cpu = timesteps.cpu()
num_timesteps = timesteps_cpu.shape[0]
Expand Down Expand Up @@ -1069,8 +1018,6 @@ def forward(

self.step_profile()

self.stop_profile(batch)

denoising_end_time = time.time()

if num_timesteps > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def forward(
prompt_embeds = prepared_vars["prompt_embeds"]

denoising_loop_start_time = time.time()
self.start_profile(batch=batch)

with self.progress_bar(total=len(timesteps)) as progress_bar:
for i, t in enumerate(timesteps):
# Skip if interrupted
Expand Down Expand Up @@ -186,7 +184,6 @@ def forward(

self.step_profile()

self.stop_profile(batch)
denoising_loop_end_time = time.time()
if len(timesteps) > 0:
self.log_info(
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/multimodal_gen/runtime/utils/perf_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
from sglang.multimodal_gen.runtime.utils.logging_utils import (
_SGLDiffusionLogger,
get_is_main_process,
init_logger,
)

logger = init_logger(__name__)


@dataclasses.dataclass
class RequestTimings:
Expand Down
Loading
Loading