Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
66b0a0b
Add diffusion performance mode defaults
mickqian May 6, 2026
82f2304
Refactor diffusion performance auto tune policy
mickqian May 7, 2026
a231fca
Apply black formatting to auto tune resolver
mickqian May 7, 2026
ef3eaec
Rename diffusion performance preset to speed
mickqian May 7, 2026
5f9d65f
Collect deployment auto tune hints
mickqian May 8, 2026
ccefae5
upd
mickqian May 8, 2026
3186ec0
auto-tune z-image
mickqian May 8, 2026
8bf7975
Merge remote-tracking branch 'origin/main' into codex/diffusion-perfo…
mickqian May 9, 2026
eb7444e
Merge remote-tracking branch 'origin/main' into codex/diffusion-perfo…
mickqian May 12, 2026
e45d05f
upd
mickqian May 12, 2026
155c914
upd
mickqian May 12, 2026
42c2e9d
upd
mickqian May 12, 2026
9108ba4
upd
mickqian May 12, 2026
62d671b
upd
mickqian May 12, 2026
7974064
upd
mickqian May 12, 2026
132491c
lint
mickqian May 12, 2026
90f5d3e
lint
mickqian May 12, 2026
c909cd9
upd
mickqian May 12, 2026
59eec65
Merge remote-tracking branch 'origin/main' into codex/diffusion-perfo…
mickqian May 12, 2026
c214fd8
upd
mickqian May 12, 2026
50c3997
merge balanced and auto mode
mickqian May 12, 2026
031ba4d
fix
mickqian May 12, 2026
acba38e
Refine diffusion auto performance mode
mickqian May 13, 2026
38b5e04
Tighten auto performance defaults
mickqian May 13, 2026
c69f635
Narrow auto offload residency tuning
mickqian May 13, 2026
4989205
Fix auto layerwise offload ordering
mickqian May 13, 2026
8bfc955
Merge remote-tracking branch 'origin/main' into codex/diffusion-perfo…
mickqian May 13, 2026
c912573
upd
mickqian May 13, 2026
67c65d8
Update auto performance mode unit expectations
mickqian May 13, 2026
a94c8e7
Preserve LTX snapshot DiT offload in auto mode
mickqian May 13, 2026
f974077
Print diffusion performance logs per run
mickqian May 13, 2026
b7eaadc
Remove duplicate baseline improvement logs
mickqian May 13, 2026
337d20c
Disable warmup for LTX HQ CI case
mickqian May 13, 2026
2023929
Avoid LTX snapshot warmup DiT overlap
mickqian May 13, 2026
3091f2e
Tighten stable LTX perf baselines
mickqian May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/diffusion/api/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,13 @@ Use `sglang generate --help` and `sglang serve --help` for the full argument lis

- `--model-path {MODEL}`: model path or Hugging Face model ID
- `--lora-path {PATH}` and `--lora-nickname {NAME}`: load a LoRA adapter
- `--lora-merge-mode {auto|merge|dynamic}`: choose how LoRA is applied. `auto` statically merges regular weights and uses dynamic LoRA for FSDP-sharded weights to avoid full-gather peaks.
- `--num-gpus {N}`: number of GPUs to use
- `--performance-mode {manual|auto|speed|memory}` / `--mode`: preset for latency/throughput and memory defaults. `auto` is the default and keeps safe offload defaults, using FSDP only for validated DiT-offload replacement paths; use `manual` to keep performance-related server args under explicit user control. Explicit offload, FSDP, and parallelism flags take precedence in all modes.
- `--tp-size {N}`: tensor parallelism size, mainly for encoders
- `--sp-degree {N}`: sequence parallelism size
- `--ulysses-degree {N}` and `--ring-degree {N}`: USP parallelism controls
- `--enable-cfg-parallel {true|false}`: enable or explicitly disable CFG parallelism
- `--attention-backend {BACKEND}`: attention backend for native SGLang pipelines
- `--component-attention-backends {MAP}`: per-component attention backend overrides, for example `text_encoder=torch_sdpa,transformer=fa`
- `--attention-backend-config {CONFIG}`: attention backend configuration
Expand Down
12 changes: 7 additions & 5 deletions docs/diffusion/api/openai_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,13 @@ curl -sS -L "http://localhost:30010/v1/videos/<VIDEO_ID>/content" \
The server supports dynamic loading, merging, and unmerging of LoRA adapters.

