Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ def forward(
logger=logger,
metrics=batch.metrics,
perf_dump_path_provided=batch.perf_dump_path is not None,
record_as_step=True,
):
t_int = int(t_host.item())
t_device = timesteps[i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req:
logger=logger,
metrics=batch.metrics,
perf_dump_path_provided=batch.perf_dump_path is not None,
record_as_step=True,
):
t_int = int(t_host.item())
t_device = timesteps[i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def forward(
logger=logger,
metrics=batch.metrics,
perf_dump_path_provided=batch.perf_dump_path is not None,
record_as_step=True,
):
t_int = int(t.item())
if self.transformer_2 is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req:
logger=logger,
metrics=metrics,
perf_dump_path_provided=perf_dump_path_provided,
record_as_step=True,
):
pair_t = paired_timesteps[idx_step]
if getattr(pair_t, "shape", None) == (2,):
Expand Down
19 changes: 11 additions & 8 deletions python/sglang/multimodal_gen/runtime/utils/perf_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,8 @@ def record_stage(self, stage_name: str, duration_s: float):
"""Records the duration of a pipeline stage"""
self.stages[stage_name] = duration_s * 1000 # Store as milliseconds

def record_steps(self, index: int, duration_s: float):
"""Records the duration of a denoising step"""
assert index == len(self.steps)
def record_step(self, duration_s: float):
"""Records the duration of a denoising step in execution order."""
self.steps.append(duration_s * 1000)

def record_memory_snapshot(self, checkpoint_name: str, snapshot: MemorySnapshot):
Expand Down Expand Up @@ -192,6 +191,7 @@ def __init__(
log_stage_start_end: bool = False,
perf_dump_path_provided: bool = False,
capture_memory: bool = False,
record_as_step: bool = False,
):
self.stage_name = stage_name
self.metrics = metrics
Expand All @@ -200,6 +200,10 @@ def __init__(
self.log_timing = perf_dump_path_provided or envs.SGLANG_DIFFUSION_STAGE_LOGGING
self.log_stage_start_end = log_stage_start_end
self.capture_memory = capture_memory
self.record_as_step = record_as_step

def _should_record_as_step(self) -> bool:
return self.record_as_step or self.stage_name.startswith("denoising_step_")
Comment on lines +205 to +206
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in _should_record_as_step uses startswith("denoising_step_"). The previous implementation used "denoising_step_" in self.stage_name. While startswith is generally more precise, this is a slight change in behavior for any legacy custom stage names that might have contained but not started with that string. If such names exist, they will no longer be automatically categorized as steps unless record_as_step=True is explicitly passed.


def __enter__(self):
if self.log_stage_start_end:
Expand All @@ -211,7 +215,7 @@ def __enter__(self):
if (self.log_timing and self.metrics) or self.log_stage_start_end:
if (
os.environ.get("SGLANG_DIFFUSION_SYNC_STAGE_PROFILING", "0") == "1"
and self.stage_name.startswith("denoising_step_")
and self._should_record_as_step()
and torch.get_device_module().is_available()
):
torch.get_device_module().synchronize()
Expand All @@ -225,7 +229,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

if (
os.environ.get("SGLANG_DIFFUSION_SYNC_STAGE_PROFILING", "0") == "1"
and self.stage_name.startswith("denoising_step_")
and self._should_record_as_step()
and torch.get_device_module().is_available()
):
torch.get_device_module().synchronize()
Expand All @@ -247,9 +251,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
)

if self.log_timing and self.metrics:
if "denoising_step_" in self.stage_name:
index = int(self.stage_name[len("denoising_step_") :])
self.metrics.record_steps(index, execution_time_s)
if self._should_record_as_step():
self.metrics.record_step(execution_time_s)
else:
self.metrics.record_stage(self.stage_name, execution_time_s)

Expand Down
Loading