Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import contextlib
import dataclasses
import inspect
import json
import logging
import pickle
Expand Down Expand Up @@ -47,6 +48,9 @@
is_transfer_message,
)
from sglang.multimodal_gen.runtime.pipelines_core import Req
from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import (
clone_scheduler_runtime,
)
from sglang.multimodal_gen.runtime.utils.common import get_zmq_socket
from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
Expand Down Expand Up @@ -86,6 +90,9 @@
"trajectory_audio_latents",
"timestep",
"step_index",
# Request scheduler is a local runtime object cloned from the pipeline
# scheduler template. It may hold live mutable state and is not JSON-safe.
"scheduler",
"prompt_template",
"max_sequence_length",
# trace_ctx holds live OTel SDK objects that aren't JSON-serializable.
Expand Down Expand Up @@ -163,6 +170,45 @@ def _extract_extra_fields(extra: dict, scalar_fields: dict) -> None:
pass


def _init_request_scheduler_from_template(
scheduler_template: Any, req: Req, device: torch.device
) -> None:
scheduler = clone_scheduler_runtime(scheduler_template)
extra_kwargs = {}
mu = req.extra.get("mu") if hasattr(req, "extra") else None
if mu is not None:
extra_kwargs["mu"] = mu

set_timesteps_params = inspect.signature(scheduler.set_timesteps).parameters
timesteps = getattr(req, "timesteps", None)
sigmas = getattr(req, "sigmas", None)
num_steps = getattr(req, "num_inference_steps", None)

if sigmas is not None and "sigmas" in set_timesteps_params:
if isinstance(sigmas, torch.Tensor):
sigmas = sigmas.detach().cpu()
scheduler.set_timesteps(sigmas=sigmas, device=device, **extra_kwargs)
elif timesteps is not None and "timesteps" in set_timesteps_params:
if isinstance(timesteps, torch.Tensor):
timesteps = timesteps.detach().cpu()
scheduler.set_timesteps(timesteps=timesteps, device=device, **extra_kwargs)
elif num_steps is not None:
scheduler.set_timesteps(num_steps, device=device, **extra_kwargs)
else:
return

req.scheduler = scheduler
req.timesteps = scheduler.timesteps


def _init_disagg_request_scheduler(self: Scheduler, req: Req) -> None:
scheduler_template = self.worker.pipeline.get_module("scheduler")
if scheduler_template is None:
return
device = torch.device(f"cuda:{self.worker.local_rank}")
_init_request_scheduler_from_template(scheduler_template, req, device)


def extract_transfer_fields(req) -> tuple[dict, dict]:
"""Extract all transferable fields from a Req, split into tensors and scalars."""
tensor_fields = {}
Expand Down Expand Up @@ -817,17 +863,7 @@ def _disagg_prefetch_event_loop(self: Scheduler, role_name: str) -> None:
# Init scheduler timesteps on main thread (safe — no
# concurrent denoising loop can be running here).
if self._disagg_role == RoleType.DENOISER:
scheduler_mod = self.worker.pipeline.get_module("scheduler")
num_steps = getattr(req, "num_inference_steps", None)
if scheduler_mod is not None and num_steps is not None:
device = torch.device(f"cuda:{self.worker.local_rank}")
extra_kwargs = {}
mu = req.extra.get("mu") if hasattr(req, "extra") else None
if mu is not None:
extra_kwargs["mu"] = mu
scheduler_mod.set_timesteps(
num_steps, device=device, **extra_kwargs
)
_init_disagg_request_scheduler(self, req)
# Run compute
if self._disagg_role == RoleType.DENOISER:
self._disagg_denoiser_compute(req, request_id, rn)
Expand Down Expand Up @@ -1194,15 +1230,7 @@ def _handle_transfer_ready(self: Scheduler, msg: dict) -> None:

# 3. Init scheduler timesteps if denoiser (CPU work, overlapped)
if self._disagg_role == RoleType.DENOISER:
scheduler_mod = self.worker.pipeline.get_module("scheduler")
num_steps = getattr(req, "num_inference_steps", None)
if scheduler_mod is not None and num_steps is not None:
device = torch.device(local_device)
extra_kwargs = {}
mu = req.extra.get("mu") if hasattr(req, "extra") else None
if mu is not None:
extra_kwargs["mu"] = mu
scheduler_mod.set_timesteps(num_steps, device=device, **extra_kwargs)
_init_disagg_request_scheduler(self, req)