**Important Notes:**
- Mutual Exclusion: Only one LoRA can be *merged* (active) at a time
- Switching: To switch LoRAs, you must first `unmerge` the current one, then `set` the new one
- Mutual Exclusion: Only one LoRA configuration can be active per target at a time
- Switching: To switch LoRAs, deactivate the current LoRA with `unmerge_lora_weights`, then `set` the new one
- Caching: The server caches loaded LoRA weights in memory. Switching back to a previously loaded LoRA (same path) has little cost

**Set LoRA Adapter**

Loads one or more LoRA adapters and merges their weights into the model. Supports both single LoRA (backward compatible) and multiple LoRA adapters.
Loads one or more LoRA adapters and applies them to the model. By default, regular weights are statically merged, while FSDP-sharded weights use dynamic LoRA to avoid full-gather memory peaks.

**Endpoint:** `POST /v1/set_lora`

Expand All @@ -287,6 +287,7 @@ Loads one or more LoRA adapters and merges their weights into the model. Support
- `"transformer_2"`: Apply only to transformer_2 (low noise for Wan2.2)
- `"critic"`: Apply only to the critic model
- `strength` (float or list of floats, optional): LoRA strength for merge, default 1.0. If a list, must match the length of `lora_nickname`. Values < 1.0 reduce the effect, values > 1.0 amplify the effect
- `merge_mode` (string, optional): `"auto"` (default server policy), `"merge"` (force static merge), or `"dynamic"` (apply LoRA at forward time)

**Single LoRA Example:**

Expand Down Expand Up @@ -331,15 +332,15 @@ curl -X POST http://localhost:30010/v1/set_lora \
> When using multiple LoRAs:
> - All list parameters (`lora_nickname`, `lora_path`, `target`, `strength`) must have the same length
> - If `target` or `strength` is a single value, it will be applied to all LoRAs
> - Multiple LoRAs applied to the same target will be merged in order
> - Multiple LoRAs applied to the same target are applied in order


**Merge LoRA Weights**

Manually merges the currently set LoRA weights into the base model.

> [!NOTE]
> `set_lora` automatically performs a merge, so this is typically only needed if you have manually unmerged but want to re-apply the same LoRA without calling `set_lora` again.*
> With FSDP-sharded weights, manual merge may require a full-gather and can OOM. Use `set_lora` with `merge_mode="auto"` or `"dynamic"` for the lower-peak path.

**Endpoint:** `POST /v1/merge_lora_weights`

Expand Down Expand Up @@ -395,6 +396,7 @@ curl -sS -X GET "http://localhost:30010/v1/list_loras"
"nickname": "lora2",
"path": "tarn59/pixel_art_style_lora_z_image_turbo",
"merged": true,
"mode": "merged",
"strength": 1.0
}
]
Expand Down
94 changes: 94 additions & 0 deletions docs/diffusion/performance/deployment_cookbook.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Deployment Cookbook

This page gives practical defaults for choosing CPU offload, FSDP, CFG parallelism, SP, and TP.

## Quick Rule

Use the simplest setting that fits your memory target:

| Goal | Recommended setting |
|--------------------------------------------|------------------------------------------------------------------------------------|
| Fastest single-GPU run when the model fits | Disable CPU offload and do not use FSDP. |
| Lower single-GPU memory usage | Use component CPU offload, or layerwise DiT offload for supported Wan/MOVA models. |
| Faster multi-GPU Qwen/Wan CFG generation | Use FSDP with CFG parallelism and disable CPU offload. |
| Sequence length or video-shape scaling | Use SP/Ulysses/Ring when the model benefits from sequence parallelism. |
| TP compatibility or encoder-heavy paths | Set TP explicitly; do not treat TP as the default latency optimization. |

