diff --git a/.gitignore b/.gitignore index 97d2b1cc210d..93947f3ddcdc 100644 --- a/.gitignore +++ b/.gitignore @@ -248,3 +248,4 @@ lmms-eval **/.claude/ **/.serena/ ctags/ +outputs/ diff --git a/python/pyproject.toml b/python/pyproject.toml index 751584a49f9c..56a2226b9d53 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -95,6 +95,7 @@ diffusion = [ "vsa==0.0.4", "yunchang==0.6.3.post1", "runai_model_streamer", + "cache-dit==1.1.6" ] [tool.uv.extra-build-dependencies] diff --git a/python/sglang/multimodal_gen/docs/cache_dit.md b/python/sglang/multimodal_gen/docs/cache_dit.md new file mode 100644 index 000000000000..fa4a948d3a51 --- /dev/null +++ b/python/sglang/multimodal_gen/docs/cache_dit.md @@ -0,0 +1,174 @@ +# Cache-DiT Acceleration + +SGLang integrates [Cache-DiT](https://github.com/vipshop/cache-dit), a caching acceleration engine for Diffusion +Transformers (DiT), to achieve up to **7.4x inference speedup** with minimal quality loss. + +## Overview + +**Cache-DiT** uses intelligent caching strategies to skip redundant computation in the denoising loop: + +- **DBCache (Dual Block Cache)**: Dynamically decides when to cache transformer blocks based on residual differences +- **TaylorSeer**: Uses Taylor expansion for calibration to optimize caching decisions +- **SCM (Step Computation Masking)**: Step-level caching control for additional speedup + +## Basic Usage + +Enable Cache-DiT by exporting the environment variable and using `sglang generate` or `sglang serve` : + +```bash +SGLANG_CACHE_DIT_ENABLED=true \ +sglang generate --model-path Qwen/Qwen-Image \ + --prompt "A beautiful sunset over the mountains" +``` + +## Advanced Configuration + +### DBCache Parameters + +DBCache controls block-level caching behavior: + +| Parameter | Env Variable | Default | Description | +|-----------|---------------------------|---------|------------------------------------------| +| Fn | `SGLANG_CACHE_DIT_FN` | 1 | Number of first blocks to always compute | +| Bn | `SGLANG_CACHE_DIT_BN` | 0 | Number of last blocks to always compute | +| W | `SGLANG_CACHE_DIT_WARMUP` | 4 | Warmup steps before caching starts | +| R | `SGLANG_CACHE_DIT_RDT` | 0.24 | Residual difference threshold | +| MC | `SGLANG_CACHE_DIT_MC` | 3 | Maximum continuous cached steps | + +### TaylorSeer Configuration + +TaylorSeer improves caching accuracy using Taylor expansion: + +| Parameter | Env Variable | Default | Description | +|-----------|-------------------------------|---------|---------------------------------| +| Enable | `SGLANG_CACHE_DIT_TAYLORSEER` | false | Enable TaylorSeer calibrator | +| Order | `SGLANG_CACHE_DIT_TS_ORDER` | 1 | Taylor expansion order (1 or 2) | + +### Combined Configuration Example + +DBCache and TaylorSeer are complementary strategies that work together, you can configure both sets of parameters +simultaneously: + +```bash +SGLANG_CACHE_DIT_ENABLED=true \ +SGLANG_CACHE_DIT_FN=2 \ +SGLANG_CACHE_DIT_BN=1 \ +SGLANG_CACHE_DIT_WARMUP=4 \ +SGLANG_CACHE_DIT_RDT=0.4 \ +SGLANG_CACHE_DIT_MC=4 \ +SGLANG_CACHE_DIT_TAYLORSEER=true \ +SGLANG_CACHE_DIT_TS_ORDER=2 \ +sglang generate --model-path black-forest-labs/FLUX.1-dev \ + --prompt "A curious raccoon in a forest" +``` + +### SCM (Step Computation Masking) + +SCM provides step-level caching control for additional speedup. It decides which denoising steps to compute fully and +which to use cached results. + +#### SCM Presets + +SCM is configured with presets: + +| Preset | Compute Ratio | Speed | Quality | +|----------|---------------|----------|------------| +| `none` | 100% | Baseline | Best | +| `slow` | ~75% | ~1.3x | High | +| `medium` | ~50% | ~2x | Good | +| `fast` | ~35% | ~3x | Acceptable | +| `ultra` | ~25% | ~4x | Lower | + +##### Usage + +```bash +SGLANG_CACHE_DIT_ENABLED=true \ +SGLANG_CACHE_DIT_SCM_PRESET=medium \ +sglang generate --model-path Qwen/Qwen-Image \ + --prompt "A futuristic cityscape at sunset" +``` + +#### Custom SCM Bins + +For fine-grained control over which steps to compute vs cache: + +```bash +SGLANG_CACHE_DIT_ENABLED=true \ +SGLANG_CACHE_DIT_SCM_COMPUTE_BINS="8,3,3,2,2" \ +SGLANG_CACHE_DIT_SCM_CACHE_BINS="1,2,2,2,3" \ +sglang generate --model-path Qwen/Qwen-Image \ + --prompt "A futuristic cityscape at sunset" +``` + +#### SCM Policy + +| Policy | Env Variable | Description | +|-----------|---------------------------------------|---------------------------------------------| +| `dynamic` | `SGLANG_CACHE_DIT_SCM_POLICY=dynamic` | Adaptive caching based on content (default) | +| `static` | `SGLANG_CACHE_DIT_SCM_POLICY=static` | Fixed caching pattern | + +## Environment Variables + +All Cache-DiT parameters can be set via the following environment variables: + +| Environment Variable | Default | Description | +|-------------------------------------|---------|------------------------------------------| +| `SGLANG_CACHE_DIT_ENABLED` | false | Enable Cache-DiT acceleration | +| `SGLANG_CACHE_DIT_FN` | 1 | First N blocks to always compute | +| `SGLANG_CACHE_DIT_BN` | 0 | Last N blocks to always compute | +| `SGLANG_CACHE_DIT_WARMUP` | 4 | Warmup steps before caching | +| `SGLANG_CACHE_DIT_RDT` | 0.24 | Residual difference threshold | +| `SGLANG_CACHE_DIT_MC` | 3 | Max continuous cached steps | +| `SGLANG_CACHE_DIT_TAYLORSEER` | false | Enable TaylorSeer calibrator | +| `SGLANG_CACHE_DIT_TS_ORDER` | 1 | TaylorSeer order (1 or 2) | +| `SGLANG_CACHE_DIT_SCM_PRESET` | none | SCM preset (none/slow/medium/fast/ultra) | +| `SGLANG_CACHE_DIT_SCM_POLICY` | dynamic | SCM caching policy | +| `SGLANG_CACHE_DIT_SCM_COMPUTE_BINS` | not set | Custom SCM compute bins | +| `SGLANG_CACHE_DIT_SCM_CACHE_BINS` | not set | Custom SCM cache bins | + +## Supported Models + +SGLang Diffusion x Cache-DiT supports almost all models originally supported in SGLang Diffusion: + +| Model Family | Example Models | +|--------------|-----------------------------| +| Wan | Wan2.1, Wan2.2 | +| Flux | FLUX.1-dev, FLUX.2-dev | +| Z-Image | Z-Image-Turbo | +| Qwen | Qwen-Image, Qwen-Image-Edit | +| Hunyuan | HunyuanVideo | + +## Performance Tips + +1. **Start with defaults**: The default parameters work well for most models +2. **Use TaylorSeer**: It typically improves both speed and quality +3. **Tune R threshold**: Lower values = better quality, higher values = faster +4. **SCM for extra speed**: Use `medium` preset for good speed/quality balance +5. **Warmup matters**: Higher warmup = more stable caching decisions + +## Limitations + +- **Single GPU only**: Distributed support (TP/SP) is not yet validated; Cache-DiT will be automatically disabled when + `world_size > 1` +- **SCM minimum steps**: SCM requires >= 8 inference steps to be effective +- **Model support**: Only models registered in Cache-DiT's BlockAdapterRegister are supported + +## Troubleshooting + +### Distributed environment warning + +``` +WARNING: cache-dit is disabled in distributed environment (world_size=N) +``` + +This is expected behavior. Cache-DiT currently only supports single-GPU inference. + +### SCM disabled for low step count + +For models with < 8 inference steps (e.g., DMD distilled models), SCM will be automatically disabled. DBCache +acceleration still works. + +## References + +- [Cache-Dit](https://github.com/vipshop/cache-dit) +- [SGLang Diffusion](../README.md) diff --git a/python/sglang/multimodal_gen/docs/cli.md b/python/sglang/multimodal_gen/docs/cli.md index 1c054e466b32..3c4264a98dd9 100644 --- a/python/sglang/multimodal_gen/docs/cli.md +++ b/python/sglang/multimodal_gen/docs/cli.md @@ -177,6 +177,9 @@ SAMPLING_ARGS=( ) sglang generate "${SERVER_ARGS[@]}" "${SAMPLING_ARGS[@]}" + +# Or, users can set `SGLANG_CACHE_DIT_ENABLED` env as `true` to enable cache acceleration +SGLANG_CACHE_DIT_ENABLED=true sglang generate "${SERVER_ARGS[@]}" "${SAMPLING_ARGS[@]}" ``` Once the generation task has finished, the server will shut down automatically. diff --git a/python/sglang/multimodal_gen/docs/environment_variables.md b/python/sglang/multimodal_gen/docs/environment_variables.md new file mode 100644 index 000000000000..465e88056832 --- /dev/null +++ b/python/sglang/multimodal_gen/docs/environment_variables.md @@ -0,0 +1,19 @@ +## Cache-DiT Acceleration + +These variables configure cache-dit caching acceleration for Diffusion Transformer (DiT) models. +See [cache-dit documentation](cache_dit.md) for details. + +| Environment Variable | Default | Description | +|-------------------------------------|---------|------------------------------------------| +| `SGLANG_CACHE_DIT_ENABLED` | false | Enable Cache-DiT acceleration | +| `SGLANG_CACHE_DIT_FN` | 1 | First N blocks to always compute | +| `SGLANG_CACHE_DIT_BN` | 0 | Last N blocks to always compute | +| `SGLANG_CACHE_DIT_WARMUP` | 4 | Warmup steps before caching | +| `SGLANG_CACHE_DIT_RDT` | 0.24 | Residual difference threshold | +| `SGLANG_CACHE_DIT_MC` | 3 | Max continuous cached steps | +| `SGLANG_CACHE_DIT_TAYLORSEER` | false | Enable TaylorSeer calibrator | +| `SGLANG_CACHE_DIT_TS_ORDER` | 1 | TaylorSeer order (1 or 2) | +| `SGLANG_CACHE_DIT_SCM_PRESET` | none | SCM preset (none/slow/medium/fast/ultra) | +| `SGLANG_CACHE_DIT_SCM_POLICY` | dynamic | SCM caching policy | +| `SGLANG_CACHE_DIT_SCM_COMPUTE_BINS` | not set | Custom SCM compute bins | +| `SGLANG_CACHE_DIT_SCM_CACHE_BINS` | not set | Custom SCM cache bins | diff --git a/python/sglang/multimodal_gen/envs.py b/python/sglang/multimodal_gen/envs.py index 56418e72d3e7..8e4c150dcb18 100644 --- a/python/sglang/multimodal_gen/envs.py +++ b/python/sglang/multimodal_gen/envs.py @@ -37,6 +37,27 @@ VERBOSE: bool = False SGLANG_DIFFUSION_SERVER_DEV_MODE: bool = False SGLANG_DIFFUSION_STAGE_LOGGING: bool = False + # cache-dit env vars (primary transformer) + SGLANG_CACHE_DIT_ENABLED: bool = False + SGLANG_CACHE_DIT_FN: int = 1 + SGLANG_CACHE_DIT_BN: int = 0 + SGLANG_CACHE_DIT_WARMUP: int = 4 + SGLANG_CACHE_DIT_RDT: float = 0.24 + SGLANG_CACHE_DIT_MC: int = 3 + SGLANG_CACHE_DIT_TAYLORSEER: bool = False + SGLANG_CACHE_DIT_TS_ORDER: int = 1 + SGLANG_CACHE_DIT_SCM_PRESET: str = "none" + SGLANG_CACHE_DIT_SCM_COMPUTE_BINS: str | None = None + SGLANG_CACHE_DIT_SCM_CACHE_BINS: str | None = None + SGLANG_CACHE_DIT_SCM_POLICY: str = "dynamic" + # cache-dit env vars (secondary transformer, e.g., Wan2.2 low-noise expert) + SGLANG_CACHE_DIT_SECONDARY_FN: int = 1 + SGLANG_CACHE_DIT_SECONDARY_BN: int = 0 + SGLANG_CACHE_DIT_SECONDARY_WARMUP: int = 4 + SGLANG_CACHE_DIT_SECONDARY_RDT: float = 0.24 + SGLANG_CACHE_DIT_SECONDARY_MC: int = 3 + SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER: bool = False + SGLANG_CACHE_DIT_SECONDARY_TS_ORDER: int = 1 def _is_hip(): @@ -287,6 +308,90 @@ def maybe_convert_int(value: str | None) -> int | None: "SGLANG_DIFFUSION_STAGE_LOGGING": lambda: get_bool_env_var( "SGLANG_DIFFUSION_STAGE_LOGGING" ), + # ================== cache-dit Env Vars ================== + # Enable cache-dit acceleration for DiT inference + "SGLANG_CACHE_DIT_ENABLED": lambda: get_bool_env_var("SGLANG_CACHE_DIT_ENABLED"), + # Number of first blocks to always compute (DBCache F parameter) + "SGLANG_CACHE_DIT_FN": lambda: int(os.getenv("SGLANG_CACHE_DIT_FN", "1")), + # Number of last blocks to always compute (DBCache B parameter) + "SGLANG_CACHE_DIT_BN": lambda: int(os.getenv("SGLANG_CACHE_DIT_BN", "0")), + # Warmup steps before caching (DBCache W parameter) + "SGLANG_CACHE_DIT_WARMUP": lambda: int(os.getenv("SGLANG_CACHE_DIT_WARMUP", "4")), + # Residual difference threshold (DBCache R parameter) + "SGLANG_CACHE_DIT_RDT": lambda: float(os.getenv("SGLANG_CACHE_DIT_RDT", "0.24")), + # Maximum continuous cached steps (DBCache MC parameter) + "SGLANG_CACHE_DIT_MC": lambda: int(os.getenv("SGLANG_CACHE_DIT_MC", "3")), + # Enable TaylorSeer calibrator + "SGLANG_CACHE_DIT_TAYLORSEER": lambda: get_bool_env_var( + "SGLANG_CACHE_DIT_TAYLORSEER", default="false" + ), + # TaylorSeer order (1 or 2) + "SGLANG_CACHE_DIT_TS_ORDER": lambda: int( + os.getenv("SGLANG_CACHE_DIT_TS_ORDER", "1") + ), + # SCM preset: none, slow, medium, fast, ultra + "SGLANG_CACHE_DIT_SCM_PRESET": lambda: os.getenv( + "SGLANG_CACHE_DIT_SCM_PRESET", "none" + ), + # SCM custom compute bins (e.g., "8,3,3,2,2") + "SGLANG_CACHE_DIT_SCM_COMPUTE_BINS": lambda: os.getenv( + "SGLANG_CACHE_DIT_SCM_COMPUTE_BINS", None + ), + # SCM custom cache bins (e.g., "1,2,2,2,3") + "SGLANG_CACHE_DIT_SCM_CACHE_BINS": lambda: os.getenv( + "SGLANG_CACHE_DIT_SCM_CACHE_BINS", None + ), + # SCM policy: dynamic or static + "SGLANG_CACHE_DIT_SCM_POLICY": lambda: os.getenv( + "SGLANG_CACHE_DIT_SCM_POLICY", "dynamic" + ), + # ================== cache-dit Secondary Transformer Env Vars ================== + # For dual-transformer models like Wan2.2 (high-noise + low-noise experts) + # These parameters configure the secondary transformer (transformer_2) + # If not set, they inherit from the primary transformer settings + # Number of first blocks to always compute for secondary transformer + "SGLANG_CACHE_DIT_SECONDARY_FN": lambda: int( + os.getenv( + "SGLANG_CACHE_DIT_SECONDARY_FN", os.getenv("SGLANG_CACHE_DIT_FN", "1") + ) + ), + # Number of last blocks to always compute for secondary transformer + "SGLANG_CACHE_DIT_SECONDARY_BN": lambda: int( + os.getenv( + "SGLANG_CACHE_DIT_SECONDARY_BN", os.getenv("SGLANG_CACHE_DIT_BN", "0") + ) + ), + # Warmup steps before caching for secondary transformer + "SGLANG_CACHE_DIT_SECONDARY_WARMUP": lambda: int( + os.getenv( + "SGLANG_CACHE_DIT_SECONDARY_WARMUP", + os.getenv("SGLANG_CACHE_DIT_WARMUP", "4"), + ) + ), + # Residual difference threshold for secondary transformer + "SGLANG_CACHE_DIT_SECONDARY_RDT": lambda: float( + os.getenv( + "SGLANG_CACHE_DIT_SECONDARY_RDT", os.getenv("SGLANG_CACHE_DIT_RDT", "0.24") + ) + ), + # Maximum continuous cached steps for secondary transformer + "SGLANG_CACHE_DIT_SECONDARY_MC": lambda: int( + os.getenv( + "SGLANG_CACHE_DIT_SECONDARY_MC", os.getenv("SGLANG_CACHE_DIT_MC", "3") + ) + ), + # Enable TaylorSeer for secondary transformer + "SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER": lambda: get_bool_env_var( + "SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER", + default=os.getenv("SGLANG_CACHE_DIT_TAYLORSEER", "false"), + ), + # TaylorSeer order for secondary transformer + "SGLANG_CACHE_DIT_SECONDARY_TS_ORDER": lambda: int( + os.getenv( + "SGLANG_CACHE_DIT_SECONDARY_TS_ORDER", + os.getenv("SGLANG_CACHE_DIT_TS_ORDER", "1"), + ) + ), } diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index aae4af5033a8..a268184a6607 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -19,6 +19,7 @@ from einops import rearrange from tqdm.auto import tqdm +from sglang.multimodal_gen import envs from sglang.multimodal_gen.configs.pipeline_configs.base import ModelTaskType, STA_Mode from sglang.multimodal_gen.configs.pipeline_configs.wan import Wan2_2_TI2V_5B_Config from sglang.multimodal_gen.runtime.distributed import ( @@ -152,6 +153,153 @@ def __init__( # misc self.profiler = None + # cache-dit state (for delayed mounting and idempotent control) + self._cache_dit_enabled = False + self._cached_num_steps = None + + def _maybe_enable_cache_dit(self, num_inference_steps: int) -> None: + """Enable cache-dit on the transformers if configured (idempotent). + + This method should be called after the transformer is fully loaded + and before torch.compile is applied. + + For dual-transformer models (e.g., Wan2.2), this enables cache-dit on both + transformers with (potentially) different configurations. + + """ + if self._cache_dit_enabled: + if self._cached_num_steps != num_inference_steps: + logger.warning( + "num_inference_steps changed from %d to %d after cache-dit was enabled. " + "Continuing with initial configuration (steps=%d).", + self._cached_num_steps, + num_inference_steps, + self._cached_num_steps, + ) + return + # check if cache-dit is enabled in config + if not envs.SGLANG_CACHE_DIT_ENABLED: + return + + from sglang.multimodal_gen.runtime.distributed import get_world_size + from sglang.multimodal_gen.runtime.utils.cache_dit_integration import ( + CacheDitConfig, + enable_cache_on_dual_transformer, + enable_cache_on_transformer, + get_scm_mask, + ) + + if get_world_size() > 1: + logger.warning( + "cache-dit is disabled in distributed environment (world_size=%d). " + "Distributed support will be added in a future version.", + get_world_size(), + ) + return + # === Parse SCM configuration from envs === + # SCM is shared between primary and secondary transformers + scm_preset = envs.SGLANG_CACHE_DIT_SCM_PRESET + scm_compute_bins_str = envs.SGLANG_CACHE_DIT_SCM_COMPUTE_BINS + scm_cache_bins_str = envs.SGLANG_CACHE_DIT_SCM_CACHE_BINS + scm_policy = envs.SGLANG_CACHE_DIT_SCM_POLICY + + # parse custom bins if provided (both must be set together) + scm_compute_bins = None + scm_cache_bins = None + if scm_compute_bins_str and scm_cache_bins_str: + try: + scm_compute_bins = [ + int(x.strip()) for x in scm_compute_bins_str.split(",") + ] + scm_cache_bins = [int(x.strip()) for x in scm_cache_bins_str.split(",")] + except ValueError as e: + logger.warning("Failed to parse SCM bins: %s. SCM disabled.", e) + scm_preset = "none" + elif scm_compute_bins_str or scm_cache_bins_str: + # Only one of the bins was provided - warn user + logger.warning( + "SCM custom bins require both compute_bins and cache_bins. " + "Only one was provided (compute=%s, cache=%s). Falling back to preset '%s'.", + scm_compute_bins_str, + scm_cache_bins_str, + scm_preset, + ) + + # generate SCM mask using cache-dit's steps_mask() + # cache-dit handles step count validation and scaling internally + steps_computation_mask = get_scm_mask( + preset=scm_preset, + num_inference_steps=num_inference_steps, + compute_bins=scm_compute_bins, + cache_bins=scm_cache_bins, + ) + + # build config for primary transformer (high-noise expert) + primary_config = CacheDitConfig( + enabled=True, + Fn_compute_blocks=envs.SGLANG_CACHE_DIT_FN, + Bn_compute_blocks=envs.SGLANG_CACHE_DIT_BN, + max_warmup_steps=envs.SGLANG_CACHE_DIT_WARMUP, + residual_diff_threshold=envs.SGLANG_CACHE_DIT_RDT, + max_continuous_cached_steps=envs.SGLANG_CACHE_DIT_MC, + enable_taylorseer=envs.SGLANG_CACHE_DIT_TAYLORSEER, + taylorseer_order=envs.SGLANG_CACHE_DIT_TS_ORDER, + num_inference_steps=num_inference_steps, + # SCM fields + steps_computation_mask=steps_computation_mask, + steps_computation_policy=scm_policy, + ) + + if self.transformer_2 is not None: + # dual transformer + # build config for secondary transformer (low-noise expert) + # uses secondary parameters which inherit from primary if not explicitly set + secondary_config = CacheDitConfig( + enabled=True, + Fn_compute_blocks=envs.SGLANG_CACHE_DIT_SECONDARY_FN, + Bn_compute_blocks=envs.SGLANG_CACHE_DIT_SECONDARY_BN, + max_warmup_steps=envs.SGLANG_CACHE_DIT_SECONDARY_WARMUP, + residual_diff_threshold=envs.SGLANG_CACHE_DIT_SECONDARY_RDT, + max_continuous_cached_steps=envs.SGLANG_CACHE_DIT_SECONDARY_MC, + enable_taylorseer=envs.SGLANG_CACHE_DIT_SECONDARY_TAYLORSEER, + taylorseer_order=envs.SGLANG_CACHE_DIT_SECONDARY_TS_ORDER, + num_inference_steps=num_inference_steps, + # SCM fields - shared with primary + steps_computation_mask=steps_computation_mask, + steps_computation_policy=scm_policy, + ) + + # for dual transformers, must use BlockAdapter to enable cache on both simultaneously. + # Don't call enable_cache separately on each transformer. + self.transformer, self.transformer_2 = enable_cache_on_dual_transformer( + self.transformer, + self.transformer_2, + primary_config, + secondary_config, + model_name="wan2.2", + ) + logger.info( + "cache-dit enabled on dual transformers (steps=%d)", + num_inference_steps, + ) + else: + # single transformer + self.transformer = enable_cache_on_transformer( + self.transformer, + primary_config, + model_name="transformer", + ) + logger.info( + "cache-dit enabled on transformer (steps=%d, Fn=%d, Bn=%d, rdt=%.3f)", + num_inference_steps, + envs.SGLANG_CACHE_DIT_FN, + envs.SGLANG_CACHE_DIT_BN, + envs.SGLANG_CACHE_DIT_RDT, + ) + + self._cache_dit_enabled = True + self._cached_num_steps = num_inference_steps + @lru_cache(maxsize=8) def _build_guidance(self, batch_size, target_dtype, device, guidance_val): """Builds a guidance tensor. This method is cached.""" @@ -314,6 +462,10 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): self.transformer = loader.load( server_args.model_paths["transformer"], server_args ) + + # enable cache-dit before torch.compile (delayed mounting) + self._maybe_enable_cache_dit(batch.num_inference_steps) + if self.server_args.enable_torch_compile: self.transformer = torch.compile( self.transformer, mode="max-autotune", fullgraph=True @@ -321,6 +473,8 @@ def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): if pipeline: pipeline.add_module("transformer", self.transformer) server_args.model_loaded["transformer"] = True + else: + self._maybe_enable_cache_dit(batch.num_inference_steps) # Prepare extra step kwargs for scheduler extra_step_kwargs = self.prepare_extra_func_kwargs( @@ -933,16 +1087,22 @@ def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]: Args: func: The function to prepare kwargs for. kwargs: The kwargs to prepare. - - Returns: - The prepared kwargs. """ - extra_step_kwargs = {} - for k, v in kwargs.items(): - accepts = k in set(inspect.signature(func).parameters.keys()) - if accepts: - extra_step_kwargs[k] = v - return extra_step_kwargs + import functools + + # Handle cache-dit's partial wrapping logic. + # Cache-dit wraps the forward method with functools.partial where args[0] is the instance. + # We access `_original_forward` if available to inspect the underlying signature. + # See: https://github.com/vipshop/cache-dit + if isinstance(func, functools.partial) and func.args: + func = getattr(func.args[0], "_original_forward", func) + + # Unwrap any decorators (e.g. functools.wraps) + target_func = inspect.unwrap(func) + + # Filter kwargs based on the signature + params = inspect.signature(target_func).parameters + return {k: v for k, v in kwargs.items() if k in params} def progress_bar( self, iterable: Iterable | None = None, total: int | None = None diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index c95199b50d6e..5912f51e216e 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -543,7 +543,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Use torch.compile to speed up DiT inference." + "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)", ) - parser.add_argument( "--dit-cpu-offload", action=StoreBoolean, diff --git a/python/sglang/multimodal_gen/runtime/utils/cache_dit_integration.py b/python/sglang/multimodal_gen/runtime/utils/cache_dit_integration.py new file mode 100644 index 000000000000..0f0abd00d2ff --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/cache_dit_integration.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +cache-dit integration module for SGLang DiT pipelines. + +This module provides helper functions to enable cache-dit acceleration +on transformer modules in SGLang's modular pipeline architecture. +""" + +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +import cache_dit +from cache_dit import ( + BlockAdapter, + DBCacheConfig, + ForwardPattern, + ParamsModifier, + TaylorSeerCalibratorConfig, + steps_mask, +) +from cache_dit.caching.block_adapters import BlockAdapterRegister + + +def get_scm_mask( + preset: str, + num_inference_steps: int, + compute_bins: Optional[List[int]] = None, + cache_bins: Optional[List[int]] = None, +) -> Optional[List[int]]: + """ + Get SCM mask using cache-dit's steps_mask(). + + This is a thin wrapper that delegates to cache-dit's built-in + steps_mask() function which handles all presets and scaling logic. + + Args: + preset: Preset name ("none", "slow", "medium", "fast", "ultra"). + compute_bins: Custom compute bins (overrides preset). + cache_bins: Custom cache bins (overrides preset). + + Returns: + SCM mask list (1=compute, 0=cache), or None if disabled. + """ + if preset == "none" and not (compute_bins and cache_bins): + return None + + # Use cache-dit's steps_mask() directly + mask = steps_mask( + compute_bins=compute_bins, + cache_bins=cache_bins, + total_steps=num_inference_steps, + mask_policy=preset if preset != "none" else "medium", + ) + + compute_count = sum(mask) + cache_count = len(mask) - compute_count + logger.info( + "SCM: generated mask with %d compute steps, %d cache steps (preset=%s)", + compute_count, + cache_count, + preset, + ) + + return mask + + +@dataclass +class CacheDitConfig: + """Configuration for cache-dit integration. + + Attributes: + enabled: Whether to enable cache-dit acceleration. + Fn_compute_blocks: Number of first blocks to always compute (DBCache F). + Bn_compute_blocks: Number of last blocks to always compute (DBCache B). + max_warmup_steps: Number of warmup steps before caching starts (DBCache W). + residual_diff_threshold: Threshold for residual difference (DBCache R). + max_continuous_cached_steps: Maximum consecutive cached steps (DBCache MC). + enable_taylorseer: Whether to enable TaylorSeer calibrator. + taylorseer_order: Order of Taylor expansion (1 or 2). + num_inference_steps: Total number of inference steps (required for transformer-only mode). + steps_computation_mask: Binary mask for step-level caching (1=compute, 0=cache). + Generated by get_scm_mask() (wrapper around cache_dit.steps_mask()). + steps_computation_policy: Caching policy for SCM ("dynamic" or "static"). + """ + + enabled: bool = False + Fn_compute_blocks: int = 1 + Bn_compute_blocks: int = 0 + # Use 4 as default warmup steps instead of 8 in cache-dit, thus making + # DBCache work for few steps distilled models, e.g., Z-Image w/ 8-steps. + max_warmup_steps: int = 4 + # Use a relatively higher residual diff threshold (namely, 0.24) as default + # to allow more aggressive caching due to we have already applied max continuous + # cached steps limit, otherwise, we should use a lower threshold here like 0.12. + residual_diff_threshold: float = 0.24 + max_continuous_cached_steps: int = 3 + # TaylorSeer is not suitable for few steps distilled models, so, we choose + # to disable it by default. Reference: + # - From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers, + # https://arxiv.org/pdf/2503.06923 + # - FoCa: Forecast then Calibrate: Feature Caching as ODE for Efficient + # Diffusion Transformers, https://arxiv.org/pdf/2508.16211 + enable_taylorseer: bool = False + taylorseer_order: int = 1 + num_inference_steps: Optional[int] = None + # SCM fields (generated by _maybe_enable_cache_dit from env configuration) + steps_computation_mask: Optional[List[int]] = None + steps_computation_policy: str = "dynamic" + + +def enable_cache_on_transformer( + transformer: torch.nn.Module, + config: CacheDitConfig, + model_name: str = "transformer", +) -> torch.nn.Module: + """Enable cache-dit on a transformer module, by wrapping the module with cache-dit + + This function enables cache-dit acceleration using the BlockAdapterRegister + for pre-registered models + + Args: + model_name: Name of the model for logging purposes. + + """ + if not config.enabled: + return transformer + + if config.num_inference_steps is None: + raise ValueError( + "num_inference_steps is required for transformer-only mode. " + "Please provide it in CacheDitConfig." + ) + + # Check if the transformer is pre-registered in cache-dit + if not BlockAdapterRegister.is_supported(transformer): + transformer_cls_name = transformer.__class__.__name__ + raise ValueError( + f"{transformer_cls_name} is not officially supported by cache-dit. " + "Supported cache-dit DiT families include Flux, QwenImage, HunyuanDiT, " + "HunyuanVideo, Wan, CogVideoX, Mochi, and others. " + "Please ensure your transformer belongs to one of these families or " + "define a custom BlockAdapter." + ) + + # Build cache config (including SCM fields if provided) + cache_config = DBCacheConfig( + num_inference_steps=config.num_inference_steps, + Fn_compute_blocks=config.Fn_compute_blocks, + Bn_compute_blocks=config.Bn_compute_blocks, + max_warmup_steps=config.max_warmup_steps, + residual_diff_threshold=config.residual_diff_threshold, + max_continuous_cached_steps=config.max_continuous_cached_steps, + # SCM fields + steps_computation_mask=config.steps_computation_mask, + steps_computation_policy=config.steps_computation_policy, + ) + + # Build calibrator config if TaylorSeer is enabled + calibrator_config = None + if config.enable_taylorseer: + calibrator_config = TaylorSeerCalibratorConfig( + taylorseer_order=config.taylorseer_order, + ) + + # Enable cache-dit on the transformer + logger.info( + "Enabling cache-dit on %s with config: Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, " + "TaylorSeer=%s (order=%d), steps=%d", + model_name, + config.Fn_compute_blocks, + config.Bn_compute_blocks, + config.max_warmup_steps, + config.residual_diff_threshold, + config.max_continuous_cached_steps, + config.enable_taylorseer, + config.taylorseer_order, + config.num_inference_steps, + ) + + # Log SCM configuration if enabled + if config.steps_computation_mask: + compute_steps = sum(config.steps_computation_mask) + cache_steps = len(config.steps_computation_mask) - compute_steps + logger.info( + "SCM enabled: %d compute steps, %d cache steps, policy=%s", + compute_steps, + cache_steps, + config.steps_computation_policy, + ) + + cache_dit.enable_cache( + transformer, + cache_config=cache_config, + calibrator_config=calibrator_config, + ) + + return transformer + + +def enable_cache_on_dual_transformer( + transformer: torch.nn.Module, + transformer_2: torch.nn.Module, + primary_config: CacheDitConfig, + secondary_config: CacheDitConfig, + model_name: str = "wan2.2", +) -> tuple[torch.nn.Module, torch.nn.Module]: + """Enable cache-dit on dual transformers using BlockAdapter. + + For models with two transformers (high-noise expert and low-noise expert), + cache-dit requires enabling cache on both simultaneously via BlockAdapter. + This cannot be done by calling enable_cache separately on each transformer. + + Args: + primary_config: CacheDitConfig for primary transformer. + secondary_config: CacheDitConfig for secondary transformer. + """ + _supported_dual_transformer_models = [ + "wan2.2", # Currently, only Wan2.2 will run into dual-transformer case + ] + if model_name not in _supported_dual_transformer_models: + raise ValueError( + f"Dual-transformer cache-dit is only supported for " + f"{_supported_dual_transformer_models}, got {model_name}." + ) + + if not primary_config.enabled: + return transformer, transformer_2 + + if primary_config.num_inference_steps is None: + raise ValueError( + "num_inference_steps is required for dual-transformer mode. " + "Please provide it in CacheDitConfig." + ) + + # Build DBCacheConfig for primary transformer + primary_cache_config = DBCacheConfig( + num_inference_steps=primary_config.num_inference_steps, + Fn_compute_blocks=primary_config.Fn_compute_blocks, + Bn_compute_blocks=primary_config.Bn_compute_blocks, + max_warmup_steps=primary_config.max_warmup_steps, + residual_diff_threshold=primary_config.residual_diff_threshold, + max_continuous_cached_steps=primary_config.max_continuous_cached_steps, + steps_computation_mask=primary_config.steps_computation_mask, + steps_computation_policy=primary_config.steps_computation_policy, + ) + + # Build DBCacheConfig for secondary transformer + secondary_cache_config = DBCacheConfig( + num_inference_steps=secondary_config.num_inference_steps, + Fn_compute_blocks=secondary_config.Fn_compute_blocks, + Bn_compute_blocks=secondary_config.Bn_compute_blocks, + max_warmup_steps=secondary_config.max_warmup_steps, + residual_diff_threshold=secondary_config.residual_diff_threshold, + max_continuous_cached_steps=secondary_config.max_continuous_cached_steps, + steps_computation_mask=secondary_config.steps_computation_mask, + steps_computation_policy=secondary_config.steps_computation_policy, + ) + + # Build calibrator configs if TaylorSeer is enabled + primary_calibrator = None + if primary_config.enable_taylorseer: + primary_calibrator = TaylorSeerCalibratorConfig( + taylorseer_order=primary_config.taylorseer_order, + ) + + secondary_calibrator = None + if secondary_config.enable_taylorseer: + secondary_calibrator = TaylorSeerCalibratorConfig( + taylorseer_order=secondary_config.taylorseer_order, + ) + + # Build ParamsModifier for each transformer + primary_modifier = ParamsModifier( + cache_config=primary_cache_config, + calibrator_config=primary_calibrator, + ) + secondary_modifier = ParamsModifier( + cache_config=secondary_cache_config, + calibrator_config=secondary_calibrator, + ) + + # Log configuration + logger.info( + "Enabling cache-dit on %s dual transformers with BlockAdapter", + model_name, + ) + logger.info( + " Primary (transformer): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s", + primary_config.Fn_compute_blocks, + primary_config.Bn_compute_blocks, + primary_config.max_warmup_steps, + primary_config.residual_diff_threshold, + primary_config.max_continuous_cached_steps, + primary_config.enable_taylorseer, + ) + logger.info( + " Secondary (transformer_2): Fn=%d, Bn=%d, W=%d, R=%.2f, MC=%d, TaylorSeer=%s", + secondary_config.Fn_compute_blocks, + secondary_config.Bn_compute_blocks, + secondary_config.max_warmup_steps, + secondary_config.residual_diff_threshold, + secondary_config.max_continuous_cached_steps, + secondary_config.enable_taylorseer, + ) + + # Log SCM configuration if enabled + if primary_config.steps_computation_mask: + compute_steps = sum(primary_config.steps_computation_mask) + cache_steps = len(primary_config.steps_computation_mask) - compute_steps + logger.info( + " SCM enabled: %d compute steps, %d cache steps, policy=%s", + compute_steps, + cache_steps, + primary_config.steps_computation_policy, + ) + + # Get blocks attribute - Wan transformers use 'blocks' attribute + transformer_blocks = getattr(transformer, "blocks", None) + transformer_2_blocks = getattr(transformer_2, "blocks", None) + + if transformer_blocks is None or transformer_2_blocks is None: + raise ValueError( + "Dual transformers must have 'blocks' attribute for cache-dit. " + f"transformer has blocks: {transformer_blocks is not None}, " + f"transformer_2 has blocks: {transformer_2_blocks is not None}" + ) + + # Enable cache-dit using BlockAdapter for both transformers simultaneously + # This is required for Wan2.2 and similar dual-transformer architectures + if model_name == "wan2.2": + # Use Pattern_2 for Wan2.2 dual-transformer. We should check `model_name` + # to ensure we only apply this for supported models. Different models + # may require different ForwardPattern. + cache_dit.enable_cache( + BlockAdapter( + transformer=[transformer, transformer_2], + blocks=[transformer_blocks, transformer_2_blocks], + forward_pattern=[ForwardPattern.Pattern_2, ForwardPattern.Pattern_2], + params_modifiers=[primary_modifier, secondary_modifier], + has_separate_cfg=True, + ), + ) + else: + raise ValueError( + f"Dual-transformer is not implemented for model {model_name} yet." + ) + + return transformer, transformer_2