# 4. Wait for load before compute (GPU must see the data)
if load_event is not None:
Expand Down Expand Up @@ -1246,15 +1274,7 @@ def _disagg_compute_non_rank0(self: Scheduler, req: Req) -> None:
"""
if self._disagg_role == RoleType.DENOISER:
# Initialize scheduler timesteps (same as rank 0)
scheduler_mod = self.worker.pipeline.get_module("scheduler")
num_steps = getattr(req, "num_inference_steps", None)
if scheduler_mod is not None and num_steps is not None:
device = torch.device(f"cuda:{self.worker.local_rank}")
extra_kwargs = {}
mu = req.extra.get("mu") if hasattr(req, "extra") else None
if mu is not None:
extra_kwargs["mu"] = mu
scheduler_mod.set_timesteps(num_steps, device=device, **extra_kwargs)
_init_disagg_request_scheduler(self, req)

with self._disagg_trace_dispatch(req):
self.worker.execute_forward([req], return_req=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from copy import deepcopy
from typing import Any

from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req


def clone_scheduler_runtime(scheduler: Any) -> Any:
"""Create an isolated scheduler runtime from a scheduler template or runtime."""
return deepcopy(scheduler)


def get_or_create_request_scheduler(
batch: Req, scheduler_template: Any, *, isolate: bool = False
) -> Any:
"""Return the scheduler runtime for this request.