Base the decision on available memory on the selected GPU(s).

- For multi-GPU deployment: the least-free selected GPU is the bottleneck. A busy 80GiB GPU can behave like a much smaller GPU.
- For single-GPU deployment: FSDP shards DiT weights across multiple GPUs. It is not useful for keeping a single-GPU deployment on one GPU; for that case use CPU offload.

## Performance Modes

`--performance-mode` applies safe presets without overriding explicit offload, FSDP, or parallelism flags. `auto` is the default. Use `manual` when you need to keep performance-related server args under explicit user control. `--mode` is a short alias.

| Mode | Meaning |
|------------|---------------------------------------------------------------------------------------------------------------------------|
| `manual` | Keeps performance-related server args under explicit user control. |
| `auto` | Default. Keeps legacy safe offload defaults and uses FSDP/CFG only on validated multi-GPU deployments where FSDP can replace DiT offload. |
| `speed` | Favors GPU-resident execution for lower latency and higher throughput. Disables CPU offload when unset; may OOM. |
| `memory` | Favors lower GPU memory. Uses component offload, or Wan/MOVA layerwise DiT offload when supported. |

`auto` checks selected GPU memory before applying FSDP. In multi-GPU runs it uses the least available memory across selected GPUs, and only turns on FSDP automatically when doing so can replace DiT offload. Text encoder, image encoder, and other component residency still follow the offload policy unless the model marks a high-memory resident path as safe. When the model default uses CFG and the user did not set a parallelism policy, `auto` may also enable CFG parallelism. `speed` intentionally does not check memory; it is the mode for users who prefer latency/throughput and accept OOM risk.

The modes tune residency for native pipeline components declared to the component residency manager. Today this covers the major DiT, text/image encoder, VAE, vocoder, and upsampler components; DiT can use layerwise offload when supported, while text encoders use either resident execution or component CPU offload. Do not assume text-encoder layerwise offload unless a model implements and validates it.

NOTE:
The preset is intentionally coarse. A future continuous value such as `0.0` to `1.0` could express the speed-memory tradeoff more precisely, but it would need model-specific memory models and clearer user expectations. Until then, use the preset plus explicit flags for overrides.

Examples:

```bash
sglang generate \
--model-path Qwen/Qwen-Image \
--num-gpus 2 \
--performance-mode auto
```

```bash
sglang generate \
--model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
--performance-mode memory
```

Explicit flags win over the mode:

```bash
sglang generate \
--model-path Qwen/Qwen-Image \
--num-gpus 2 \
--performance-mode auto \
--use-fsdp-inference false
```

In this example, `auto` will not re-enable FSDP. The same applies to parallelism; for example, `--enable-cfg-parallel false` keeps CFG parallelism disabled.

## Interpreting The Levers

**No offload** keeps model components resident on GPU. It is usually fastest when memory is sufficient.

**Component CPU offload** lowers GPU memory by moving large components to CPU. It is simple and robust, but it usually trades latency for memory.

**Layerwise DiT offload** lowers DiT memory further for supported Wan/MOVA models by moving DiT layers between CPU and GPU. It can be the best single-GPU memory mode, but may increase latency and lower throughput.

**FSDP** shards DiT weights across multiple GPUs and all-gathers weights during forward. It can reduce DiT CPU offload cost on multi-GPU deployments, especially for validated Wan I2V workloads.

FSDP sharding granularity matters. SGLang Diffusion prefers sharding direct repeated transformer block entries such as `transformer_blocks.0` or `blocks.0`. Coarser sharding lowers wrapper count but can increase all-gather peak memory; finer sharding can reduce transient memory but adds communication and scheduling overhead. If a model does not define an explicit sharding rule, the loader falls back to repeated block class names and common direct numbered block paths.

**CFG parallelism** splits positive and negative CFG branches across GPUs. For Qwen/Wan workloads with normal step counts, this is the most reliable multi-GPU speedup observed so far.

