Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ diffusion = [
"st_attn==0.0.7",
"vsa==0.0.4",
"runai_model_streamer",
"cache-dit==1.1.8"
"cache-dit==1.2.0"
]

tracing = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def from_user_sampling_params_args(model_path: str, server_args, *args, **kwargs
# Re-raise if it's not a safetensors file issue
raise

user_sampling_params = SamplingParams(*args, **kwargs)
user_kwargs = dict(kwargs)
user_kwargs.pop("diffusers_kwargs", None)
user_sampling_params = SamplingParams(*args, **user_kwargs)
# TODO: refactor
sampling_params._merge_with_user_params(user_sampling_params)
sampling_params._adjust(server_args)
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/multimodal_gen/docs/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ Attention backends are defined by `AttentionBackendEnum` (`sglang.multimodal_gen

Backend selection is performed by the shared attention layers (e.g. `LocalAttention` / `USPAttention` / `UlyssesAttention` in `sglang.multimodal_gen.runtime.layers.attention.layer`) and therefore applies to any model component using these layers (e.g. diffusion transformer / DiT and encoders).

When using the diffusers backend, `--attention-backend` is passed through to diffusers'
`set_attention_backend` (e.g., `flash`, `_flash_3_hub`, `sage`, `xformers`, `native`).

- **CUDA**: prefers FlashAttention (FA3/FA4) when supported; otherwise falls back to PyTorch SDPA.
- **ROCm**: uses FlashAttention when available; otherwise falls back to PyTorch SDPA.
- **MPS**: always uses PyTorch SDPA.

## Backend options

The CLI accepts the lowercase names of `AttentionBackendEnum`. The table below lists the backends implemented by the built-in platforms. `fa3`/`fa4` are accepted as aliases for `fa`.
For SGLang-native pipelines, the CLI accepts the lowercase names of `AttentionBackendEnum`. The table below lists the backends implemented by the built-in platforms. `fa3`/`fa4` are accepted as aliases for `fa`.

| CLI value | Enum value | Notes |
|---|---|---|
Expand Down
70 changes: 68 additions & 2 deletions python/sglang/multimodal_gen/docs/cache/cache_dit.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,72 @@ sglang generate --model-path Qwen/Qwen-Image \
--prompt "A beautiful sunset over the mountains"
```

## Diffusers Backend Configuration

Cache-DiT supports loading acceleration configs from a custom YAML file. For
diffusers pipelines, pass the YAML/JSON path via `--cache-dit-config`. This
flow requires cache-dit >= 1.2.0 (`cache_dit.load_configs`).

### Single GPU inference

Define a `config.yaml` file that contains:

```yaml
cache_config:
max_warmup_steps: 8
warmup_interval: 2
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
residual_diff_threshold: 0.12
enable_taylorseer: true
taylorseer_order: 1
```

Then apply the config with:

```bash
sglang generate --backend diffusers \
--model-path Qwen/Qwen-Image \
--cache-dit-config config.yaml \
--prompt "A beautiful sunset over the mountains"
```

### Distributed inference

Define a `parallel_config.yaml` file that contains:

```yaml
cache_config:
max_warmup_steps: 8
warmup_interval: 2
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
residual_diff_threshold: 0.12
enable_taylorseer: true
taylorseer_order: 1
parallelism_config:
ulysses_size: auto
parallel_kwargs:
attention_backend: native
extra_parallel_modules: ["text_encoder", "vae"]
```

`ulysses_size: auto` means cache-dit will auto-detect the world_size. Otherwise,
set it to a specific integer (e.g., `4`).

Then apply the distributed config with:

```bash
sglang generate --backend diffusers \
--model-path Qwen/Qwen-Image \
--cache-dit-config parallel_config.yaml \
--prompt "A futuristic cityscape at sunset"
```

## Advanced Configuration

### DBCache Parameters
Expand Down Expand Up @@ -151,8 +217,8 @@ SGLang Diffusion x Cache-DiT supports almost all models originally supported in

## Limitations

- **Single GPU only**: Distributed support (TP/SP) is not yet validated; Cache-DiT will be automatically disabled when
`world_size > 1`
- **SGLang-native pipelines**: Distributed support (TP/SP) is not yet validated; Cache-DiT will be automatically
disabled when `world_size > 1`.
- **SCM minimum steps**: SCM requires >= 8 inference steps to be effective
- **Model support**: Only models registered in Cache-DiT's BlockAdapterRegister are supported

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/multimodal_gen/docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The SGLang-diffusion CLI provides a quick way to access the inference pipeline f
- `--sp-degree {SP_SIZE}`: Sequence parallelism size (typically should match the number of GPUs)
- `--ulysses-degree {ULYSSES_DEGREE}`: The degree of DeepSpeed-Ulysses-style SP in USP
- `--ring-degree {RING_DEGREE}`: The degree of ring attention-style SP in USP
- `--attention-backend {BACKEND}`: Attention backend to use. For SGLang-native pipelines use `fa`, `torch_sdpa`, `sage_attn`, etc. For diffusers pipelines use diffusers backend names like `flash`, `_flash_3_hub`, `sage`, `xformers`.
- `--cache-dit-config {PATH}`: Path to a Cache-DiT YAML/JSON config (diffusers backend only)


### Sampling Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req
from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage
from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum
from sglang.multimodal_gen.runtime.server_args import ServerArgs
from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger
Expand Down Expand Up @@ -410,7 +411,7 @@ def _load_diffusers_pipeline(self, model_path: str, server_args: ServerArgs) ->
load_kwargs["device_map"] = device_map

# Add quantization config if provided (e.g., BitsAndBytesConfig for 4/8-bit)
config = getattr(server_args, "pipeline_config", None)
config = server_args.pipeline_config
if config is not None:
quant_config = getattr(config, "quantization_config", None)
if quant_config is not None:
Expand Down Expand Up @@ -470,12 +471,15 @@ def _load_diffusers_pipeline(self, model_path: str, server_args: ServerArgs) ->
# Apply attention backend if specified
self._apply_attention_backend(pipe, server_args)

# Apply cache-dit acceleration if configured
pipe = self._apply_cache_dit(pipe, server_args)

logger.info("Loaded diffusers pipeline: %s", pipe.__class__.__name__)
return pipe

def _apply_vae_optimizations(self, pipe: Any, server_args: ServerArgs) -> None:
"""Apply VAE memory optimizations (tiling, slicing) from pipeline config."""
config = getattr(server_args, "pipeline_config", None)
config = server_args.pipeline_config
if config is None:
return

Expand All @@ -499,16 +503,30 @@ def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None:
See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends
Available backends: flash, _flash_3_hub, sage, xformers, native, etc.
"""
backend = getattr(server_args, "diffusers_attention_backend", None)
backend = server_args.attention_backend

if backend is None:
config = getattr(server_args, "pipeline_config", None)
config = server_args.pipeline_config
if config is not None:
backend = getattr(config, "diffusers_attention_backend", None)

if backend is None:
return

backend = backend.lower()
sglang_backends = {e.name.lower() for e in AttentionBackendEnum} | {
"fa3",
"fa4",
}
if backend in sglang_backends:
logger.debug(
"Skipping diffusers attention backend '%s' because it matches a "
"SGLang backend name. Use diffusers backend names when running "
"the diffusers backend.",
backend,
)
return

for component_name in ["transformer", "unet"]:
component = getattr(pipe, component_name, None)
if component is not None and hasattr(component, "set_attention_backend"):
Expand All @@ -525,6 +543,44 @@ def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None:
e,
)

def _apply_cache_dit(self, pipe: Any, server_args: ServerArgs) -> Any:
"""Enable cache-dit for diffusers pipeline if configured."""
cache_dit_config = server_args.cache_dit_config
if not cache_dit_config:
return pipe

try:
import cache_dit
except ImportError as e:
raise RuntimeError(
"cache-dit is required for --cache-dit-config. "
"Install it with `pip install cache-dit`."
) from e

if not hasattr(cache_dit, "load_configs"):
raise RuntimeError(
"cache-dit>=1.2.0 is required for --cache-dit-config. "
"Please upgrade cache-dit."
)

try:
cache_options = cache_dit.load_configs(cache_dit_config)
except Exception as e:
raise ValueError(
"Failed to load cache-dit config. Provide a YAML/JSON path (or a dict "
"supported by cache-dit>=1.2.0)."
) from e

try:
pipe = cache_dit.enable_cache(pipe, **cache_options)
except Exception:
# cache-dit is an external integration and can raise a variety of errors.
logger.exception("Failed to enable cache-dit for diffusers pipeline")
raise
Comment on lines +576 to +579
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except Exception block is very broad. While logger.exception is helpful for debugging, catching a generic Exception can mask specific issues that might arise from cache_dit.enable_cache. Consider catching more specific exceptions if known, or at least adding a comment explaining why a broad exception is necessary here (e.g., due to the external nature of the cache_dit library and its potential to raise various exceptions).


logger.info("Enabled cache-dit for diffusers pipeline")
return pipe

def _get_device_map(self, server_args: ServerArgs) -> str | None:
"""
Determine device_map for pipeline loading.
Expand All @@ -540,7 +596,7 @@ def _get_dtype(self, server_args: ServerArgs) -> torch.dtype:
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

if hasattr(server_args, "pipeline_config") and server_args.pipeline_config:
dit_precision = getattr(server_args.pipeline_config, "dit_precision", None)
dit_precision = server_args.pipeline_config.dit_precision
if dit_precision == "fp16":
dtype = torch.float16
elif dit_precision == "bf16":
Expand Down
24 changes: 18 additions & 6 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,9 @@ class ServerArgs:

# Attention
attention_backend: str = None
diffusers_attention_backend: str = None # for diffusers backend only
cache_dit_config: str | dict[str, Any] | None = (
None # cache-dit config for diffusers
)

# Distributed executor backend
nccl_port: Optional[int] = None
Expand Down Expand Up @@ -452,15 +454,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--attention-backend",
type=str,
default=None,
choices=[e.name.lower() for e in AttentionBackendEnum] + ["fa3", "fa4"],
help="The attention backend to use. If not specified, the backend is automatically selected based on hardware and installed packages.",
help=(
"The attention backend to use. For SGLang-native pipelines, use "
"values like fa, torch_sdpa, sage_attn, etc. For diffusers pipelines, "
"use diffusers attention backend names such as flash, _flash_3_hub, "
"sage, or xformers."
),
)
parser.add_argument(
"--diffusers-attention-backend",
type=str,
dest="attention_backend",
default=None,
help="Attention backend for diffusers pipelines (e.g., flash, _flash_3_hub, sage, xformers). "
"See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends",
help=argparse.SUPPRESS,
)
parser.add_argument(
Copy link
Copy Markdown
Collaborator

@mickqian mickqian Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might function well. but we use env vars for non-diffusers-backend, introducing such a new arg could cause confusion

"--cache-dit-config",
type=str,
default=ServerArgs.cache_dit_config,
help="Path to a Cache-DiT YAML/JSON config. Enables cache-dit for diffusers backend.",
)

# HuggingFace specific parameters
Expand Down Expand Up @@ -999,7 +1011,7 @@ def check_server_args(self) -> None:
raise ValueError("pipeline_config is not set in ServerArgs")

self.pipeline_config.check_pipeline_config()
if self.attention_backend is None:
if self.attention_backend is None and self.backend != Backend.DIFFUSERS:
self._set_default_attention_backend()

# parallelism
Expand Down
Loading