Diffusion serving currently executes one request at a time on the normal
worker path, so reusing the stage-local scheduler preserves warmup caches
and avoids unnecessary deepcopy overhead. Set ``isolate=True`` only when a
request can run concurrently or outlive the stage-local scheduler state.
"""
if batch.scheduler is None:
batch.scheduler = (
clone_scheduler_runtime(scheduler_template)
if isolate
else scheduler_template
)
return batch.scheduler
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ class Req:
timestep: torch.Tensor | float | int | None = None
step_index: int | None = None

# request-local scheduler used by timestep/denoising stages.
# This is optional because the normal worker path executes one request at a time, so it can
# point at the stage-local scheduler and preserve warmup/device caches.
# Request-local cloned schedulers are only needed when a request can run
# concurrently with another request or outlive the stage-local scheduler
# state, such as grouped execution or disaggregation.
scheduler: Any | None = None

eta: float = 0.0
sigmas: list[float] | None = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from sglang.multimodal_gen.runtime.distributed import get_local_torch_device
from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context
from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video
from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import (
get_or_create_request_scheduler,
)
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage
from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import (
Expand Down Expand Up @@ -58,6 +61,7 @@ def forward(
autocast_enabled = (
target_dtype != torch.float32
) and not server_args.disable_autocast
scheduler = get_or_create_request_scheduler(batch, self.scheduler)

latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2]
patch_ratio = (
Expand All @@ -76,7 +80,7 @@ def forward(
if server_args.pipeline_config.warp_denoising_step:
logger.info("Warping timesteps...")
scheduler_timesteps = torch.cat(
(self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))
(scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32))
)
timesteps = scheduler_timesteps[1000 - timesteps]
timesteps = timesteps.to(get_local_torch_device())
Expand Down Expand Up @@ -317,7 +321,7 @@ def forward(
pred_noise=pred_noise_btchw.flatten(0, 1),
noise_input_latent=noise_latents.flatten(0, 1),
timestep=t_expand,
scheduler=self.scheduler,
scheduler=scheduler,
).unflatten(0, pred_noise_btchw.shape[:2])

if i < len(timesteps) - 1:
Expand All @@ -335,7 +339,7 @@ def forward(
device=self.device,
)
noise_btchw = noise
noise_latents_btchw = self.scheduler.add_noise(
noise_latents_btchw = scheduler.add_noise(
pred_video_btchw.flatten(0, 1),
noise_btchw.flatten(0, 1),
next_timestep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
class DenoisingContext:
"""Loop-scoped state shared across the denoising skeleton and its hooks."""

scheduler: Any
extra_step_kwargs: dict[str, Any]
target_dtype: torch.dtype
autocast_enabled: bool
Expand Down Expand Up @@ -469,6 +470,7 @@ def _handle_boundary_ratio(
self,
server_args,
batch,
scheduler,
):
"""
(Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
Expand All @@ -483,7 +485,10 @@ def _handle_boundary_ratio(
boundary_ratio = batch.boundary_ratio

if boundary_ratio is not None:
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps
num_train_timesteps = getattr(scheduler, "num_train_timesteps", None)
if num_train_timesteps is None:
num_train_timesteps = scheduler.config.num_train_timesteps
boundary_timestep = boundary_ratio * num_train_timesteps
else:
boundary_timestep = None

Expand All @@ -498,12 +503,14 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
"""
assert self.transformer is not None
pipeline = self.pipeline() if self.pipeline else None
scheduler = batch.scheduler
assert scheduler is not None

boundary_timestep = self._handle_boundary_ratio(server_args, batch)
boundary_timestep = self._handle_boundary_ratio(server_args, batch, scheduler)
# Get timesteps and calculate warmup steps
timesteps = batch.timesteps
num_inference_steps = batch.num_inference_steps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order

if self.transformer_2 is not None:
assert boundary_timestep is not None, "boundary_timestep must be provided"
Expand Down Expand Up @@ -533,7 +540,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):

# Prepare extra step kwargs for scheduler
extra_step_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.step,
scheduler.step,
{"generator": batch.generator, "eta": batch.eta, "batch": batch},
)

Expand Down Expand Up @@ -654,6 +661,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs):
neg_cond_kwargs = {}

return DenoisingContext(
scheduler=scheduler,
extra_step_kwargs=extra_step_kwargs,
target_dtype=target_dtype,
autocast_enabled=autocast_enabled,
Expand All @@ -676,7 +684,27 @@ def _before_denoising_loop(
self, ctx: DenoisingContext, batch: Req, server_args: ServerArgs
) -> None:
"""Prepare scheduler state before entering the shared denoising loop."""
self.scheduler.set_begin_index(0)
self._reset_scheduler_loop_state(ctx.scheduler)
ctx.scheduler.set_begin_index(0)

def _reset_scheduler_loop_state(self, scheduler) -> None:
if hasattr(scheduler, "_step_index"):
scheduler._step_index = None
if hasattr(scheduler, "_begin_index"):
scheduler._begin_index = None
if hasattr(scheduler, "lower_order_nums"):
scheduler.lower_order_nums = 0
if hasattr(scheduler, "last_sample"):
scheduler.last_sample = None
if hasattr(scheduler, "this_order"):
scheduler.this_order = 0

solver_order = getattr(getattr(scheduler, "config", None), "solver_order", 0)
Comment thread
mickqian marked this conversation as resolved.
if solver_order:
if hasattr(scheduler, "model_outputs"):
scheduler.model_outputs = [None] * solver_order
if hasattr(scheduler, "timestep_list"):
scheduler.timestep_list = [None] * solver_order

def _prepare_step_state(
self,
Expand Down Expand Up @@ -779,7 +807,7 @@ def _run_denoising_step(
)

# 3. Apply scheduler-side input scaling before the model forward.
latent_model_input = self.scheduler.scale_model_input(
latent_model_input = ctx.scheduler.scale_model_input(
latent_model_input, step.t_device
)

Expand All @@ -804,7 +832,7 @@ def _run_denoising_step(
batch.noise_pred = noise_pred

# 5. Advance the scheduler state with the predicted noise.
ctx.latents = self.scheduler.step(
ctx.latents = ctx.scheduler.step(
model_output=noise_pred,
timestep=step.t_device,
sample=ctx.latents,
Expand Down Expand Up @@ -1152,7 +1180,7 @@ def forward(

if step_index == num_timesteps - 1 or (
(step_index + 1) > ctx.num_warmup_steps
and (step_index + 1) % self.scheduler.order == 0
and (step_index + 1) % ctx.scheduler.order == 0
and progress_bar is not None
):
progress_bar.update()
Expand Down
Loading
Loading