From 3ae5d3e15ff40b796a809e3bbd55cdde693c6d42 Mon Sep 17 00:00:00 2001 From: Chi Date: Wed, 7 Jan 2026 09:06:58 -0500 Subject: [PATCH 1/7] run diffusers models via cachedit Signed-off-by: Chi --- .../multimodal_gen/docs/cache/cache_dit.md | 37 ++++- python/sglang/multimodal_gen/docs/cli.md | 1 + .../runtime/pipelines/diffusers_pipeline.py | 126 ++++++++++++++++++ .../multimodal_gen/runtime/server_args.py | 9 ++ 4 files changed, 171 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/docs/cache/cache_dit.md b/python/sglang/multimodal_gen/docs/cache/cache_dit.md index 01038b6caa12..dcf555bb1582 100644 --- a/python/sglang/multimodal_gen/docs/cache/cache_dit.md +++ b/python/sglang/multimodal_gen/docs/cache/cache_dit.md @@ -24,6 +24,39 @@ sglang generate --model-path Qwen/Qwen-Image \ --prompt "A beautiful sunset over the mountains" ``` +## Diffusers Backend + +To enable Cache-DiT for the diffusers backend, pass a cache-dit config file: + +```bash +sglang generate --model-path Qwen/Qwen-Image \ + --backend diffusers \ + --cache-dit-config ./cache_dit_config.yaml \ + --prompt "A curious raccoon" +``` + +Cache-DiT accepts its native config format. SGLang also supports a nested layout +with `cache_config` and `parallelism_config`: + +```yaml +cache_config: + cache_type: DBCache + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 +parallelism_config: + ulysses_size: 4 + parallel_kwargs: + attention_backend: native + extra_parallel_modules: ["text_encoder", "vae"] +``` + ## Advanced Configuration ### DBCache Parameters @@ -151,8 +184,8 @@ SGLang Diffusion x Cache-DiT supports almost all models originally supported in ## Limitations -- **Single GPU only**: Distributed support (TP/SP) is not yet validated; Cache-DiT will be automatically disabled when - `world_size > 1` +- **SGLang-native pipelines**: Distributed support (TP/SP) is not yet validated; Cache-DiT will be automatically + disabled when `world_size > 1`. The diffusers backend uses cache-dit's native parallelism when configured. - **SCM minimum steps**: SCM requires >= 8 inference steps to be effective - **Model support**: Only models registered in Cache-DiT's BlockAdapterRegister are supported diff --git a/python/sglang/multimodal_gen/docs/cli.md b/python/sglang/multimodal_gen/docs/cli.md index dcd012f75cc7..a521ec81aab9 100644 --- a/python/sglang/multimodal_gen/docs/cli.md +++ b/python/sglang/multimodal_gen/docs/cli.md @@ -21,6 +21,7 @@ The SGLang-diffusion CLI provides a quick way to access the inference pipeline f - `--sp-degree {SP_SIZE}`: Sequence parallelism size (typically should match the number of GPUs) - `--ulysses-degree {ULYSSES_DEGREE}`: The degree of DeepSpeed-Ulysses-style SP in USP - `--ring-degree {RING_DEGREE}`: The degree of ring attention-style SP in USP +- `--cache-dit-config {PATH}`: Path to a Cache-DiT YAML/JSON config (diffusers backend only) ### Sampling Parameters diff --git a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py index a962063f2407..f9068983d2d7 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py @@ -470,6 +470,9 @@ def _load_diffusers_pipeline(self, model_path: str, server_args: ServerArgs) -> # Apply attention backend if specified self._apply_attention_backend(pipe, server_args) + # Apply cache-dit acceleration if configured + pipe = self._apply_cache_dit(pipe, server_args) + logger.info("Loaded diffusers pipeline: %s", pipe.__class__.__name__) return pipe @@ -525,6 +528,129 @@ def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None: e, ) + def _apply_cache_dit(self, pipe: Any, server_args: ServerArgs) -> Any: + """Enable cache-dit for diffusers pipeline if configured.""" + cache_dit_config = getattr(server_args, "cache_dit_config", None) + if not cache_dit_config: + return pipe + + try: + import cache_dit + except ImportError as e: + raise RuntimeError( + "cache-dit is required for --cache-dit-config. " + "Install it with `pip install cache-dit`." + ) from e + + cache_options = self._load_cache_dit_options(cache_dit, cache_dit_config) + parallelism_config = cache_options.get("parallelism_config") + if parallelism_config is not None: + self._resolve_cache_dit_extra_parallel_modules( + pipe, + parallelism_config, + cache_dit, + ) + + try: + pipe = cache_dit.enable_cache(pipe, **cache_options) + except Exception: + logger.exception("Failed to enable cache-dit for diffusers pipeline") + raise + + logger.info("Enabled cache-dit for diffusers pipeline") + return pipe + + def _load_cache_dit_options(self, cache_dit: Any, cache_dit_config: Any) -> dict: + """Load cache-dit options from a path or dict.""" + if isinstance(cache_dit_config, str): + config_dict = ServerArgs.load_config_file(cache_dit_config) + elif isinstance(cache_dit_config, dict): + config_dict = cache_dit_config + else: + raise ValueError( + "cache_dit_config must be a file path or a dict, got " + f"{type(cache_dit_config).__name__}" + ) + + if not isinstance(config_dict, dict): + raise ValueError( + "cache_dit_config must resolve to a dict, got " + f"{type(config_dict).__name__}" + ) + + if "cache_config" in config_dict: + cache_options = cache_dit.load_options(config_dict["cache_config"]) + parallelism_config = config_dict.get("parallelism_config") + if parallelism_config is not None: + cache_options["parallelism_config"] = cache_dit.ParallelismConfig( + **parallelism_config + ) + return cache_options + + return cache_dit.load_options(config_dict) + + def _resolve_cache_dit_extra_parallel_modules( + self, + pipe: Any, + parallelism_config: Any, + cache_dit: Any, + ) -> None: + parallel_kwargs = getattr(parallelism_config, "parallel_kwargs", None) + if not isinstance(parallel_kwargs, dict): + return + + extra_modules = parallel_kwargs.get("extra_parallel_modules") + if not extra_modules: + return + + resolved_modules = [] + for module in extra_modules: + if isinstance(module, str): + resolved = self._lookup_parallel_module(pipe, module, cache_dit) + if resolved is None: + logger.warning( + "cache-dit extra_parallel_modules entry '%s' could not be resolved", + module, + ) + continue + resolved_modules.append(resolved) + else: + resolved_modules.append(module) + + parallel_kwargs["extra_parallel_modules"] = resolved_modules + + def _lookup_parallel_module( + self, + pipe: Any, + name: str, + cache_dit: Any, + ) -> Any | None: + if name == "text_encoder": + return self._get_default_text_encoder(pipe, cache_dit) + + if hasattr(pipe, name): + return getattr(pipe, name) + + return None + + def _get_default_text_encoder(self, pipe: Any, cache_dit: Any) -> Any | None: + try: + from cache_dit.serve.utils import get_text_encoder_from_pipe + + encoder, _ = get_text_encoder_from_pipe(pipe) + if encoder is not None: + return encoder + except Exception: + pass + + for attr in ["text_encoder_2", "text_encoder_3", "text_encoder"]: + if hasattr(pipe, attr): + encoder = getattr(pipe, attr) + if encoder is not None: + return encoder + + return None + def _get_device_map(self, server_args: ServerArgs) -> str | None: """ Determine device_map for pipeline loading. diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index f95c09efffdf..c1d1ca83908b 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -236,6 +236,9 @@ class ServerArgs: # Attention attention_backend: str = None diffusers_attention_backend: str = None # for diffusers backend only + cache_dit_config: str | dict[str, Any] | None = ( + None # cache-dit config for diffusers + ) # Distributed executor backend nccl_port: Optional[int] = None @@ -462,6 +465,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Attention backend for diffusers pipelines (e.g., flash, _flash_3_hub, sage, xformers). " "See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends", ) + parser.add_argument( + "--cache-dit-config", + type=str, + default=ServerArgs.cache_dit_config, + help="Path to a Cache-DiT YAML/JSON config. Enables cache-dit for diffusers backend.", + ) # HuggingFace specific parameters parser.add_argument( From 53815e4b6084d24c404dd8ef3ceed338f14552db Mon Sep 17 00:00:00 2001 From: Chi Date: Wed, 7 Jan 2026 09:20:31 -0500 Subject: [PATCH 2/7] fix --- .../multimodal_gen/docs/cache/cache_dit.md | 33 ------------------- 1 file changed, 33 deletions(-) diff --git a/python/sglang/multimodal_gen/docs/cache/cache_dit.md b/python/sglang/multimodal_gen/docs/cache/cache_dit.md index dcf555bb1582..523f0e50b9ad 100644 --- a/python/sglang/multimodal_gen/docs/cache/cache_dit.md +++ b/python/sglang/multimodal_gen/docs/cache/cache_dit.md @@ -24,39 +24,6 @@ sglang generate --model-path Qwen/Qwen-Image \ --prompt "A beautiful sunset over the mountains" ``` -## Diffusers Backend - -To enable Cache-DiT for the diffusers backend, pass a cache-dit config file: - -```bash -sglang generate --model-path Qwen/Qwen-Image \ - --backend diffusers \ - --cache-dit-config ./cache_dit_config.yaml \ - --prompt "A curious raccoon" -``` - -Cache-DiT accepts its native config format. SGLang also supports a nested layout -with `cache_config` and `parallelism_config`: - -```yaml -cache_config: - cache_type: DBCache - Fn_compute_blocks: 1 - Bn_compute_blocks: 0 - max_warmup_steps: 8 - warmup_interval: 2 - max_cached_steps: -1 - max_continuous_cached_steps: 2 - residual_diff_threshold: 0.12 - enable_taylorseer: true - taylorseer_order: 1 -parallelism_config: - ulysses_size: 4 - parallel_kwargs: - attention_backend: native - extra_parallel_modules: ["text_encoder", "vae"] -``` - ## Advanced Configuration ### DBCache Parameters From bde7ea32413ee6bcebc85f320e3a1cad7f7ad8c3 Mon Sep 17 00:00:00 2001 From: Chi Date: Wed, 7 Jan 2026 09:22:38 -0500 Subject: [PATCH 3/7] fix --- python/sglang/multimodal_gen/docs/cache/cache_dit.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/docs/cache/cache_dit.md b/python/sglang/multimodal_gen/docs/cache/cache_dit.md index 523f0e50b9ad..405da4e97074 100644 --- a/python/sglang/multimodal_gen/docs/cache/cache_dit.md +++ b/python/sglang/multimodal_gen/docs/cache/cache_dit.md @@ -152,7 +152,7 @@ SGLang Diffusion x Cache-DiT supports almost all models originally supported in ## Limitations - **SGLang-native pipelines**: Distributed support (TP/SP) is not yet validated; Cache-DiT will be automatically - disabled when `world_size > 1`. The diffusers backend uses cache-dit's native parallelism when configured. + disabled when `world_size > 1`. - **SCM minimum steps**: SCM requires >= 8 inference steps to be effective - **Model support**: Only models registered in Cache-DiT's BlockAdapterRegister are supported From 35539295962358a7fd175ae9a4e717f33e2c62bc Mon Sep 17 00:00:00 2001 From: Chi Date: Wed, 7 Jan 2026 19:00:42 -0500 Subject: [PATCH 4/7] fix --- .../multimodal_gen/runtime/pipelines/diffusers_pipeline.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py index f9068983d2d7..5812d52004d5 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py @@ -554,6 +554,7 @@ def _apply_cache_dit(self, pipe: Any, server_args: ServerArgs) -> Any: try: pipe = cache_dit.enable_cache(pipe, **cache_options) except Exception: + # cache-dit is an external integration and can raise a variety of errors. logger.exception("Failed to enable cache-dit for diffusers pipeline") raise @@ -641,7 +642,10 @@ def _get_default_text_encoder(self, pipe: Any, cache_dit: Any) -> Any | None: if encoder is not None: return encoder except Exception: - pass + logger.debug( + "cache-dit get_text_encoder_from_pipe failed; falling back to attribute lookup", + exc_info=True, + ) for attr in ["text_encoder_2", "text_encoder_3", "text_encoder"]: if hasattr(pipe, attr): From 74031c4b82bb570e0874d7df378fa8cf7c46ce3b Mon Sep 17 00:00:00 2001 From: Chi Date: Thu, 8 Jan 2026 14:25:35 -0500 Subject: [PATCH 5/7] update sampling params --- .../sglang/multimodal_gen/configs/sample/sampling_params.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/configs/sample/sampling_params.py b/python/sglang/multimodal_gen/configs/sample/sampling_params.py index c478bab77be8..1f6ccb9ce2f3 100644 --- a/python/sglang/multimodal_gen/configs/sample/sampling_params.py +++ b/python/sglang/multimodal_gen/configs/sample/sampling_params.py @@ -463,7 +463,9 @@ def from_user_sampling_params_args(model_path: str, server_args, *args, **kwargs # Re-raise if it's not a safetensors file issue raise - user_sampling_params = SamplingParams(*args, **kwargs) + user_kwargs = dict(kwargs) + user_kwargs.pop("diffusers_kwargs", None) + user_sampling_params = SamplingParams(*args, **user_kwargs) # TODO: refactor sampling_params._merge_with_user_params(user_sampling_params) sampling_params._adjust(server_args) From 073f328975330466f8345014176654decf1fe88b Mon Sep 17 00:00:00 2001 From: qimcis Date: Sun, 18 Jan 2026 02:00:00 -0500 Subject: [PATCH 6/7] fix Signed-off-by: qimcis --- python/pyproject.toml | 2 +- .../multimodal_gen/docs/attention_backends.md | 5 +- .../multimodal_gen/docs/cache/cache_dit.md | 66 +++++++++ python/sglang/multimodal_gen/docs/cli.md | 1 + .../runtime/pipelines/diffusers_pipeline.py | 130 ++++-------------- .../multimodal_gen/runtime/server_args.py | 15 +- 6 files changed, 109 insertions(+), 110 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 25cfdba0b9e9..46834ea7b106 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -96,7 +96,7 @@ diffusion = [ "st_attn==0.0.7", "vsa==0.0.4", "runai_model_streamer", - "cache-dit==1.1.8" + "cache-dit==1.2.0" ] tracing = [ diff --git a/python/sglang/multimodal_gen/docs/attention_backends.md b/python/sglang/multimodal_gen/docs/attention_backends.md index b5f858137921..6e191cfca52e 100644 --- a/python/sglang/multimodal_gen/docs/attention_backends.md +++ b/python/sglang/multimodal_gen/docs/attention_backends.md @@ -8,13 +8,16 @@ Attention backends are defined by `AttentionBackendEnum` (`sglang.multimodal_gen Backend selection is performed by the shared attention layers (e.g. `LocalAttention` / `USPAttention` / `UlyssesAttention` in `sglang.multimodal_gen.runtime.layers.attention.layer`) and therefore applies to any model component using these layers (e.g. diffusion transformer / DiT and encoders). +When using the diffusers backend, `--attention-backend` is passed through to diffusers' +`set_attention_backend` (e.g., `flash`, `_flash_3_hub`, `sage`, `xformers`, `native`). + - **CUDA**: prefers FlashAttention (FA3/FA4) when supported; otherwise falls back to PyTorch SDPA. - **ROCm**: uses FlashAttention when available; otherwise falls back to PyTorch SDPA. - **MPS**: always uses PyTorch SDPA. ## Backend options -The CLI accepts the lowercase names of `AttentionBackendEnum`. The table below lists the backends implemented by the built-in platforms. `fa3`/`fa4` are accepted as aliases for `fa`. +For SGLang-native pipelines, the CLI accepts the lowercase names of `AttentionBackendEnum`. The table below lists the backends implemented by the built-in platforms. `fa3`/`fa4` are accepted as aliases for `fa`. | CLI value | Enum value | Notes | |---|---|---| diff --git a/python/sglang/multimodal_gen/docs/cache/cache_dit.md b/python/sglang/multimodal_gen/docs/cache/cache_dit.md index 405da4e97074..9e0a0f66a7a9 100644 --- a/python/sglang/multimodal_gen/docs/cache/cache_dit.md +++ b/python/sglang/multimodal_gen/docs/cache/cache_dit.md @@ -24,6 +24,72 @@ sglang generate --model-path Qwen/Qwen-Image \ --prompt "A beautiful sunset over the mountains" ``` +## Diffusers Backend Configuration + +Cache-DiT supports loading acceleration configs from a custom YAML file. For +diffusers pipelines, pass the YAML/JSON path via `--cache-dit-config`. This +flow requires cache-dit >= 1.2.0 (`cache_dit.load_configs`). + +### Single GPU inference + +Define a `config.yaml` file that contains: + +```yaml +cache_config: + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 +``` + +Then apply the config with: + +```bash +sglang generate --backend diffusers \ + --model-path Qwen/Qwen-Image \ + --cache-dit-config config.yaml \ + --prompt "A beautiful sunset over the mountains" +``` + +### Distributed inference + +Define a `parallel_config.yaml` file that contains: + +```yaml +cache_config: + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 +parallelism_config: + ulysses_size: auto + parallel_kwargs: + attention_backend: native + extra_parallel_modules: ["text_encoder", "vae"] +``` + +`ulysses_size: auto` means cache-dit will auto-detect the world_size. Otherwise, +set it to a specific integer (e.g., `4`). + +Then apply the distributed config with: + +```bash +sglang generate --backend diffusers \ + --model-path Qwen/Qwen-Image \ + --cache-dit-config parallel_config.yaml \ + --prompt "A futuristic cityscape at sunset" +``` + ## Advanced Configuration ### DBCache Parameters diff --git a/python/sglang/multimodal_gen/docs/cli.md b/python/sglang/multimodal_gen/docs/cli.md index a521ec81aab9..189afaf3afe3 100644 --- a/python/sglang/multimodal_gen/docs/cli.md +++ b/python/sglang/multimodal_gen/docs/cli.md @@ -21,6 +21,7 @@ The SGLang-diffusion CLI provides a quick way to access the inference pipeline f - `--sp-degree {SP_SIZE}`: Sequence parallelism size (typically should match the number of GPUs) - `--ulysses-degree {ULYSSES_DEGREE}`: The degree of DeepSpeed-Ulysses-style SP in USP - `--ring-degree {RING_DEGREE}`: The degree of ring attention-style SP in USP +- `--attention-backend {BACKEND}`: Attention backend to use. For SGLang-native pipelines use `fa`, `torch_sdpa`, `sage_attn`, etc. For diffusers pipelines use diffusers backend names like `flash`, `_flash_3_hub`, `sage`, `xformers`. - `--cache-dit-config {PATH}`: Path to a Cache-DiT YAML/JSON config (diffusers backend only) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py index 5812d52004d5..e628762d5d85 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py @@ -32,6 +32,7 @@ ) from sglang.multimodal_gen.runtime.pipelines_core.schedule_batch import Req from sglang.multimodal_gen.runtime.pipelines_core.stages import PipelineStage +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum from sglang.multimodal_gen.runtime.server_args import ServerArgs from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger @@ -502,7 +503,7 @@ def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None: See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends Available backends: flash, _flash_3_hub, sage, xformers, native, etc. """ - backend = getattr(server_args, "diffusers_attention_backend", None) + backend = getattr(server_args, "attention_backend", None) if backend is None: config = getattr(server_args, "pipeline_config", None) @@ -512,6 +513,20 @@ def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None: if backend is None: return + backend = backend.lower() + sglang_backends = {e.name.lower() for e in AttentionBackendEnum} | { + "fa3", + "fa4", + } + if backend in sglang_backends: + logger.debug( + "Skipping diffusers attention backend '%s' because it matches a " + "SGLang backend name. Use diffusers backend names when running " + "the diffusers backend.", + backend, + ) + return + for component_name in ["transformer", "unet"]: component = getattr(pipe, component_name, None) if component is not None and hasattr(component, "set_attention_backend"): @@ -542,15 +557,20 @@ def _apply_cache_dit(self, pipe: Any, server_args: ServerArgs) -> Any: "Install it with `pip install cache-dit`." ) from e - cache_options = self._load_cache_dit_options(cache_dit, cache_dit_config) - parallelism_config = cache_options.get("parallelism_config") - if parallelism_config is not None: - self._resolve_cache_dit_extra_parallel_modules( - pipe, - parallelism_config, - cache_dit, + if not hasattr(cache_dit, "load_configs"): + raise RuntimeError( + "cache-dit>=1.2.0 is required for --cache-dit-config. " + "Please upgrade cache-dit." ) + try: + cache_options = cache_dit.load_configs(cache_dit_config) + except Exception as e: + raise ValueError( + "Failed to load cache-dit config. Provide a YAML/JSON path (or a dict " + "supported by cache-dit>=1.2.0)." + ) from e + try: pipe = cache_dit.enable_cache(pipe, **cache_options) except Exception: @@ -561,100 +581,6 @@ def _apply_cache_dit(self, pipe: Any, server_args: ServerArgs) -> Any: logger.info("Enabled cache-dit for diffusers pipeline") return pipe - def _load_cache_dit_options(self, cache_dit: Any, cache_dit_config: Any) -> dict: - """Load cache-dit options from a path or dict.""" - if isinstance(cache_dit_config, str): - config_dict = ServerArgs.load_config_file(cache_dit_config) - elif isinstance(cache_dit_config, dict): - config_dict = cache_dit_config - else: - raise ValueError( - "cache_dit_config must be a file path or a dict, got " - f"{type(cache_dit_config).__name__}" - ) - - if not isinstance(config_dict, dict): - raise ValueError( - "cache_dit_config must resolve to a dict, got " - f"{type(config_dict).__name__}" - ) - - if "cache_config" in config_dict: - cache_options = cache_dit.load_options(config_dict["cache_config"]) - parallelism_config = config_dict.get("parallelism_config") - if parallelism_config is not None: - cache_options["parallelism_config"] = cache_dit.ParallelismConfig( - **parallelism_config - ) - return cache_options - - return cache_dit.load_options(config_dict) - - def _resolve_cache_dit_extra_parallel_modules( - self, - pipe: Any, - parallelism_config: Any, - cache_dit: Any, - ) -> None: - parallel_kwargs = getattr(parallelism_config, "parallel_kwargs", None) - if not isinstance(parallel_kwargs, dict): - return - - extra_modules = parallel_kwargs.get("extra_parallel_modules") - if not extra_modules: - return - - resolved_modules = [] - for module in extra_modules: - if isinstance(module, str): - resolved = self._lookup_parallel_module(pipe, module, cache_dit) - if resolved is None: - logger.warning( - "cache-dit extra_parallel_modules entry '%s' could not be resolved", - module, - ) - continue - resolved_modules.append(resolved) - else: - resolved_modules.append(module) - - parallel_kwargs["extra_parallel_modules"] = resolved_modules - - def _lookup_parallel_module( - self, - pipe: Any, - name: str, - cache_dit: Any, - ) -> Any | None: - if name == "text_encoder": - return self._get_default_text_encoder(pipe, cache_dit) - - if hasattr(pipe, name): - return getattr(pipe, name) - - return None - - def _get_default_text_encoder(self, pipe: Any, cache_dit: Any) -> Any | None: - try: - from cache_dit.serve.utils import get_text_encoder_from_pipe - - encoder, _ = get_text_encoder_from_pipe(pipe) - if encoder is not None: - return encoder - except Exception: - logger.debug( - "cache-dit get_text_encoder_from_pipe failed; falling back to attribute lookup", - exc_info=True, - ) - - for attr in ["text_encoder_2", "text_encoder_3", "text_encoder"]: - if hasattr(pipe, attr): - encoder = getattr(pipe, attr) - if encoder is not None: - return encoder - - return None - def _get_device_map(self, server_args: ServerArgs) -> str | None: """ Determine device_map for pipeline loading. diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index c1d1ca83908b..d2fa12ff6b81 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -235,7 +235,6 @@ class ServerArgs: # Attention attention_backend: str = None - diffusers_attention_backend: str = None # for diffusers backend only cache_dit_config: str | dict[str, Any] | None = ( None # cache-dit config for diffusers ) @@ -455,15 +454,19 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--attention-backend", type=str, default=None, - choices=[e.name.lower() for e in AttentionBackendEnum] + ["fa3", "fa4"], - help="The attention backend to use. If not specified, the backend is automatically selected based on hardware and installed packages.", + help=( + "The attention backend to use. For SGLang-native pipelines, use " + "values like fa, torch_sdpa, sage_attn, etc. For diffusers pipelines, " + "use diffusers attention backend names such as flash, _flash_3_hub, " + "sage, or xformers." + ), ) parser.add_argument( "--diffusers-attention-backend", type=str, + dest="attention_backend", default=None, - help="Attention backend for diffusers pipelines (e.g., flash, _flash_3_hub, sage, xformers). " - "See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends", + help=argparse.SUPPRESS, ) parser.add_argument( "--cache-dit-config", @@ -1008,7 +1011,7 @@ def check_server_args(self) -> None: raise ValueError("pipeline_config is not set in ServerArgs") self.pipeline_config.check_pipeline_config() - if self.attention_backend is None: + if self.attention_backend is None and self.backend != Backend.DIFFUSERS: self._set_default_attention_backend() # parallelism From 0d164fe70d6b5e953c1c4dac3f6d0df1b7e8e1e8 Mon Sep 17 00:00:00 2001 From: Mick Date: Wed, 21 Jan 2026 21:22:46 +0800 Subject: [PATCH 7/7] remove getattr --- .../runtime/pipelines/diffusers_pipeline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py index e628762d5d85..b408790d71f1 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py +++ b/python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py @@ -411,7 +411,7 @@ def _load_diffusers_pipeline(self, model_path: str, server_args: ServerArgs) -> load_kwargs["device_map"] = device_map # Add quantization config if provided (e.g., BitsAndBytesConfig for 4/8-bit) - config = getattr(server_args, "pipeline_config", None) + config = server_args.pipeline_config if config is not None: quant_config = getattr(config, "quantization_config", None) if quant_config is not None: @@ -479,7 +479,7 @@ def _load_diffusers_pipeline(self, model_path: str, server_args: ServerArgs) -> def _apply_vae_optimizations(self, pipe: Any, server_args: ServerArgs) -> None: """Apply VAE memory optimizations (tiling, slicing) from pipeline config.""" - config = getattr(server_args, "pipeline_config", None) + config = server_args.pipeline_config if config is None: return @@ -503,10 +503,10 @@ def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None: See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends Available backends: flash, _flash_3_hub, sage, xformers, native, etc. """ - backend = getattr(server_args, "attention_backend", None) + backend = server_args.attention_backend if backend is None: - config = getattr(server_args, "pipeline_config", None) + config = server_args.pipeline_config if config is not None: backend = getattr(config, "diffusers_attention_backend", None) @@ -545,7 +545,7 @@ def _apply_attention_backend(self, pipe: Any, server_args: ServerArgs) -> None: def _apply_cache_dit(self, pipe: Any, server_args: ServerArgs) -> Any: """Enable cache-dit for diffusers pipeline if configured.""" - cache_dit_config = getattr(server_args, "cache_dit_config", None) + cache_dit_config = server_args.cache_dit_config if not cache_dit_config: return pipe @@ -596,7 +596,7 @@ def _get_dtype(self, server_args: ServerArgs) -> torch.dtype: dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 if hasattr(server_args, "pipeline_config") and server_args.pipeline_config: - dit_precision = getattr(server_args.pipeline_config, "dit_precision", None) + dit_precision = server_args.pipeline_config.dit_precision if dit_precision == "fp16": dtype = torch.float16 elif dit_precision == "bf16":