diff --git a/python/sglang/multimodal_gen/runtime/disaggregation/scheduler_mixin.py b/python/sglang/multimodal_gen/runtime/disaggregation/scheduler_mixin.py index d9e6506190a6..615dfca3c043 100644 --- a/python/sglang/multimodal_gen/runtime/disaggregation/scheduler_mixin.py +++ b/python/sglang/multimodal_gen/runtime/disaggregation/scheduler_mixin.py @@ -10,6 +10,7 @@ import contextlib import dataclasses +import inspect import json import logging import pickle @@ -47,6 +48,9 @@ is_transfer_message, ) from sglang.multimodal_gen.runtime.pipelines_core import Req +from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( + clone_scheduler_runtime, +) from sglang.multimodal_gen.runtime.utils.common import get_zmq_socket from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -86,6 +90,9 @@ "trajectory_audio_latents", "timestep", "step_index", + # Request scheduler is a local runtime object cloned from the pipeline + # scheduler template. It may hold live mutable state and is not JSON-safe. + "scheduler", "prompt_template", "max_sequence_length", # trace_ctx holds live OTel SDK objects that aren't JSON-serializable. @@ -163,6 +170,45 @@ def _extract_extra_fields(extra: dict, scalar_fields: dict) -> None: pass +def _init_request_scheduler_from_template( + scheduler_template: Any, req: Req, device: torch.device +) -> None: + scheduler = clone_scheduler_runtime(scheduler_template) + extra_kwargs = {} + mu = req.extra.get("mu") if hasattr(req, "extra") else None + if mu is not None: + extra_kwargs["mu"] = mu + + set_timesteps_params = inspect.signature(scheduler.set_timesteps).parameters + timesteps = getattr(req, "timesteps", None) + sigmas = getattr(req, "sigmas", None) + num_steps = getattr(req, "num_inference_steps", None) + + if sigmas is not None and "sigmas" in set_timesteps_params: + if isinstance(sigmas, torch.Tensor): + sigmas = sigmas.detach().cpu() + scheduler.set_timesteps(sigmas=sigmas, device=device, **extra_kwargs) + elif timesteps is not None and "timesteps" in set_timesteps_params: + if isinstance(timesteps, torch.Tensor): + timesteps = timesteps.detach().cpu() + scheduler.set_timesteps(timesteps=timesteps, device=device, **extra_kwargs) + elif num_steps is not None: + scheduler.set_timesteps(num_steps, device=device, **extra_kwargs) + else: + return + + req.scheduler = scheduler + req.timesteps = scheduler.timesteps + + +def _init_disagg_request_scheduler(self: Scheduler, req: Req) -> None: + scheduler_template = self.worker.pipeline.get_module("scheduler") + if scheduler_template is None: + return + device = torch.device(f"cuda:{self.worker.local_rank}") + _init_request_scheduler_from_template(scheduler_template, req, device) + + def extract_transfer_fields(req) -> tuple[dict, dict]: """Extract all transferable fields from a Req, split into tensors and scalars.""" tensor_fields = {} @@ -817,17 +863,7 @@ def _disagg_prefetch_event_loop(self: Scheduler, role_name: str) -> None: # Init scheduler timesteps on main thread (safe — no # concurrent denoising loop can be running here). if self._disagg_role == RoleType.DENOISER: - scheduler_mod = self.worker.pipeline.get_module("scheduler") - num_steps = getattr(req, "num_inference_steps", None) - if scheduler_mod is not None and num_steps is not None: - device = torch.device(f"cuda:{self.worker.local_rank}") - extra_kwargs = {} - mu = req.extra.get("mu") if hasattr(req, "extra") else None - if mu is not None: - extra_kwargs["mu"] = mu - scheduler_mod.set_timesteps( - num_steps, device=device, **extra_kwargs - ) + _init_disagg_request_scheduler(self, req) # Run compute if self._disagg_role == RoleType.DENOISER: self._disagg_denoiser_compute(req, request_id, rn) @@ -1194,15 +1230,7 @@ def _handle_transfer_ready(self: Scheduler, msg: dict) -> None: # 3. Init scheduler timesteps if denoiser (CPU work, overlapped) if self._disagg_role == RoleType.DENOISER: - scheduler_mod = self.worker.pipeline.get_module("scheduler") - num_steps = getattr(req, "num_inference_steps", None) - if scheduler_mod is not None and num_steps is not None: - device = torch.device(local_device) - extra_kwargs = {} - mu = req.extra.get("mu") if hasattr(req, "extra") else None - if mu is not None: - extra_kwargs["mu"] = mu - scheduler_mod.set_timesteps(num_steps, device=device, **extra_kwargs) + _init_disagg_request_scheduler(self, req) # 4. Wait for load before compute (GPU must see the data) if load_event is not None: @@ -1246,15 +1274,7 @@ def _disagg_compute_non_rank0(self: Scheduler, req: Req) -> None: """ if self._disagg_role == RoleType.DENOISER: # Initialize scheduler timesteps (same as rank 0) - scheduler_mod = self.worker.pipeline.get_module("scheduler") - num_steps = getattr(req, "num_inference_steps", None) - if scheduler_mod is not None and num_steps is not None: - device = torch.device(f"cuda:{self.worker.local_rank}") - extra_kwargs = {} - mu = req.extra.get("mu") if hasattr(req, "extra") else None - if mu is not None: - extra_kwargs["mu"] = mu - scheduler_mod.set_timesteps(num_steps, device=device, **extra_kwargs) + _init_disagg_request_scheduler(self, req) with self._disagg_trace_dispatch(req): self.worker.execute_forward([req], return_req=True) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/diffusion_scheduler_utils.py b/python/sglang/multimodal_gen/runtime/pipelines_core/diffusion_scheduler_utils.py new file mode 100644 index 000000000000..04d3628f00f4 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/diffusion_scheduler_utils.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from copy import deepcopy +from typing import Any + +from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req + + +def clone_scheduler_runtime(scheduler: Any) -> Any: + """Create an isolated scheduler runtime from a scheduler template or runtime.""" + return deepcopy(scheduler) + + +def get_or_create_request_scheduler( + batch: Req, scheduler_template: Any, *, isolate: bool = False +) -> Any: + """Return the scheduler runtime for this request. + + Diffusion serving currently executes one request at a time on the normal + worker path, so reusing the stage-local scheduler preserves warmup caches + and avoids unnecessary deepcopy overhead. Set ``isolate=True`` only when a + request can run concurrently or outlive the stage-local scheduler state. + """ + if batch.scheduler is None: + batch.scheduler = ( + clone_scheduler_runtime(scheduler_template) + if isolate + else scheduler_template + ) + return batch.scheduler diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py index 7f5b332e940e..b2bfc02e5a4f 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py @@ -129,6 +129,14 @@ class Req: timestep: torch.Tensor | float | int | None = None step_index: int | None = None + # request-local scheduler used by timestep/denoising stages. + # This is optional because the normal worker path executes one request at a time, so it can + # point at the stage-local scheduler and preserve warmup/device caches. + # Request-local cloned schedulers are only needed when a request can run + # concurrently with another request or outlive the stage-local scheduler + # state, such as grouped execution or disaggregation. + scheduler: Any | None = None + eta: float = 0.0 sigmas: list[float] | None = None diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py index f8640186e166..414968e666bc 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/causal_denoising.py @@ -5,6 +5,9 @@ from sglang.multimodal_gen.runtime.distributed import get_local_torch_device from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video +from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( + get_or_create_request_scheduler, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import DenoisingStage from sglang.multimodal_gen.runtime.pipelines_core.stages.validators import ( @@ -58,6 +61,7 @@ def forward( autocast_enabled = ( target_dtype != torch.float32 ) and not server_args.disable_autocast + scheduler = get_or_create_request_scheduler(batch, self.scheduler) latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] patch_ratio = ( @@ -76,7 +80,7 @@ def forward( if server_args.pipeline_config.warp_denoising_step: logger.info("Warping timesteps...") scheduler_timesteps = torch.cat( - (self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)) + (scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)) ) timesteps = scheduler_timesteps[1000 - timesteps] timesteps = timesteps.to(get_local_torch_device()) @@ -317,7 +321,7 @@ def forward( pred_noise=pred_noise_btchw.flatten(0, 1), noise_input_latent=noise_latents.flatten(0, 1), timestep=t_expand, - scheduler=self.scheduler, + scheduler=scheduler, ).unflatten(0, pred_noise_btchw.shape[:2]) if i < len(timesteps) - 1: @@ -335,7 +339,7 @@ def forward( device=self.device, ) noise_btchw = noise - noise_latents_btchw = self.scheduler.add_noise( + noise_latents_btchw = scheduler.add_noise( pred_video_btchw.flatten(0, 1), noise_btchw.flatten(0, 1), next_timestep, diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index 4cc697e07ff2..2dc23b4c2e63 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -101,6 +101,7 @@ class DenoisingContext: """Loop-scoped state shared across the denoising skeleton and its hooks.""" + scheduler: Any extra_step_kwargs: dict[str, Any] target_dtype: torch.dtype autocast_enabled: bool @@ -469,6 +470,7 @@ def _handle_boundary_ratio( self, server_args, batch, + scheduler, ): """ (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert @@ -483,7 +485,10 @@ def _handle_boundary_ratio( boundary_ratio = batch.boundary_ratio if boundary_ratio is not None: - boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + num_train_timesteps = getattr(scheduler, "num_train_timesteps", None) + if num_train_timesteps is None: + num_train_timesteps = scheduler.config.num_train_timesteps + boundary_timestep = boundary_ratio * num_train_timesteps else: boundary_timestep = None @@ -498,12 +503,14 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): """ assert self.transformer is not None pipeline = self.pipeline() if self.pipeline else None + scheduler = batch.scheduler + assert scheduler is not None - boundary_timestep = self._handle_boundary_ratio(server_args, batch) + boundary_timestep = self._handle_boundary_ratio(server_args, batch, scheduler) # Get timesteps and calculate warmup steps timesteps = batch.timesteps num_inference_steps = batch.num_inference_steps - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order if self.transformer_2 is not None: assert boundary_timestep is not None, "boundary_timestep must be provided" @@ -533,7 +540,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): # Prepare extra step kwargs for scheduler extra_step_kwargs = self.prepare_extra_func_kwargs( - self.scheduler.step, + scheduler.step, {"generator": batch.generator, "eta": batch.eta, "batch": batch}, ) @@ -654,6 +661,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): neg_cond_kwargs = {} return DenoisingContext( + scheduler=scheduler, extra_step_kwargs=extra_step_kwargs, target_dtype=target_dtype, autocast_enabled=autocast_enabled, @@ -676,7 +684,27 @@ def _before_denoising_loop( self, ctx: DenoisingContext, batch: Req, server_args: ServerArgs ) -> None: """Prepare scheduler state before entering the shared denoising loop.""" - self.scheduler.set_begin_index(0) + self._reset_scheduler_loop_state(ctx.scheduler) + ctx.scheduler.set_begin_index(0) + + def _reset_scheduler_loop_state(self, scheduler) -> None: + if hasattr(scheduler, "_step_index"): + scheduler._step_index = None + if hasattr(scheduler, "_begin_index"): + scheduler._begin_index = None + if hasattr(scheduler, "lower_order_nums"): + scheduler.lower_order_nums = 0 + if hasattr(scheduler, "last_sample"): + scheduler.last_sample = None + if hasattr(scheduler, "this_order"): + scheduler.this_order = 0 + + solver_order = getattr(getattr(scheduler, "config", None), "solver_order", 0) + if solver_order: + if hasattr(scheduler, "model_outputs"): + scheduler.model_outputs = [None] * solver_order + if hasattr(scheduler, "timestep_list"): + scheduler.timestep_list = [None] * solver_order def _prepare_step_state( self, @@ -779,7 +807,7 @@ def _run_denoising_step( ) # 3. Apply scheduler-side input scaling before the model forward. - latent_model_input = self.scheduler.scale_model_input( + latent_model_input = ctx.scheduler.scale_model_input( latent_model_input, step.t_device ) @@ -804,7 +832,7 @@ def _run_denoising_step( batch.noise_pred = noise_pred # 5. Advance the scheduler state with the predicted noise. - ctx.latents = self.scheduler.step( + ctx.latents = ctx.scheduler.step( model_output=noise_pred, timestep=step.t_device, sample=ctx.latents, @@ -1152,7 +1180,7 @@ def forward( if step_index == num_timesteps - 1 or ( (step_index + 1) > ctx.num_warmup_steps - and (step_index + 1) % self.scheduler.order == 0 + and (step_index + 1) % ctx.scheduler.order == 0 and progress_bar is not None ): progress_bar.update() diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py index c8c488c3fc8a..0df3e14bb6ec 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_av.py @@ -1,9 +1,10 @@ -import copy - import torch from diffusers.utils.torch_utils import randn_tensor from sglang.multimodal_gen.configs.pipeline_configs.ltx_2 import is_ltx23_native_variant +from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( + clone_scheduler_runtime, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.ltx_2_denoising import ( LTX2DenoisingStage, @@ -331,13 +332,12 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: device=batch.audio_latents.device, dtype=torch.float32 ) - original_scheduler = self.scheduler + original_batch_scheduler = batch.scheduler original_batch_timesteps = batch.timesteps original_batch_num_inference_steps = batch.num_inference_steps - self.scheduler = copy.deepcopy(original_scheduler) - distilled_device = self.scheduler.sigmas.device - num_steps = len(self.distilled_sigmas) - 1 + scheduler = clone_scheduler_runtime(original_batch_scheduler or self.scheduler) + distilled_device = scheduler.sigmas.device # Inject `0.0011` before the terminal `0.0` to avoid the # `sigma_next==0` singularity in res2s' `(sample - denoised) / # (sigma - sigma_next)`. Official `res2s_denoising_loop` does this @@ -354,15 +354,18 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: ) else: scheduler_sigmas = self.distilled_sigmas - self.scheduler.sigmas = scheduler_sigmas.to(distilled_device) - self.scheduler.num_inference_steps = num_steps - self.scheduler.timesteps = (self.distilled_sigmas[:num_steps] * 1000).to( + + scheduler.sigmas = scheduler_sigmas + num_steps = len(self.distilled_sigmas) - 1 + scheduler.num_inference_steps = num_steps + scheduler.timesteps = (self.distilled_sigmas[:num_steps] * 1000).to( distilled_device ) - self.scheduler._step_index = None - self.scheduler._begin_index = None + scheduler._step_index = None + scheduler._begin_index = None - batch.timesteps = self.scheduler.timesteps + batch.scheduler = scheduler + batch.timesteps = scheduler.timesteps batch.num_inference_steps = num_steps original_do_cfg = batch.do_classifier_free_guidance batch.do_classifier_free_guidance = False @@ -370,7 +373,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: try: batch = super().forward(batch, server_args) finally: - self.scheduler = original_scheduler + batch.scheduler = original_batch_scheduler batch.timesteps = original_batch_timesteps batch.num_inference_steps = original_batch_num_inference_steps batch.do_classifier_free_guidance = original_do_cfg diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py index 1b0223d513ec..955fa752132e 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising_dmd.py @@ -70,6 +70,7 @@ def forward( num_warmup_steps = prepared_vars["num_warmup_steps"] latents = prepared_vars["latents"] video_raw_latent_shape = latents.shape + scheduler = self.scheduler timesteps = torch.tensor( server_args.pipeline_config.dmd_denoising_steps, @@ -112,7 +113,7 @@ def forward( self._select_and_manage_model( t_int=t_int, boundary_timestep=self._handle_boundary_ratio( - server_args, batch + server_args, batch, scheduler ), server_args=server_args, batch=batch, @@ -173,7 +174,7 @@ def forward( pred_noise=pred_noise.flatten(0, 1), noise_input_latent=noise_latents.flatten(0, 1), timestep=t_expand, - scheduler=self.scheduler, + scheduler=scheduler, ).unflatten(0, pred_noise.shape[:2]) if i < len(timesteps) - 1: @@ -186,7 +187,7 @@ def forward( generator=batch.generator[0], device=self.device, ) - latents = self.scheduler.add_noise( + latents = scheduler.add_noise( pred_video.flatten(0, 1), noise.flatten(0, 1), next_timestep, @@ -197,7 +198,7 @@ def forward( # Update progress bar if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps - and (i + 1) % self.scheduler.order == 0 + and (i + 1) % scheduler.order == 0 and progress_bar is not None ): progress_bar.update() @@ -274,6 +275,7 @@ def _handle_boundary_ratio( self, server_args, batch, + scheduler, ): """ (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert @@ -288,7 +290,10 @@ def _handle_boundary_ratio( boundary_ratio = batch.boundary_ratio if boundary_ratio is not None: - boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + num_train_timesteps = getattr(scheduler, "num_train_timesteps", None) + if num_train_timesteps is None: + num_train_timesteps = scheduler.config.num_train_timesteps + boundary_timestep = boundary_ratio * num_train_timesteps else: boundary_timestep = None diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py index fe69d1ac97f6..970a0e7290a6 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_paint.py @@ -775,16 +775,17 @@ def _prepare_denoising_inputs( prompt_embeds = self.transformer.learned_text_clip_gen.repeat(1, 1, 1) negative_prompt_embeds = torch.zeros_like(prompt_embeds) + scheduler = self.scheduler if self.is_turbo: bsz = 3 index = torch.arange(29, -1, -bsz, device=device).long() timesteps = self.solver.ddim_timesteps[index] - self.scheduler.set_timesteps(timesteps=timesteps.cpu(), device=device) - timesteps = self.scheduler.timesteps + scheduler.set_timesteps(timesteps=timesteps.cpu(), device=device) + timesteps = scheduler.timesteps else: timesteps, num_steps = retrieve_timesteps( - self.scheduler, num_steps, device, None, None + scheduler, num_steps, device, None, None ) num_channels_latents = self.transformer.config.in_channels @@ -797,9 +798,10 @@ def _prepare_denoising_inputs( latents = randn_tensor( latent_shape, generator=generator, device=device, dtype=prompt_embeds.dtype ) - latents = latents * self.scheduler.init_noise_sigma + latents = latents * scheduler.init_noise_sigma return { + "scheduler": scheduler, "timesteps": timesteps, "latents": latents, "prompt_embeds": prompt_embeds, @@ -826,6 +828,7 @@ def _denoise_loop( do_cfg: bool, generator: torch.Generator, num_channels_latents: int, + scheduler: Any, ) -> torch.Tensor: import inspect @@ -833,9 +836,9 @@ def _denoise_loop( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) extra_step_kwargs = {} - if "eta" in inspect.signature(self.scheduler.step).parameters: + if "eta" in inspect.signature(scheduler.step).parameters: extra_step_kwargs["eta"] = 0.0 - if "generator" in inspect.signature(self.scheduler.step).parameters: + if "generator" in inspect.signature(scheduler.step).parameters: extra_step_kwargs["generator"] = generator for step_idx, t in enumerate(timesteps): @@ -844,7 +847,7 @@ def _denoise_loop( latent_model_input = rearrange( latent_model_input, "b n c h w -> (b n) c h w" ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = scheduler.scale_model_input(latent_model_input, t) latent_model_input = rearrange( latent_model_input, "(b n) c h w -> b n c h w", n=num_in_batch ) @@ -872,7 +875,7 @@ def _denoise_loop( noise_pred_text - noise_pred_uncond ) - latents = self.scheduler.step( + latents = scheduler.step( noise_pred, t, latents[:, :num_channels_latents, :, :], @@ -915,6 +918,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: do_cfg=prepared["do_cfg"], generator=prepared["generator"], num_channels_latents=prepared["num_channels_latents"], + scheduler=prepared["scheduler"], ) multiview_textures = self._decode_latents(latents) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py index 6c3be73c1af0..75fd02472198 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/hunyuan3d_shape.py @@ -154,12 +154,12 @@ def _validate_input(self, batch: Req, server_args: ServerArgs) -> None: if batch.num_outputs_per_prompt != 1: raise ValueError("Hunyuan3D only supports num_outputs_per_prompt=1.") - def _prepare_latents(self, batch_size, dtype, device, generator): + def _prepare_latents(self, batch_size, dtype, device, generator, scheduler): from diffusers.utils.torch_utils import randn_tensor shape = (batch_size, *self.vae.latent_shape) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - return latents * getattr(self.scheduler, "init_noise_sigma", 1.0) + return latents * getattr(scheduler, "init_noise_sigma", 1.0) def forward(self, batch: Req, server_args: ServerArgs) -> Req: # 1. Input validation @@ -196,10 +196,11 @@ def cat_recursive(a, b): cond = cat_recursive(cond, un_cond) # 4. Latent and timestep preparation + scheduler = self.scheduler batch_size = image.shape[0] sigmas = np.linspace(0, 1, batch.num_inference_steps) timesteps, _ = retrieve_timesteps( - self.scheduler, + scheduler, batch.num_inference_steps, device, sigmas=sigmas, @@ -209,7 +210,7 @@ def cat_recursive(a, b): if generator is None and batch.seed is not None: generator = torch.Generator(device=device).manual_seed(batch.seed) - latents = self._prepare_latents(batch_size, dtype, device, generator) + latents = self._prepare_latents(batch_size, dtype, device, generator, scheduler) guidance = None if hasattr(self.model, "guidance_embed") and self.model.guidance_embed is True: @@ -221,6 +222,7 @@ def cat_recursive(a, b): batch.prompt_embeds = [cond] batch.do_classifier_free_guidance = do_cfg batch.timesteps = timesteps + batch.scheduler = scheduler batch.latents = latents batch.extra["shape_guidance"] = guidance batch.extra["shape_image"] = image @@ -252,6 +254,8 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): """Prepare Hunyuan3D-specific variables for the base denoising loop.""" assert self.transformer is not None pipeline = self.pipeline() if self.pipeline else None + scheduler = batch.scheduler + assert scheduler is not None cache_dit_num_inference_steps = batch.extra.get( "cache_dit_num_inference_steps", batch.num_inference_steps ) @@ -285,10 +289,10 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): guidance = batch.extra.get("shape_guidance") num_inference_steps = batch.num_inference_steps - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order extra_step_kwargs = self.prepare_extra_func_kwargs( - self.scheduler.step, + scheduler.step, {"generator": batch.generator, "eta": batch.eta}, ) @@ -300,6 +304,7 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): return { "extra_step_kwargs": extra_step_kwargs, + "scheduler": scheduler, "target_dtype": target_dtype, "autocast_enabled": autocast_enabled, "timesteps": timesteps, @@ -329,7 +334,8 @@ def _predict_noise( ): """Hunyuan3D-specific noise prediction with normalized timestep.""" cond = kwargs.get("encoder_hidden_states") - timestep_norm = timestep / self.scheduler.config.num_train_timesteps + scheduler = kwargs.get("scheduler") + timestep_norm = timestep / scheduler.config.num_train_timesteps return current_model(latent_model_input, timestep_norm, cond, guidance=guidance) def _predict_noise_with_cfg( @@ -371,6 +377,7 @@ def _predict_noise_with_cfg( timestep=timestep_expanded, target_dtype=target_dtype, guidance=guidance, + scheduler=batch.scheduler, encoder_hidden_states=cond, ) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py index 9f1e68265e88..f95a1aac3061 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/ltx_2_denoising.py @@ -1,4 +1,3 @@ -import copy from contextlib import contextmanager from dataclasses import dataclass, field @@ -10,6 +9,9 @@ ) from sglang.multimodal_gen.runtime.distributed import get_sp_world_size from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( + clone_scheduler_runtime, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.denoising import ( DenoisingContext, @@ -1083,7 +1085,7 @@ def _prepare_denoising_loop( ) ctx.audio_latents = batch.audio_latents # Video and audio keep separate scheduler state throughout the denoising loop. - ctx.audio_scheduler = copy.deepcopy(self.scheduler) + ctx.audio_scheduler = clone_scheduler_runtime(ctx.scheduler) if ctx.use_ltx23_legacy_one_stage: batch.ltx23_audio_replicated_for_sp = False @@ -1209,7 +1211,7 @@ def _run_denoising_step( raise ValueError("LTX-2 audio scheduler was not prepared.") # 1. Read the scheduler sigma pair and derive the Euler delta. - sigmas = getattr(self.scheduler, "sigmas", None) + sigmas = getattr(ctx.scheduler, "sigmas", None) if sigmas is None or not isinstance(sigmas, torch.Tensor): raise ValueError("Expected scheduler.sigmas to be a tensor for LTX-2.") sigma = sigmas[step.step_index].to( @@ -1391,7 +1393,7 @@ def _stage2_midpoint_model_call( midpoint_model_call=_stage2_midpoint_model_call, ) else: - ctx.latents = self.scheduler.step( + ctx.latents = ctx.scheduler.step( model_video, step.t_device, ctx.latents, return_dict=False )[0] ctx.audio_latents = ctx.audio_scheduler.step( diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py index 48476a881bfd..47246b419d64 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/glm_image.py @@ -775,22 +775,23 @@ def forward( ) # Prepare timesteps + scheduler = self.scheduler image_seq_len = ( (height // self.vae_scale_factor) * (width // self.vae_scale_factor) ) // (self.transformer.config.patch_size**2) timesteps = np.linspace( - self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1 + scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1 )[:-1] timesteps = timesteps.astype(np.int64).astype(np.float32) - sigmas = timesteps / self.scheduler.config.num_train_timesteps + sigmas = timesteps / scheduler.config.num_train_timesteps mu = calculate_shift( image_seq_len, - self.scheduler.config.get("base_image_seq_len", 256), - self.scheduler.config.get("base_shift", 0.25), - self.scheduler.config.get("max_shift", 0.75), + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("base_shift", 0.25), + scheduler.config.get("max_shift", 0.75), ) timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu ) self._num_timesteps = len(timesteps) @@ -800,6 +801,7 @@ def forward( batch.negative_prompt_embeds = [negative_prompt_embeds] batch.latents = latents batch.timesteps = timesteps + batch.scheduler = scheduler batch.num_inference_steps = num_inference_steps batch.sigmas = sigmas.tolist() # Convert numpy array to list for validation batch.generator = generator diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py index 123122716760..e208f364314e 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/helios_denoising.py @@ -15,6 +15,9 @@ from sglang.multimodal_gen.runtime.distributed import get_local_torch_device from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( + get_or_create_request_scheduler, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( PipelineStage, @@ -126,6 +129,7 @@ def _denoise_one_chunk( batch=None, server_args=None, global_step_offset=0, + scheduler=None, ): """Denoise a single chunk with full timestep loop.""" batch_size = latents.shape[0] @@ -226,9 +230,7 @@ def _denoise_one_chunk( noise_pred - noise_uncond ) - latents = self.scheduler.step( - noise_pred, t, latents, return_dict=False - )[0] + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] return latents @@ -258,6 +260,7 @@ def _denoise_one_chunk_stage2( batch=None, server_args=None, global_step_offset=0, + scheduler=None, ): """Denoise a single chunk using pyramid super-resolution (Stage 2).""" batch_size, num_channel, num_frames, height, width = latents.shape @@ -292,14 +295,14 @@ def _denoise_one_chunk_stage2( ) mu = calculate_shift(image_seq_len) - self.scheduler.set_timesteps( + scheduler.set_timesteps( pyramid_num_inference_steps_list[i_s], i_s, device=device, mu=mu, is_amplify_first_chunk=is_amplify_first_chunk, ) - timesteps = self.scheduler.timesteps + timesteps = scheduler.timesteps if i_s > 0: # Upsample 2x nearest-neighbor @@ -317,7 +320,7 @@ def _denoise_one_chunk_stage2( ).permute(0, 2, 1, 3, 4) # Renoise with correlated block noise - ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] + ori_sigma = 1 - scheduler.ori_start_sigmas[i_s] alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) @@ -428,7 +431,7 @@ def _denoise_one_chunk_stage2( noise_pred - noise_uncond ) - latents = self.scheduler.step( + latents = scheduler.step( noise_pred, t, latents, @@ -439,8 +442,8 @@ def _denoise_one_chunk_stage2( if start_point_list is not None else None ), - dmd_sigmas=self.scheduler.sigmas, - dmd_timesteps=self.scheduler.timesteps, + dmd_sigmas=scheduler.sigmas, + dmd_timesteps=scheduler.timesteps, all_timesteps=timesteps, )[0] @@ -451,6 +454,7 @@ def _denoise_one_chunk_stage2( def forward(self, batch: Req, server_args: ServerArgs) -> Req: """Run the Helios chunked denoising loop.""" pipeline_config = server_args.pipeline_config + scheduler = get_or_create_request_scheduler(batch, self.scheduler) device = ( batch.latents.device if hasattr(batch, "latents") and batch.latents is not None @@ -671,13 +675,14 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: batch=batch, server_args=server_args, global_step_offset=global_step_offset, + scheduler=scheduler, ) else: # Stage 1: Standard flat denoising - self.scheduler.set_timesteps( + scheduler.set_timesteps( num_inference_steps, device=device, sigmas=sigmas, mu=mu ) - timesteps = self.scheduler.timesteps + timesteps = scheduler.timesteps latents = self._denoise_one_chunk( latents=latents, @@ -700,6 +705,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: batch=batch, server_args=server_args, global_step_offset=global_step_offset, + scheduler=scheduler, ) global_step_offset += num_inference_steps diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py index b8b0cfb62b02..a33242f08c17 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/mova.py @@ -127,19 +127,21 @@ def __init__(self, scheduler) -> None: self.scheduler = scheduler def forward(self, batch: Req, server_args: ServerArgs) -> Req: - self.scheduler.set_timesteps( + scheduler = self.scheduler + scheduler.set_timesteps( batch.num_inference_steps, denoising_strength=1.0, - shift=getattr(batch, "sigma_shift", self.scheduler.shift), + shift=getattr(batch, "sigma_shift", scheduler.shift), ) - self.scheduler.set_pair_postprocess_by_name( + scheduler.set_pair_postprocess_by_name( "dual_sigma_shift", visual_shift=getattr(batch, "visual_shift", 5.0), audio_shift=getattr(batch, "audio_shift", 5.0), ) - paired = self.scheduler.get_pairs() + paired = scheduler.get_pairs() batch.paired_timesteps = paired batch.timesteps = paired + batch.scheduler = scheduler return batch @@ -349,13 +351,17 @@ def _manage_device_placement( model_to_use.to(get_local_torch_device()) def _select_visual_dit( - self, timestep: float, boundary_ratio: float | None, server_args: ServerArgs + self, + timestep: float, + boundary_ratio: float | None, + server_args: ServerArgs, + scheduler, ): if boundary_ratio is None or self.video_dit_2 is None: self._manage_device_placement(self.video_dit, None, server_args) return self.video_dit - boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + boundary_timestep = boundary_ratio * scheduler.num_train_timesteps if timestep >= boundary_timestep: current_model = self.video_dit model_to_offload = self.video_dit_2 @@ -405,6 +411,9 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: paired_timesteps = batch.paired_timesteps if paired_timesteps is None: raise ValueError("paired_timesteps must be set for MOVA") + scheduler = batch.scheduler + if scheduler is None: + raise ValueError("scheduler must be set for MOVA denoising") y = batch.y if batch.y is not None else batch.image_latent if getattr(self.video_dit, "require_vae_embedding", False) and y is None: @@ -417,7 +426,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: is_warmup = batch.is_warmup extra_step_kwargs = self.prepare_extra_func_kwargs( - self.scheduler.step_from_to, + scheduler.step_from_to, getattr(batch, "extra_step_kwargs", None) or {}, ) @@ -441,7 +450,7 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: audio_timestep = pair_t cur_visual_dit = self._select_visual_dit( - timestep.item(), boundary_ratio, server_args + timestep.item(), boundary_ratio, server_args, scheduler ) timestep = timestep.unsqueeze(0).to(device=get_local_torch_device()) @@ -569,14 +578,14 @@ def forward(self, batch: Req, server_args: ServerArgs) -> Req: next_timestep = None next_audio_timestep = None - batch.latents = self.scheduler.step_from_to( + batch.latents = scheduler.step_from_to( visual_noise_pred, timestep, next_timestep, batch.latents, **extra_step_kwargs, ) - batch.audio_latents = self.scheduler.step_from_to( + batch.audio_latents = scheduler.step_from_to( audio_noise_pred, audio_timestep, next_audio_timestep, diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py index 2cb35deadb3f..6e82852a66fb 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages/qwen_image_layered.py @@ -488,12 +488,13 @@ def forward( ] # 5. Prepare timesteps + scheduler = self.scheduler sigmas = np.linspace(1.0, 0, num_inference_steps + 1)[:-1] image_seq_len = latents.shape[1] base_seqlen = 256 * 256 / 16 / 16 mu = (image_latents.shape[1] / base_seqlen) ** 0.5 timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, + scheduler, num_inference_steps, device, sigmas=sigmas, @@ -518,8 +519,10 @@ def forward( batch.negative_prompt_embeds_mask = [negative_prompt_embeds_mask] batch.latents = latents batch.image_latent = image_latents + batch.timesteps = timesteps + batch.scheduler = scheduler batch.num_inference_steps = num_inference_steps - batch.sigmas = sigmas.tolist() # Convert numpy array to list for validation + batch.sigmas = None batch.generator = torch.manual_seed(0) batch.original_condition_image_size = image_size batch.raw_latent_shape = latents.shape diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py index 44cf8c1196fd..1af0201c1c7a 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/timestep_preparation.py @@ -13,6 +13,9 @@ import torch from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines_core.diffusion_scheduler_utils import ( + get_or_create_request_scheduler, +) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages.base import ( PipelineStage, @@ -68,7 +71,10 @@ def forward( Returns: The batch with prepared timesteps. """ - scheduler = self.scheduler + if batch.scheduler is not None and batch.timesteps is not None: + return batch + + scheduler = get_or_create_request_scheduler(batch, self.scheduler) device = get_local_torch_device() num_inference_steps = batch.num_inference_steps timesteps = batch.timesteps @@ -133,6 +139,7 @@ def forward( # Update batch with prepared timesteps batch.timesteps = timesteps + batch.scheduler = scheduler if not batch.is_warmup: self.log_debug("timesteps: %s", timesteps) return batch