**SP/Ulysses/Ring** splits sequence work. It can help video workloads, but validated Qwen/Wan runs showed CFG parallelism outperforming SP for latency.

**TP** is supported for compatibility and some model structures, but current measurements do not make it the default latency path for Qwen/Wan.

## Current Benchmark Takeaways

Observed regular-scale trends:

- Z-Image: single-GPU no-offload was faster than FSDP/SP in the tested setting; keep FSDP off unless memory or parallelism requires it.
- Qwen-Image: keep the default non-FSDP path unless a specific FSDP/SP/Ring setting has been benchmarked on the target hardware.
- Wan: FSDP can replace DiT offload on validated multi-GPU workloads, while text/image encoders may still need component offload. Keep model-specific precision checks before making FSDP automatic for a path.
- Component offload mainly reduced memory; it did not improve latency in the tested no-offload-vs-offload runs.

Always benchmark with your actual resolution, frame count, step count, and GPU type before locking production defaults.
2 changes: 2 additions & 0 deletions docs/diffusion/performance/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This section covers the main performance levers for SGLang Diffusion: attention
## Start Here

- Use [Attention Backends](attention_backends.md) to choose the best backend for your model and hardware.
- Use [Deployment Cookbook](deployment_cookbook.md) to choose CPU offload, FSDP, CFG parallelism, SP, and TP.
- Use [Caching Acceleration](cache/index.md) to reduce denoising cost with Cache-DiT or TeaCache.
- Use [Profiling](profiling.md) when you need to diagnose a bottleneck rather than guess.

Expand All @@ -26,6 +27,7 @@ This section covers the main performance levers for SGLang Diffusion: attention
:maxdepth: 1

attention_backends
deployment_cookbook
cache/index
profiling
```
Expand Down
17 changes: 10 additions & 7 deletions python/sglang/benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
chunk_size = 1024 # Download in chunks of 1KB

# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
with (
open(filename, "wb") as f,
tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar,
):
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def forward(
ref_latents=None,
additional_t_cond=None,
transformer_options={},
**kwargs
**kwargs,
):
"""Forward pass for QwenImageEdit model."""
latents, orig_shape = self._pack_latents(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class WanVideoArchConfig(DiTArchConfig):
r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2",
r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2",
r"^blocks\.(\d+)\.attn1\.attn_op\.local_attn\.proj_l\.(.*)$": r"blocks.\1.attn1.local_attn.proj_l.\2",
r"^blocks\.(\d+)\.attn2\.norm_added_q\.(.*)$": "",
r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2",
r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2",
r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
)
from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput
from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config
from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import (
ModelDeploymentConfig,
)
from sglang.multimodal_gen.configs.sample.sampling_params import DataType
from sglang.multimodal_gen.configs.utils import update_config_from_args
from sglang.multimodal_gen.runtime.distributed.cfg_policy import CFGPolicy
Expand Down Expand Up @@ -240,6 +243,9 @@ class PipelineConfig:
# image encoding
image_encoder_extra_args: dict = field(default_factory=lambda: {})

def get_model_deployment_config(self) -> ModelDeploymentConfig:
return ModelDeploymentConfig()

def postprocess_image(self, image):
return image.last_hidden_state

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
ModelTaskType,
PipelineConfig,
)
from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import (
ModelDeploymentConfig,
)
from sglang.multimodal_gen.runtime.distributed import (
get_sp_parallel_rank,
get_sp_world_size,
Expand Down Expand Up @@ -187,6 +190,12 @@ def vae_scale_factor(self):
def vae_temporal_compression(self):
return self.vae_config.arch_config.temporal_compression_ratio

def get_model_deployment_config(self) -> ModelDeploymentConfig:
return ModelDeploymentConfig(
auto_disable_component_offload_min_available_memory_gb=70,
auto_disable_component_offload_components=("dit",),
)

def prepare_latent_shape(self, batch, batch_size, num_frames):
"""Return unpacked latent shape [B, C, F, H, W]."""
height = batch.height // self.vae_scale_factor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
ModelDeploymentConfig provides model-specific config on how to deploy a model optimally

"""

from dataclasses import dataclass
from typing import Literal

OffloadComponentName = Literal["dit", "text_encoder", "image_encoder"]


@dataclass(frozen=True)
class ModelDeploymentConfig:
auto_dit_layerwise_offload: bool = False
# if the available memory is bigger than this value, keep dit resident instead of apply layerwise-offload
auto_dit_layerwise_offload_high_memory_disable_gb: float | None = None
auto_disable_component_offload_min_available_memory_gb: float | None = None
# keep this explicit because large encoders can OOM even when DiT fits resident
auto_disable_component_offload_components: tuple[OffloadComponentName, ...] = (
"dit",
"text_encoder",
"image_encoder",
)
fsdp_auto_min_available_memory_gb: float | None = None
fsdp_auto_requires_cfg: bool = True
fsdp_auto_requires_default_parallelism: bool = True
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
ModelTaskType,
PipelineConfig,
)
from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import (
ModelDeploymentConfig,
)
from sglang.multimodal_gen.configs.pipeline_configs.wan import t5_postprocess_text
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

Expand Down Expand Up @@ -52,6 +55,12 @@ class MOVAPipelineConfig(PipelineConfig):
time_division_factor: int = 4
time_division_remainder: int = 1

def get_model_deployment_config(self) -> ModelDeploymentConfig:
return ModelDeploymentConfig(
auto_dit_layerwise_offload=True,
auto_dit_layerwise_offload_high_memory_disable_gb=130,
)

def _center_crop_and_resize(
self, image: torch.Tensor | Image.Image, target_height: int, target_width: int
) -> torch.Tensor | Image.Image:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class QwenImagePipelineConfig(QwenImageRolloutPipelineMixin, ImagePipelineConfig
postprocess_text_funcs: tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (qwen_image_postprocess_text,)
)

text_encoder_extra_args: list[dict] = field(
default_factory=lambda: [
dict(
Expand Down
15 changes: 15 additions & 0 deletions python/sglang/multimodal_gen/configs/pipeline_configs/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
ModelTaskType,
PipelineConfig,
)
from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import (
ModelDeploymentConfig,
)
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -91,6 +94,12 @@ def __post_init__(self):
self.vae_config.load_encoder = False
self.vae_config.load_decoder = True

def get_model_deployment_config(self) -> ModelDeploymentConfig:
return ModelDeploymentConfig(
auto_dit_layerwise_offload=True,
auto_dit_layerwise_offload_high_memory_disable_gb=130,
)


@dataclass
class TurboWanT2V480PConfig(WanT2V480PConfig):
Expand Down Expand Up @@ -136,6 +145,12 @@ def __post_init__(self) -> None:
self.vae_config.load_encoder = True
self.vae_config.load_decoder = True

def get_model_deployment_config(self) -> ModelDeploymentConfig:
return ModelDeploymentConfig(
auto_dit_layerwise_offload=True,
auto_dit_layerwise_offload_high_memory_disable_gb=130,
)


@dataclass
class WanI2V720PConfig(WanI2V480PConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
TextConditioningOutput,
pad_text_embeddings_with_mask,
)
from sglang.multimodal_gen.configs.pipeline_configs.model_deployment_config import (
ModelDeploymentConfig,
)
from sglang.multimodal_gen.configs.post_training.pipeline_configs import (
ZImageRolloutPipelineMixin,
)
Expand Down Expand Up @@ -80,6 +83,9 @@ class ZImagePipelineConfig(ZImageRolloutPipelineMixin, ImagePipelineConfig):
PATCH_SIZE: int = 2
F_PATCH_SIZE: int = 1

def get_model_deployment_config(self) -> ModelDeploymentConfig:
return ModelDeploymentConfig(fsdp_auto_min_available_memory_gb=40)

def tokenize_prompt(self, prompts: list[str], tokenizer, tok_kwargs) -> dict:
rendered_prompts = [
tokenizer.apply_chat_template(
Expand Down
Loading
Loading