Skip to content

[diffusion] Enable Cache‑DiT config for diffusers backend#16662

Merged
mickqian merged 7 commits intosgl-project:mainfrom
qimcis:cache-dit-diffusers
Jan 22, 2026
Merged

[diffusion] Enable Cache‑DiT config for diffusers backend#16662
mickqian merged 7 commits intosgl-project:mainfrom
qimcis:cache-dit-diffusers

Conversation

@qimcis
Copy link
Copy Markdown
Contributor

@qimcis qimcis commented Jan 7, 2026

Motivation

Addressing #16642, enabling Cache‑DiT acceleration for any diffusers pipeline in SGLang by allowing a cache‑dit config file to be passed through the diffusers backend.

Modifications

  • --cache-dit-config server arg in python/sglang/multimodal_gen/runtime/server_args.py
  • cache‑dit when using diffusers pipelines, including config loading and module resolution: python/sglang/multimodal_gen/runtime/pipelines/diffusers_pipeline.py.
  • fixed cli --diffusers-kwargs path to avoid SamplingParams init errors: python/sglang/multimodal_gen/configs/sample/sampling_params.py.

Accuracy Tests

Benchmarking and Profiling

2x RTX Pro 6000 WS

  • Cache‑DiT config:
cat > /tmp/cache_dit_config.yaml <<'EOF'
cache_config:
  max_warmup_steps: 8
  warmup_interval: 2
  max_cached_steps: -1
  max_continuous_cached_steps: 2
  Fn_compute_blocks: 1
  Bn_compute_blocks: 0
  residual_diff_threshold: 0.12
  enable_taylorseer: true
  taylorseer_order: 1
parallelism_config:
  ulysses_size: 2
  parallel_kwargs:
    attention_backend: native
    extra_parallel_modules: ["text_encoder", "vae"]
EOF

Test 1

  • Baseline:
    • 43.79s, peak 67.80 GB
    • Command:
CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition.' \
  --negative-prompt " " \
  --width 1664 \
  --height 928 \
  --num-inference-steps 50 \
  --seed 42 \
  --save-output \
  --output-path outputs \
  --output-file-name baseline-2gpu-a.png
baseline-2gpu-a
  • Cache‑DiT:
    • 14.69s, peak 67.98 GB (~2.98x faster)
    • Command:
CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --cache-dit-config /tmp/cache_dit_config.yaml \
  --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition.' \
  --negative-prompt " " \
  --width 1664 \
  --height 928 \
  --num-inference-steps 50 \
  --seed 42 \
  --diffusers-kwargs '{"max_sequence_length":512}' \
  --save-output \
  --output-path outputs \
  --output-file-name cachedit-2gpu-a.png
cachedit-2gpu-a

Test 2

  • Baseline:
    • 44.26s, peak 67.80 GB
    • Command:
CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --prompt 'A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7" in English, 2) a handwritten Chinese banner reading "欢迎来到未来城市", 3) a chalkboard menu listing "Espresso $2.50, Latte $3.75, Matcha $4.00". Include a small QR code in the corner and a timestamp "2025-01-08 18:30". Ultra-detailed, sharp typography, cinematic lighting, 4K.' \
  --negative-prompt " " \
  --width 1664 \
  --height 928 \
  --num-inference-steps 50 \
  --seed 42 \
  --save-output \
  --output-path outputs \
  --output-file-name baseline-2gpu-b.png
baseline-2gpu-b
  • Cache‑DiT:
    • 14.41s, peak 67.98 GB (~3.07x faster)
    • Command:
CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --cache-dit-config /tmp/cache_dit_config.yaml \
  --prompt 'A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7" in English, 2) a handwritten Chinese banner reading "欢迎来到未来城市", 3) a chalkboard menu listing "Espresso $2.50, Latte $3.75, Matcha $4.00". Include a small QR code in the corner and a timestamp "2025-01-08 18:30". Ultra-detailed, sharp typography, cinematic lighting, 4K.' \
  --negative-prompt " " \
  --width 1664 \
  --height 928 \
  --num-inference-steps 50 \
  --guidance-scale 4.0 \
  --seed 42 \
  --diffusers-kwargs '{"max_sequence_length":512}' \
  --save-output \
  --output-path outputs \
  --output-file-name cachedit-2gpu-b.png
cachedit-2gpu-b

Checklist

  • Format your code according to the Format code with pre-commit.
  • Add unit tests according to the Run and add unit tests.
  • Update documentation according to Write documentations.
  • Provide accuracy and speed benchmark results according to Test the accuracy and Benchmark the speed.
  • Follow the SGLang code style guidance.

@github-actions github-actions Bot added documentation Improvements or additions to documentation diffusion SGLang Diffusion labels Jan 7, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @qimcis, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for Cache-DiT acceleration within SGLang's Diffusers backend. It allows users to leverage Cache-DiT for optimizing any Diffusers pipeline by providing a configuration file. The changes involve adding a new command-line argument to specify the Cache-DiT configuration, integrating the Cache-DiT enablement logic directly into the Diffusers pipeline loading process, and updating relevant documentation to guide users on this new feature.

Highlights

  • Cache-DiT Integration: Enabled Cache-DiT acceleration for Diffusers pipelines within SGLang, addressing issue [Feature] Run any Diffusers' model with sglang via Cache-DiT #16642.
  • Configuration Option: Introduced a new --cache-dit-config CLI argument to specify Cache-DiT YAML/JSON configuration files.
  • Dynamic Application: Modified the Diffusers pipeline loading process to dynamically apply Cache-DiT based on the provided configuration.
  • Documentation Update: Updated documentation to reflect the new configuration option and clarify Cache-DiT limitations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Cache-DiT acceleration within the diffusers backend. It adds a new CLI argument --cache-dit-config to specify a Cache-DiT configuration file, updates the ServerArgs dataclass to include this configuration, and integrates the Cache-DiT enabling logic into the DiffusersPipeline.

The changes also include minor documentation updates to reflect the new CLI argument and clarify limitations regarding distributed support for Cache-DiT.

Comment on lines +556 to +558
except Exception:
logger.exception("Failed to enable cache-dit for diffusers pipeline")
raise
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The try...except Exception block is very broad. While logger.exception is helpful for debugging, catching a generic Exception can mask specific issues that might arise from cache_dit.enable_cache. Consider catching more specific exceptions if known, or at least adding a comment explaining why a broad exception is necessary here (e.g., due to the external nature of the cache_dit library and its potential to raise various exceptions).

Comment on lines +643 to +644
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception and then passing can hide critical errors, making debugging difficult. If get_text_encoder_from_pipe can fail in expected ways, consider catching those specific exceptions. Otherwise, it's better to log the exception at a debug level or re-raise it if it indicates a serious problem, rather than silently ignoring it.

@qimcis qimcis marked this pull request as ready for review January 8, 2026 19:53
@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Jan 9, 2026

@DefTruth could you take a look when you're available?

@DefTruth
Copy link
Copy Markdown
Contributor

DefTruth commented Jan 9, 2026

@DefTruth could you take a look when you're available?

cool~ I'll review it this afternoon

@DefTruth
Copy link
Copy Markdown
Contributor

DefTruth commented Jan 9, 2026

@qimcis @mickqian Thanks for this great work! I have some suggestions as follows:

    1. It's better to merge this pr after cache-dit 1.2.0 released (It is planned to be released next week). 1.2.0 is a major release that provide many new optimizations, include text encoder parallelism, autoencoder(vae) parallelism, controlnet parallelism, Ascend NPU support, many new models support (Z-Image, LongCat-Image, etc). Therefore, I strongly recommend using cache-dit 1.2.0 in this pr.
    1. I can help to design a new load_configs API in cache_dit that can properly load the cache_config and parallelism_config from yaml. Thus, We can keep the code simplest in sglang. The load_options will be deprecated in the future because of the compatibility conflicts. The usage of new load_configs API may look like:
     # Quantization will be treated as a third-party tool for cache-dit, so we do not plan to 
     # support quant_config for this backend. For quantization-related use cases, we  
     # recommend directly using sglang (instead of the cache-dit backend) as it delivers 
     # better performance.
     cache_dit.enable_cache(pipe, **cache_dit.load_configs("cache_dit_config.yml"))
    1. The peak memory results for the cache-dit backend in the tested cases appear to be incorrect. If text encoder parallelism and VAE parallelism were indeed enabled successfully, the GPU memory usage should have dropped significantly. I suspect this issue is caused by not using the latest (main) version of cache-dit.
    1. Following up on point 2, I can assist in fully handling the preparation of extra modules in cache-dit by allowing it to accept input in the form of a string (str). This way, you can directly load the configuration from the YAML file and use it without additional modification.

@DefTruth
Copy link
Copy Markdown
Contributor

DefTruth commented Jan 9, 2026

@qimcis The docs for cache_dit.load_configs API is here: https://cache-dit.readthedocs.io/en/latest/user_guide/LOAD_CONFIGS/

help="Attention backend for diffusers pipelines (e.g., flash, _flash_3_hub, sage, xformers). "
"See: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends",
)
parser.add_argument(
Copy link
Copy Markdown
Collaborator

@mickqian mickqian Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might function well. but we use env vars for non-diffusers-backend, introducing such a new arg could cause confusion

@DefTruth
Copy link
Copy Markdown
Contributor

🤗Cache-DiT v1.2.0 Major Release is ready! I recommend that we use this version.

@qimcis
Copy link
Copy Markdown
Contributor Author

qimcis commented Jan 18, 2026

Apologies for the delay, have had a busy week

I think I've implemented all the requested changes, lmk if I missed anything

After rerunning the tests on 1.2.0 I'm seeing good improvements! Let me know if these are approximately expected numbers @DefTruth

Item Original Bumped
Cache‑DiT version 1.1.8 1.2.0
Test 1 baseline 43.79s / 67.80 GB 47.75s / 67.80 GB
Test 1 cache‑dit 14.69s / 67.98 GB 21.03s / 49.56 GB
Test 1 speedup 2.98x 2.27x
Test 2 baseline 44.26s / 67.80 GB 47.65s / 67.80 GB
Test 2 cache‑dit 14.41s / 67.98 GB 20.69s / 49.56 GB
Test 2 speedup 3.07x 2.30x

@DefTruth
Copy link
Copy Markdown
Contributor

Apologies for the delay, have had a busy week

I think I've implemented all the requested changes, lmk if I missed anything

After rerunning the tests on 1.2.0 I'm seeing good improvements! Let me know if these are approximately expected numbers @DefTruth

Item Original Bumped
Cache‑DiT version 1.1.8 1.2.0
Test 1 baseline 43.79s / 67.80 GB 47.75s / 67.80 GB
Test 1 cache‑dit 14.69s / 67.98 GB 21.03s / 49.56 GB
Test 1 speedup 2.98x 2.27x
Test 2 baseline 44.26s / 67.80 GB 47.65s / 67.80 GB
Test 2 cache‑dit 14.41s / 67.98 GB 20.69s / 49.56 GB
Test 2 speedup 3.07x 2.30x

Hi~ thank you for providing the test data. The overall result is LGTM. Could you please also provide the complete test log at the same time? I can help analyze what causes the difference in the acceleration ratio.

@qimcis
Copy link
Copy Markdown
Contributor Author

qimcis commented Jan 18, 2026

Apologies for the delay, have had a busy week
I think I've implemented all the requested changes, lmk if I missed anything
After rerunning the tests on 1.2.0 I'm seeing good improvements! Let me know if these are approximately expected numbers @DefTruth
Item Original Bumped
Cache‑DiT version 1.1.8 1.2.0
Test 1 baseline 43.79s / 67.80 GB 47.75s / 67.80 GB
Test 1 cache‑dit 14.69s / 67.98 GB 21.03s / 49.56 GB
Test 1 speedup 2.98x 2.27x
Test 2 baseline 44.26s / 67.80 GB 47.65s / 67.80 GB
Test 2 cache‑dit 14.41s / 67.98 GB 20.69s / 49.56 GB
Test 2 speedup 3.07x 2.30x

Hi~ thank you for providing the test data. The overall result is LGTM. Could you please also provide the complete test log at the same time? I can help analyze what causes the difference in the acceleration ratio.

Yup! Attached below:

Test 1 Baseline

root@C.30179436:/workspace/sglang$ CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition.'  --output-file-name baseline-2gpu-a.png
[01-18 07:28:24] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 07:28:24] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "cache_dit_config": null, "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30054, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5631, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 07:28:24] Local mode: True
[01-18 07:28:24] Starting server...
[01-18 07:28:33] Scheduler bind at endpoint: tcp://127.0.0.1:5631
[01-18 07:28:33] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 07:28:34] Found nccl from library libnccl.so.2
[01-18 07:28:34] sglang-diffusion is using nccl==2.27.5
[01-18 07:28:35] Found nccl from library libnccl.so.2
[01-18 07:28:35] sglang-diffusion is using nccl==2.27.5
[01-18 07:28:35] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 07:28:35] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 07:28:35] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 07:28:35] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 07:28:35] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                                          | 0/5 [00:00<?, ?it/s]Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:  20%|███████████████████▌                                                                              | 1/5 [00:00<00:00,  4.22it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading pipeline components...:  40%|███████████████████████████████████████▏                                                          | 2/5 [00:00<00:00,  5.77it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.28it/s]
Loading pipeline components...:  80%|██████████████████████████████████████████████████████████████████████████████▍                   | 4/5 [00:03<00:01,  1.08s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.12it/s]
Loading pipeline components...:  40%|███████████████████████████████████████▏                                                          | 2/5 [00:08<00:14,  4.96s/it]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.30it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.35s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.16it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.39s/it]
[01-18 07:28:47] Loaded diffusers pipeline: QwenImagePipeline
[01-18 07:28:47] Detected pipeline type: image
[01-18 07:28:47] Pipeline instantiated
[01-18 07:28:47] Worker 0: Initialized device, model, and distributed environment.
[01-18 07:28:47] Worker 0: Scheduler loop started.
[01-18 07:28:47] Processing prompt 1/1: A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon ligh
[01-18 07:28:47] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
  0%|                                                                                                                                         | 0/50 [00:00<?, ?it/s]guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:45<00:00,  1.09it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:45<00:00,  1.09it/s]
[01-18 07:29:34] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
[01-18 07:29:34] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 07:29:34] [DiffusersExecutionStage] finished in 47.0062 seconds
[01-18 07:29:34] Peak GPU memory: 66.21 GB, Remaining GPU memory at peak: 29.38 GB. Components that can stay resident: []
[01-18 07:29:35] Output saved to outputs/baseline-2gpu-a.png
[01-18 07:29:35] Pixel data generated successfully in 47.75 seconds
[01-18 07:29:35] Completed batch processing. Generated 1 outputs in 47.75 seconds
[01-18 07:29:35] Memory usage - Max peak: 67800.77 MB, Avg peak: 67800.77 MB

Test 1 Cache-DiT

root@C.30179436:/workspace/sglang$ CUDA_VISIBLEDEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --cache-dit-config /tmp/cache_dit_config.yaml \
  --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition.' \
  --negative-prompt " " \
  --width 1664 \
  --height 928 \
  --num-inference-steps 50 \
  --output-file-name cachedit-2gpu-a.pnggth":512}' \
[01-18 07:36:47] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 07:36:47] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "cache_dit_config": "/tmp/cache_dit_config.yaml", "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30038, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5635, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 07:36:47] Parsed diffusers_kwargs: {'max_sequence_length': 512}
[01-18 07:36:47] Local mode: True
[01-18 07:36:47] Starting server...
[01-18 07:36:56] Scheduler bind at endpoint: tcp://127.0.0.1:5635
[01-18 07:36:56] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 07:36:57] Found nccl from library libnccl.so.2
[01-18 07:36:57] sglang-diffusion is using nccl==2.27.5
[01-18 07:36:58] Found nccl from library libnccl.so.2
[01-18 07:36:58] sglang-diffusion is using nccl==2.27.5
[01-18 07:36:58] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 07:36:58] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 07:36:58] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 07:36:58] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 07:36:58] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                                          | 0/5 [00:00<?, ?it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                                          | 0/5 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.28it/s]
Loading pipeline components...:  20%|███████████████████▌                                                                              | 1/5 [00:03<00:12,  3.25s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.11it/s]
Loading pipeline components...:  20%|███████████████████▌                                                                              | 1/5 [00:08<00:33,  8.35s/it`torch_dtype` is deprecated! Use `dtype` instead!█████████████████████████████████████████▏                                             | 5/9 [00:04<00:03,  1.05it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.16it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.29it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.37s/it]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.37s/it]
INFO 01-18 07:37:10 [config.py:50] Auto selected parallelism backend for transformer: Native_Diffuser
INFO 01-18 07:37:10 [cache_adapter.py:57] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!
INFO 01-18 07:37:10 [config.py:50] Auto selected parallelism backend for transformer: Native_Diffuser
INFO 01-18 07:37:10 [cache_adapter.py:57] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!
INFO 01-18 07:37:10 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 07:37:10 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 07:37:10 [block_adapters.py:504] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 07:37:10 [block_adapters.py:504] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 07:37:10 [cache_adapter.py:142] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 07:37:10 [cache_adapter.py:341] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 07:37:10 [block_adapters.py:504] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 07:37:10 [block_adapters.py:504] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 07:37:10 [cache_adapter.py:142] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 07:37:10 [cache_adapter.py:341] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 07:37:10 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_140173304933520, context_manager: QwenImagePipeline_140173358179312.
INFO 01-18 07:37:10 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_125143874362528, context_manager: QwenImagePipeline_125143899253760.
INFO 01-18 07:37:10 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 07:37:10 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 07:37:10 [_attention_dispatch.py:310] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:37:10 [_attention_dispatch.py:423] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:37:10 [_attention_dispatch.py:478] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:37:10 [_attention_dispatch.py:637] Flash Attention 3 not available, skipping _FLASH_3 backend registration.
INFO 01-18 07:37:10 [_attention_dispatch.py:686] Re-registered _NATIVE_NPU attention backend to enable context parallelism You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:37:10 [_attention_dispatch.py:310] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:37:10 [_attention_dispatch.py:423] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:37:10 [_attention_dispatch.py:478] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:37:10 [_attention_dispatch.py:637] Flash Attention 3 not available, skipping _FLASH_3 backend registration.
INFO 01-18 07:37:10 [_attention_dispatch.py:686] Re-registered _NATIVE_NPU attention backend to enable context parallelism You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
INFO 01-18 07:37:10 [dispatch.py:75] Parallelize Transformer: QwenImageTransformer2DModel, id:125143874362096, ParallelismConfig(backend=Native_Diffuser, ulysses_size=2)
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 07:37:10 [dispatch.py:75] Parallelize Transformer: QwenImageTransformer2DModel, id:140173359331472, ParallelismConfig(backend=Native_Diffuser, ulysses_size=2)
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 07:37:10 [dispatch.py:153] Found attention_backend from config, set attention backend of QwenImageTransformer2DModel to: native.
INFO 01-18 07:37:10 [dispatch.py:153] Found attention_backend from config, set attention backend of QwenImageTransformer2DModel to: native.
INFO 01-18 07:37:13 [dispatch.py:36] Parallelize Text Encoder: Qwen2_5_VLForConditionalGeneration, id:140173311464208, ParallelismConfig(backend=Native_PyTorch, tp_size=2)
INFO 01-18 07:37:13 [dispatch.py:36] Parallelize Auto Encoder: AutoencoderKLQwenImage, id:140173305045040, ParallelismConfig(backend=Native_PyTorch, dp_size=2)
INFO 01-18 07:37:13 [dispatch.py:36] Parallelize Text Encoder: Qwen2_5_VLForConditionalGeneration, id:125143874355664, ParallelismConfig(backend=Native_PyTorch, tp_size=2)
INFO 01-18 07:37:13 [dispatch.py:36] Parallelize Auto Encoder: AutoencoderKLQwenImage, id:125143140367392, ParallelismConfig(backend=Native_PyTorch, dp_size=2)
[01-18 07:37:15] Enabled cache-dit for diffusers pipeline
[01-18 07:37:15] Loaded diffusers pipeline: QwenImagePipeline
[01-18 07:37:15] Detected pipeline type: image
[01-18 07:37:15] Pipeline instantiated
[01-18 07:37:15] Worker 0: Initialized device, model, and distributed environment.
[01-18 07:37:15] Worker 0: Scheduler loop started.
[01-18 07:37:15] Processing prompt 1/1: A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon ligh
[01-18 07:37:15] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.94it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.94it/s]
[rank1]:[W118 07:37:35.816357827 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 1] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[rank0]:[W118 07:37:35.929964787 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 0] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[rank1]:[W118 07:37:35.951580790 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 1] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[rank0]:[W118 07:37:36.250010324 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 0] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[01-18 07:37:36] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
[01-18 07:37:36] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 07:37:36] [DiffusersExecutionStage] finished in 20.3875 seconds
[01-18 07:37:36] Peak GPU memory: 48.40 GB, Remaining GPU memory at peak: 47.19 GB. Components that can stay resident: []
[01-18 07:37:36] Output saved to outputs/cachedit-2gpu-a.png
[01-18 07:37:36] Pixel data generated successfully in 21.07 seconds
[01-18 07:37:36] Completed batch processing. Generated 1 outputs in 21.07 seconds
[01-18 07:37:36] Memory usage - Max peak: 49565.17 MB, Avg peak: 49565.17 MB

Test 2 Baseline

root@C.30179436:/workspace/sglang$ CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --prompt 'A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7" in English, 2) a handwritten Chinese banner reading "欢迎来到未来城市", 3) a chalkboard menu listing "Espresso $2.50, Latte $3.75, Matcha $4.00". Include a small QR code in the corner and a timestamp "2025-01-08 18:30". Ultra-detailed, sharp typography, cinematic lighting, 4K.' \
  --negative-prompt " " \
  --width 1664 \
  --height 928 \
  --num-inference-steps 50 \
  --seed 42 \
  --output-file-name baseline-2gpu-b.png
[01-18 07:39:03] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 07:39:03] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "cache_dit_config": null, "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30083, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5589, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 07:39:03] Local mode: True
[01-18 07:39:03] Starting server...
[01-18 07:39:12] Scheduler bind at endpoint: tcp://127.0.0.1:5589
[01-18 07:39:12] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 07:39:13] Found nccl from library libnccl.so.2
[01-18 07:39:13] sglang-diffusion is using nccl==2.27.5
[01-18 07:39:13] Found nccl from library libnccl.so.2
[01-18 07:39:13] sglang-diffusion is using nccl==2.27.5
[01-18 07:39:14] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 07:39:14] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 07:39:14] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 07:39:14] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 07:39:14] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                                          | 0/5 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:  20%|███████████████████▌                                                                              | 1/5 [00:00<00:00,  9.86it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.27it/s]
Loading pipeline components...:  20%|███████████████████▌                                                                              | 1/5 [00:03<00:13,  3.28s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.13it/s]
Loading pipeline components...:  80%|██████████████████████████████████████████████████████████████████████████████▍                   | 4/5 [00:08<00:02,  2.15s/it]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.14it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.39s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.28it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.41s/it]
[01-18 07:39:26] Loaded diffusers pipeline: QwenImagePipeline
[01-18 07:39:26] Detected pipeline type: image
[01-18 07:39:26] Pipeline instantiated
[01-18 07:39:26] Worker 0: Initialized device, model, and distributed environment.
[01-18 07:39:26] Worker 0: Scheduler loop started.
[01-18 07:39:26] Processing prompt 1/1: A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7"
[01-18 07:39:26] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
  0%|                                                                                                                                         | 0/50 [00:00<?, ?it/s]guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:45<00:00,  1.09it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:45<00:00,  1.09it/s]
[01-18 07:40:13] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
[01-18 07:40:13] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 07:40:13] [DiffusersExecutionStage] finished in 46.9359 seconds
[01-18 07:40:13] Peak GPU memory: 66.21 GB, Remaining GPU memory at peak: 29.38 GB. Components that can stay resident: []
[01-18 07:40:13] Output saved to outputs/baseline-2gpu-b.png
[01-18 07:40:13] Pixel data generated successfully in 47.65 seconds
[01-18 07:40:13] Completed batch processing. Generated 1 outputs in 47.65 seconds
[01-18 07:40:13] Memory usage - Max peak: 67801.88 MB, Avg peak: 67801.88 MB

Test 2 Cache-DiT

root@C.30179436:/workspace/sglang$ CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --cache-dit-config /tmp/cache_dit_config.yaml \
  --prompt 'A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7" in English, 2) a handwritten Chinese banner reading "欢迎来到未来城市", 3) a chalkboard menu listing "Espresso $2.50, Latte $3.75, Matcha $4.00". Include a small QR code in the corner and a timestamp "2025-01-08 18:30". Ultra-detailed, sharp typography, cinematic lighting, 4K.' \
  --negative-prompt " " \
  --width 1664 \
  --height 928 \
  --num-inference-steps 50 \
  --output-file-name cachedit-2gpu-b.pnggth":512}' \
[01-18 07:42:55] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 07:42:55] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "cache_dit_config": "/tmp/cache_dit_config.yaml", "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30088, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5603, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 07:42:55] Parsed diffusers_kwargs: {'max_sequence_length': 512}
[01-18 07:42:55] Local mode: True
[01-18 07:42:55] Starting server...
[01-18 07:43:05] Scheduler bind at endpoint: tcp://127.0.0.1:5603
[01-18 07:43:05] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 07:43:06] Found nccl from library libnccl.so.2
[01-18 07:43:06] sglang-diffusion is using nccl==2.27.5
[01-18 07:43:06] Found nccl from library libnccl.so.2
[01-18 07:43:06] sglang-diffusion is using nccl==2.27.5
[01-18 07:43:06] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 07:43:06] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 07:43:06] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 07:43:06] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 07:43:06] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                                          | 0/5 [00:00<?, ?it/s]Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                                          | 0/5 [00:00<?, ?it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading pipeline components...:  40%|███████████████████████████████████████▏                                                          | 2/5 [00:00<00:00,  5.84it/s]`torch_dtype` is deprecated! Use `dtype` instead!                                                                                              | 0/9 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.29it/s]
Loading pipeline components...:  60%|██████████████████████████████████████████████████████████▊                                       | 3/5 [00:03<00:03,  1.56s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.11it/s]
Loading pipeline components...:  60%|██████████████████████████████████████████████████████████▊                                       | 3/5 [00:08<00:04,  2.23s/it]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.16it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.28it/s]
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.36s/it]
INFO 01-18 07:43:18 [config.py:50] Auto selected parallelism backend for transformer: Native_Diffuser
INFO 01-18 07:43:18 [cache_adapter.py:57] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!
INFO 01-18 07:43:18 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 07:43:18 [block_adapters.py:504] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 07:43:18 [block_adapters.py:504] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 07:43:18 [cache_adapter.py:142] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 07:43:18 [cache_adapter.py:341] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 07:43:18 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_137489847202112, context_manager: QwenImagePipeline_137496387101728.
INFO 01-18 07:43:18 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.37s/it]
INFO 01-18 07:43:18 [_attention_dispatch.py:310] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:43:18 [_attention_dispatch.py:423] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:43:18 [_attention_dispatch.py:478] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:43:18 [_attention_dispatch.py:637] Flash Attention 3 not available, skipping _FLASH_3 backend registration.
INFO 01-18 07:43:18 [_attention_dispatch.py:686] Re-registered _NATIVE_NPU attention backend to enable context parallelism You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
INFO 01-18 07:43:18 [config.py:50] Auto selected parallelism backend for transformer: Native_Diffuser
INFO 01-18 07:43:18 [cache_adapter.py:57] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!
INFO 01-18 07:43:18 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 07:43:18 [block_adapters.py:504] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 07:43:18 [block_adapters.py:504] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 07:43:18 [cache_adapter.py:142] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 07:43:18 [cache_adapter.py:341] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 07:43:18 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_140101441558656, context_manager: QwenImagePipeline_140101487953392.
INFO 01-18 07:43:19 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 07:43:19 [_attention_dispatch.py:310] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:43:19 [_attention_dispatch.py:423] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:43:19 [_attention_dispatch.py:478] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 07:43:19 [_attention_dispatch.py:637] Flash Attention 3 not available, skipping _FLASH_3 backend registration.
INFO 01-18 07:43:19 [_attention_dispatch.py:686] Re-registered _NATIVE_NPU attention backend to enable context parallelism You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
INFO 01-18 07:43:19 [dispatch.py:75] Parallelize Transformer: QwenImageTransformer2DModel, id:137496386144928, ParallelismConfig(backend=Native_Diffuser, ulysses_size=2)
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 07:43:19 [dispatch.py:75] Parallelize Transformer: QwenImageTransformer2DModel, id:140101441557456, ParallelismConfig(backend=Native_Diffuser, ulysses_size=2)
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 07:43:19 [dispatch.py:153] Found attention_backend from config, set attention backend of QwenImageTransformer2DModel to: native.
INFO 01-18 07:43:19 [dispatch.py:153] Found attention_backend from config, set attention backend of QwenImageTransformer2DModel to: native.
INFO 01-18 07:43:21 [dispatch.py:36] Parallelize Text Encoder: Qwen2_5_VLForConditionalGeneration, id:137496391613328, ParallelismConfig(backend=Native_PyTorch, tp_size=2)
INFO 01-18 07:43:21 [dispatch.py:36] Parallelize Auto Encoder: AutoencoderKLQwenImage, id:137496390222528, ParallelismConfig(backend=Native_PyTorch, dp_size=2)
INFO 01-18 07:43:21 [dispatch.py:36] Parallelize Text Encoder: Qwen2_5_VLForConditionalGeneration, id:140101488564016, ParallelismConfig(backend=Native_PyTorch, tp_size=2)
INFO 01-18 07:43:21 [dispatch.py:36] Parallelize Auto Encoder: AutoencoderKLQwenImage, id:140101199941024, ParallelismConfig(backend=Native_PyTorch, dp_size=2)
[01-18 07:43:24] Enabled cache-dit for diffusers pipeline
[01-18 07:43:24] Loaded diffusers pipeline: QwenImagePipeline
[01-18 07:43:24] Detected pipeline type: image
[01-18 07:43:24] Pipeline instantiated
[01-18 07:43:24] Worker 0: Initialized device, model, and distributed environment.
[01-18 07:43:24] Worker 0: Scheduler loop started.
[01-18 07:43:24] Processing prompt 1/1: A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7"
[01-18 07:43:24] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
  0%|                                                                                                                                         | 0/50 [00:00<?, ?it/s]guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
[rank1]:[W118 07:43:44.165188675 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 1] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[rank0]:[W118 07:43:44.235758707 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 0] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[rank1]:[W118 07:43:44.256721372 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 1] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[rank0]:[W118 07:43:44.546492035 ProcessGroupNCCL.cpp:3961] Warning: [PG ID 0 PG GUID 0(default_pg) Rank 0] An unbatched P2P op (send/recv) was called on this ProcessGroup with size 2.  In eager initialization mode, unbatched P2P ops are treated as independent collective ops, and are thus serialized with all other ops on this ProcessGroup, including other P2P ops. To avoid serialization, either create additional independent ProcessGroups for the P2P ops or use batched P2P ops. You can squash this warning by setting the environment variable TORCH_NCCL_SHOW_EAGER_INIT_P2P_SERIALIZATION_WARNING to false. (function operator())
[01-18 07:43:44] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
[01-18 07:43:44] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 07:43:44] [DiffusersExecutionStage] finished in 19.9821 seconds
[01-18 07:43:44] Peak GPU memory: 48.40 GB, Remaining GPU memory at peak: 47.19 GB. Components that can stay resident: []
[01-18 07:43:45] Output saved to outputs/cachedit-2gpu-b.png
[01-18 07:43:45] Pixel data generated successfully in 20.69 seconds
[01-18 07:43:45] Completed batch processing. Generated 1 outputs in 20.69 seconds
[01-18 07:43:45] Memory usage - Max peak: 49564.07 MB, Avg peak: 49564.07 MB

@DefTruth
Copy link
Copy Markdown
Contributor

can you also share the test logs based on cache-dit v1.1.8 ?

@qimcis
Copy link
Copy Markdown
Contributor Author

qimcis commented Jan 18, 2026

can you also share the test logs based on cache-dit v1.1.8 ?

I don't have the logs for when I originally ran the tests, I just reran them on the previous commit, logs below:

Test 1 Baseline

root@C.30179436:/workspace/sglang$ CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition.' \
  --negative-prompt " " \
  --width 1664 \
  --output-file-name baseline-2gpu-a.png
[01-18 09:04:23] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 09:04:23] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "diffusers_attention_backend": null, "cache_dit_config": null, "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30075, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5636, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 09:04:23] Local mode: True
[01-18 09:04:23] Starting server...
[01-18 09:04:33] Scheduler bind at endpoint: tcp://127.0.0.1:5636
[01-18 09:04:33] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 09:04:34] Found nccl from library libnccl.so.2
[01-18 09:04:34] sglang-diffusion is using nccl==2.27.5
[01-18 09:04:35] Found nccl from library libnccl.so.2
[01-18 09:04:35] sglang-diffusion is using nccl==2.27.5
[01-18 09:04:35] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 09:04:35] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 09:04:35] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 09:04:35] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 09:04:35] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                             | 0/5 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:  20%|█████████████████                                                                    | 1/5 [00:00<00:00,  4.30it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.22it/s]
Loading pipeline components...:  60%|███████████████████████████████████████████████████                                  | 3/5 [00:03<00:01,  1.04it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.04it/s]
Loading pipeline components...:  40%|██████████████████████████████████                                                   | 2/5 [00:09<00:15,  5.30s/it]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.10it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.48s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.22it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.51s/it]
[01-18 09:04:47] Loaded diffusers pipeline: QwenImagePipeline
[01-18 09:04:47] Detected pipeline type: image
[01-18 09:04:47] Pipeline instantiated
[01-18 09:04:47] Worker 0: Initialized device, model, and distributed environment.
[01-18 09:04:47] Worker 0: Scheduler loop started.
[01-18 09:04:47] Processing prompt 1/1: A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon ligh
[01-18 09:04:47] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
  0%|                                                                                                                            | 0/50 [00:00<?, ?it/s]guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:46<00:00,  1.08it/s]
[01-18 09:05:35] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:46<00:00,  1.07it/s]
[01-18 09:05:35] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 09:05:35] [DiffusersExecutionStage] finished in 47.5436 seconds
[01-18 09:05:35] Peak GPU memory: 66.22 GB, Remaining GPU memory at peak: 29.38 GB. Components that can stay resident: []
[01-18 09:05:36] Output saved to outputs/baseline-2gpu-a.png
[01-18 09:05:36] Pixel data generated successfully in 48.24 seconds
[01-18 09:05:36] Completed batch processing. Generated 1 outputs in 48.24 seconds
[01-18 09:05:36] Memory usage - Max peak: 67804.48 MB, Avg peak: 67804.48 MB

Test 1 Cache-DiT

root@C.30179436:/workspace/sglang$ CUDA_VISIBLEDEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --cache-dit-config /tmp/cache_dit_config.yaml \
  --prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition.' \
  --negative-prompt " " \
  --output-file-name cachedit-2gpu-a.pnggth":512}' \
[01-18 09:08:05] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 09:08:05] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "diffusers_attention_backend": null, "cache_dit_config": "/tmp/cache_dit_config.yaml", "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30060, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5637, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 09:08:05] Parsed diffusers_kwargs: {'max_sequence_length': 512}
[01-18 09:08:05] Local mode: True
[01-18 09:08:05] Starting server...
[01-18 09:08:14] Scheduler bind at endpoint: tcp://127.0.0.1:5637
[01-18 09:08:14] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 09:08:15] Found nccl from library libnccl.so.2
[01-18 09:08:15] sglang-diffusion is using nccl==2.27.5
[01-18 09:08:16] Found nccl from library libnccl.so.2
[01-18 09:08:16] sglang-diffusion is using nccl==2.27.5
[01-18 09:08:16] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 09:08:16] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 09:08:16] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 09:08:16] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 09:08:16] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                             | 0/5 [00:00<?, ?it/s]Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:  20%|█████████████████                                                                    | 1/5 [00:00<00:00,  4.29it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading pipeline components...:  40%|██████████████████████████████████                                                   | 2/5 [00:00<00:00,  8.39it/s`torch_dtype` is deprecated! Use `dtype` instead!                                                                                  | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.31it/s]
Loading pipeline components...:  40%|██████████████████████████████████                                                   | 2/5 [00:03<00:05,  1.97s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.22it/s]
Loading pipeline components...:  80%|████████████████████████████████████████████████████████████████████                 | 4/5 [00:03<00:01,  1.07s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.14it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.39s/it]
INFO 01-18 09:08:28 [cache_adapter.py:49] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!
INFO 01-18 09:08:28 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 09:08:28 [block_adapters.py:494] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 09:08:28 [block_adapters.py:494] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 09:08:28 [cache_adapter.py:134] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 09:08:28 [cache_adapter.py:332] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 09:08:28 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_137794046086992, context_manager: QwenImagePipeline_137800587914240.
INFO 01-18 09:08:28 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
WARNING 01-18 09:08:28 [_attention_dispatch.py:276] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 09:08:28 [_attention_dispatch.py:389] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
WARNING 01-18 09:08:28 [_attention_dispatch.py:444] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 09:08:28 [__init__.py:85] Found attention_backend from config, set attention backend to: native
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:08<00:00,  1.10it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:12<00:00,  2.48s/it]
INFO 01-18 09:08:28 [cache_adapter.py:49] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!
INFO 01-18 09:08:28 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 09:08:28 [block_adapters.py:494] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 09:08:28 [block_adapters.py:494] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 09:08:28 [cache_adapter.py:134] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 09:08:28 [cache_adapter.py:332] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 09:08:28 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_133362747347872, context_manager: QwenImagePipeline_133369519855936.
INFO 01-18 09:08:28 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
WARNING 01-18 09:08:28 [_attention_dispatch.py:276] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 09:08:28 [_attention_dispatch.py:389] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
WARNING 01-18 09:08:28 [_attention_dispatch.py:444] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 09:08:28 [__init__.py:85] Found attention_backend from config, set attention backend to: native
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
INFO 01-18 09:08:28 [parallel_interface.py:48] Enabled parallelism: ParallelismConfig(backend=ParallelismBackend.NATIVE_DIFFUSER, ulysses_size=2, ring_size=None, tp_size=None), transformer id:133362747343744
INFO 01-18 09:08:28 [parallel_interface.py:48] Enabled parallelism: ParallelismConfig(backend=ParallelismBackend.NATIVE_DIFFUSER, ulysses_size=2, ring_size=None, tp_size=None), transformer id:137800583378848
[01-18 09:08:31] Enabled cache-dit for diffusers pipeline
[01-18 09:08:31] Loaded diffusers pipeline: QwenImagePipeline
[01-18 09:08:31] Detected pipeline type: image
[01-18 09:08:31] Pipeline instantiated
[01-18 09:08:31] Worker 0: Initialized device, model, and distributed environment.
[01-18 09:08:31] Worker 0: Scheduler loop started.
[01-18 09:08:31] Processing prompt 1/1: A coffee shop entrance features a chalkboard sign reading "Qwen Coffee $2 per cup," with a neon ligh
[01-18 09:08:31] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
  0%|                                                                                                                            | 0/50 [00:00<?, ?it/s]guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.93it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.93it/s]
[01-18 09:08:49] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
[01-18 09:08:49] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 09:08:49] [DiffusersExecutionStage] finished in 18.3559 seconds
[01-18 09:08:49] Peak GPU memory: 66.38 GB, Remaining GPU memory at peak: 29.21 GB. Components that can stay resident: []
[01-18 09:08:50] Output saved to outputs/cachedit-2gpu-a.png
[01-18 09:08:50] Pixel data generated successfully in 19.05 seconds
[01-18 09:08:50] Completed batch processing. Generated 1 outputs in 19.05 seconds
[01-18 09:08:50] Memory usage - Max peak: 67976.11 MB, Avg peak: 67976.11 MB

Test 2 Baseline

root@C.30179436:/workspace/sglang$ CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --prompt 'A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7" in English, 2) a handwritten Chinese banner reading "欢迎来到未来城市", 3) a chalkboard menu listing "Espresso $2.50, Latte $3.75, Matcha $4.00". Include a small QR code in the corner and a timestamp "2025-01-08 18:30". Ultra-detailed, sharp typography, cinematic lighting, 4K.' \
  --negative-prompt " " \
  --width 1664 \
  --output-file-name baseline-2gpu-b.png
[01-18 09:10:35] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 09:10:35] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "diffusers_attention_backend": null, "cache_dit_config": null, "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30047, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5641, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 09:10:35] Local mode: True
[01-18 09:10:35] Starting server...
[01-18 09:10:44] Scheduler bind at endpoint: tcp://127.0.0.1:5641
[01-18 09:10:44] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 09:10:45] Found nccl from library libnccl.so.2
[01-18 09:10:45] sglang-diffusion is using nccl==2.27.5
[01-18 09:10:46] Found nccl from library libnccl.so.2
[01-18 09:10:46] sglang-diffusion is using nccl==2.27.5
[01-18 09:10:46] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 09:10:46] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 09:10:46] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 09:10:46] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 09:10:46] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                             | 0/5 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:  40%|██████████████████████████████████                                                   | 2/5 [00:00<00:00,  6.10it/s]`torch_dtype` is deprecated! Use `dtype` instead!                                                                                 | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.33it/s]
Loading pipeline components...:  60%|███████████████████████████████████████████████████                                  | 3/5 [00:03<00:01,  1.12it/s]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.27it/s]
Loading pipeline components...:  80%|████████████████████████████████████████████████████████████████████                 | 4/5 [00:03<00:01,  1.09s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.22it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.27s/it]
[01-18 09:10:57] Loaded diffusers pipeline: QwenImagePipeline
[01-18 09:10:57] Detected pipeline type: image
[01-18 09:10:57] Pipeline instantiated
[01-18 09:10:57] Worker 0: Initialized device, model, and distributed environment.
[01-18 09:10:57] Worker 0: Scheduler loop started.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.16it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.37s/it]
[01-18 09:10:58] Processing prompt 1/1: A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7"
[01-18 09:10:58] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
  0%|                                                                                                                            | 0/50 [00:00<?, ?it/s]guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:45<00:00,  1.10it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:45<00:00,  1.09it/s]
[01-18 09:11:45] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
[01-18 09:11:45] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 09:11:45] [DiffusersExecutionStage] finished in 46.9141 seconds
[01-18 09:11:45] Peak GPU memory: 66.22 GB, Remaining GPU memory at peak: 29.38 GB. Components that can stay resident: []
[01-18 09:11:45] Output saved to outputs/baseline-2gpu-b.png
[01-18 09:11:45] Pixel data generated successfully in 47.67 seconds
[01-18 09:11:45] Completed batch processing. Generated 1 outputs in 47.67 seconds
[01-18 09:11:45] Memory usage - Max peak: 67804.28 MB, Avg peak: 67804.28 MB

Test 2 Cache-DiT

root@C.30179436:/workspace/sglang$ CUDA_VISIBLE_DEVICES=0,1 sglang generate \
  --model-path Qwen/Qwen-Image \
  --backend diffusers \
  --num-gpus 2 \
  --sp-degree 2 \
  --ulysses-degree 2 \
  --tp-size 1 \
  --cache-dit-config /tmp/cache_dit_config.yaml \
  --prompt 'A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7" in English, 2) a handwritten Chinese banner reading "欢迎来到未来城市", 3) a chalkboard menu listing "Espresso $2.50, Latte $3.75, Matcha $4.00". Include a small QR code in the corner and a timestamp "2025-01-08 18:30". Ultra-detailed, sharp typography, cinematic lighting, 4K.' \
  --negative-prompt " " \
  --output-file-name cachedit-2gpu-b.pnggth":512}' \
[01-18 09:12:49] Disabling some offloading (except dit, text_encoder) for image generation model
[01-18 09:12:49] server_args: {"model_path": "Qwen/Qwen-Image", "backend": "diffusers", "attention_backend": null, "diffusers_attention_backend": null, "cache_dit_config": "/tmp/cache_dit_config.yaml", "nccl_port": null, "trust_remote_code": false, "revision": null, "num_gpus": 2, "tp_size": 1, "sp_degree": 2, "ulysses_degree": 2, "ring_degree": 1, "dp_size": 1, "dp_degree": 1, "enable_cfg_parallel": false, "hsdp_replicate_dim": 1, "hsdp_shard_dim": 2, "dist_timeout": null, "lora_path": null, "lora_nickname": "default", "vae_path": null, "lora_target_modules": null, "dit_cpu_offload": true, "dit_layerwise_offload": null, "text_encoder_cpu_offload": true, "image_encoder_cpu_offload": false, "vae_cpu_offload": false, "use_fsdp_inference": false, "pin_cpu_memory": true, "mask_strategy_file_path": null, "STA_mode": "STA_inference", "skip_time_steps": 15, "enable_torch_compile": false, "warmup": false, "warmup_resolutions": null, "disable_autocast": true, "VSA_sparsity": 0.0, "moba_config_path": null, "moba_config": {}, "master_port": 30040, "host": "127.0.0.1", "port": 30000, "webui": false, "webui_port": 12312, "scheduler_port": 5644, "prompt_file_path": null, "model_paths": {}, "model_loaded": {"transformer": true, "vae": true}, "boundary_ratio": null, "log_level": "info"}
[01-18 09:12:49] Parsed diffusers_kwargs: {'max_sequence_length': 512}
[01-18 09:12:49] Local mode: True
[01-18 09:12:49] Starting server...
[01-18 09:12:59] Scheduler bind at endpoint: tcp://127.0.0.1:5644
[01-18 09:12:59] Initializing distributed environment with world_size=2, device=cuda:0
[01-18 09:13:01] Found nccl from library libnccl.so.2
[01-18 09:13:01] sglang-diffusion is using nccl==2.27.5
[01-18 09:13:02] Found nccl from library libnccl.so.2
[01-18 09:13:02] sglang-diffusion is using nccl==2.27.5
[01-18 09:13:02] Using diffusers backend for model 'Qwen/Qwen-Image' (explicitly requested)
[01-18 09:13:02] Loading diffusers pipeline from Qwen/Qwen-Image
[01-18 09:13:02] Checking for cached model in HF Hub cache for Qwen/Qwen-Image...
[01-18 09:13:02] Found complete model in cache at /workspace/.hf_home/hub/models--Qwen--Qwen-Image/snapshots/75e0b4be04f60ec59a75f475837eced720f823b6
[01-18 09:13:02] Loading diffusers pipeline with dtype=torch.bfloat16, device_map=cuda
Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                             | 0/5 [00:00<?, ?it/s]Keyword arguments {'trust_remote_code': False} are not expected by QwenImagePipeline and will be ignored.
Loading pipeline components...:   0%|                                                                                             | 0/5 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.40it/s]
Loading pipeline components...:  20%|█████████████████                                                                    | 1/5 [00:02<00:11,  2.98s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.31it/s]
Loading pipeline components...:  60%|███████████████████████████████████████████████████                                  | 3/5 [00:03<00:02,  1.09s/it]The config attributes {'pooled_projection_dim': 768} were passed to QwenImageTransformer2DModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.23it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.24s/it]
INFO 01-18 09:13:13 [cache_adapter.py:49] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!  1.06it/s]
INFO 01-18 09:13:13 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 09:13:13 [block_adapters.py:494] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 09:13:13 [block_adapters.py:494] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 09:13:13 [cache_adapter.py:134] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 09:13:13 [cache_adapter.py:332] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 09:13:13 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_129031992183456, context_manager: QwenImagePipeline_129032127015024.
INFO 01-18 09:13:13 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
WARNING 01-18 09:13:13 [_attention_dispatch.py:276] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 09:13:13 [_attention_dispatch.py:389] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
WARNING 01-18 09:13:13 [_attention_dispatch.py:444] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 09:13:13 [__init__.py:85] Found attention_backend from config, set attention backend to: native
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:07<00:00,  1.18it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:11<00:00,  2.32s/it]
INFO 01-18 09:13:14 [cache_adapter.py:49] QwenImagePipeline is officially supported by cache-dit. Use it's pre-defined BlockAdapter directly!
INFO 01-18 09:13:14 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
INFO 01-18 09:13:14 [block_adapters.py:494] Match Block Forward Pattern: QwenImageTransformerBlock, ForwardPattern.Pattern_1
INFO 01-18 09:13:14 [block_adapters.py:494] IN:('hidden_states', 'encoder_hidden_states'), OUT:('encoder_hidden_states', 'hidden_states'))
INFO 01-18 09:13:14 [cache_adapter.py:134] Use custom 'enable_separate_cfg' from BlockAdapter: True. Pipeline: QwenImagePipeline.
INFO 01-18 09:13:14 [cache_adapter.py:332] Collected Context Config: DBCache_F1B0_W8I2M0MC2_R0.12_CFG1, Calibrator Config: TaylorSeer_O(1)
INFO 01-18 09:13:14 [pattern_base.py:70] Match Blocks: CachedBlocks_Pattern_0_1_2, for transformer_blocks, cache_context: transformer_blocks_129840755869504, context_manager: QwenImagePipeline_129840788642016.
INFO 01-18 09:13:14 [block_adapters.py:220] Auto fill blocks_name: transformer_blocks.
WARNING 01-18 09:13:14 [_attention_dispatch.py:276] Re-registered NATIVE attention backend to enable context parallelism with attn mask in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
INFO 01-18 09:13:14 [_attention_dispatch.py:389] Registered new attention backend: _SDPA_CUDNN to enable context parallelism with attn mask in cache-dit. You can disable it by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
WARNING 01-18 09:13:14 [_attention_dispatch.py:444] Re-registered SAGE attention backend to enable context parallelism with FP8 Attention in cache-dit. You can disable this behavior by: export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0.
Attention backends are an experimental feature and the API may be subject to change.
INFO 01-18 09:13:14 [__init__.py:85] Found attention_backend from config, set attention backend to: native
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
INFO 01-18 09:13:14 [parallel_interface.py:48] Enabled parallelism: ParallelismConfig(backend=ParallelismBackend.NATIVE_DIFFUSER, ulysses_size=2, ring_size=None, tp_size=None), transformer id:129840758619792
INFO 01-18 09:13:14 [parallel_interface.py:48] Enabled parallelism: ParallelismConfig(backend=ParallelismBackend.NATIVE_DIFFUSER, ulysses_size=2, ring_size=None, tp_size=None), transformer id:129031992282912
[01-18 09:13:16] Enabled cache-dit for diffusers pipeline
[01-18 09:13:16] Loaded diffusers pipeline: QwenImagePipeline
[01-18 09:13:16] Detected pipeline type: image
[01-18 09:13:16] Pipeline instantiated
[01-18 09:13:16] Worker 0: Initialized device, model, and distributed environment.
[01-18 09:13:16] Worker 0: Scheduler loop started.
[01-18 09:13:16] Processing prompt 1/1: A modern metro station poster wall with three ads: 1) a neon sign that reads "TOKYO MIDNIGHT - 24/7"
[01-18 09:13:16] [DiffusersExecutionStage] started...
guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
  0%|                                                                                                                            | 0/50 [00:00<?, ?it/s]guidance_scale is passed as 4.0, but ignored since the model is not guidance-distilled.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
[01-18 09:13:34] Extracted output from 'images': shape=torch.Size([1, 3, 928, 1664]), dtype=torch.bfloat16
[01-18 09:13:34] Final output tensor shape: torch.Size([1, 3, 928, 1664])
[01-18 09:13:34] [DiffusersExecutionStage] finished in 17.9508 seconds
[01-18 09:13:34] Peak GPU memory: 66.39 GB, Remaining GPU memory at peak: 29.21 GB. Components that can stay resident: []
[01-18 09:13:35] Output saved to outputs/cachedit-2gpu-b.png
[01-18 09:13:35] Pixel data generated successfully in 18.63 seconds
[01-18 09:13:35] Completed batch processing. Generated 1 outputs in 18.63 seconds
[01-18 09:13:35] Memory usage - Max peak: 67980.80 MB, Avg peak: 67980.80 MB

@DefTruth
Copy link
Copy Markdown
Contributor

@qimcis @mickqian From the comparison of the time consumption in Transformers, I believe the results are as expected because Cache-DiT only affects the Transformer modules. What do you think?

cache-dit v1.1.8

# Test 1 Cache-DiT, Transformer ~ 17s
100%|█████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.93it/s]
100%|█████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.93it/s]
# Test 2 Cache-DiT, Transformer ~ 16s
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]

cache-dit v1.2.0

# Test 1 Cache-DiT, Transformer ~ 17s
100%|█████████████████████████████████████████████████████████|50/50 [00:17<00:00,  2.94it/s]
100%|█████████████████████████████████████████████████████████|50/50 [00:17<00:00,  2.94it/s]
# Test 2 Cache-DiT, Transformer ~ 16s
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]

Copy link
Copy Markdown
Contributor

@DefTruth DefTruth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~

@qimcis
Copy link
Copy Markdown
Contributor Author

qimcis commented Jan 18, 2026

@qimcis @mickqian From the comparison of the time consumption in Transformers, I believe the results are as expected because Cache-DiT only affects the Transformer modules. What do you think?

cache-dit v1.1.8

# Test 1 Cache-DiT, Transformer ~ 17s
100%|█████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.93it/s]
100%|█████████████████████████████████████████████████████████| 50/50 [00:17<00:00,  2.93it/s]
# Test 2 Cache-DiT, Transformer ~ 16s
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]

cache-dit v1.2.0

# Test 1 Cache-DiT, Transformer ~ 17s
100%|█████████████████████████████████████████████████████████|50/50 [00:17<00:00,  2.94it/s]
100%|█████████████████████████████████████████████████████████|50/50 [00:17<00:00,  2.94it/s]
# Test 2 Cache-DiT, Transformer ~ 16s
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]
100%|█████████████████████████████████████████████████████████| 50/50 [00:16<00:00,  2.99it/s]

this looks like it could be the case!

@qimcis qimcis requested a review from mickqian January 18, 2026 13:12
@DefTruth
Copy link
Copy Markdown
Contributor

@mickqian do you have time to take a look? thanks~

@DefTruth
Copy link
Copy Markdown
Contributor

DefTruth commented Jan 21, 2026

@qimcis hi, can you also test the configs w/o parallel text encoder and parallel vae for this latest commit? Just remove extra_parallel_modules: ["text_encoder", "vae"] from config. By the way, don't forget to add the warmup flag --warmup=True. I think the performance difference is also related to the fact that warmup was not added.

Copy link
Copy Markdown
Collaborator

@mickqian mickqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

brilliant! please consider adding tests next

@mickqian
Copy link
Copy Markdown
Collaborator

mickqian commented Jan 21, 2026

/tag-and-rerun-ci

@qimcis
Copy link
Copy Markdown
Contributor Author

qimcis commented Jan 21, 2026

@qimcis hi, can you also test the configs w/o parallel text encoder and parallel vae for this latest commit? Just remove extra_parallel_modules: ["text_encoder", "vae"] from config. By the way, don't forget to add the warmup flag --warmup=True. I think the performance difference is also related to the fact that warmup was not added.

tested most recent commit on cache dit 1.2.0 w/o parallel text encoder and parallel vae, with added --warmup True flag:

[01-21 04:49:00] Pixel data generated successfully in 45.62 seconds
[01-21 04:49:00] Completed batch processing. Generated 1 outputs in 45.62 seconds
[01-21 04:49:00] Warmed-up request processed in 41.79 seconds (with warmup excluded)
[01-21 04:49:00] Memory usage - Max peak: 67977.43 MB, Avg peak: 67977.43 MB

Signed-off-by: Chi <chixie.mcisaac@gmail.com>
Signed-off-by: qimcis <chixie.mcisaac@gmail.com>
@mickqian mickqian force-pushed the cache-dit-diffusers branch from e2bec4d to 073f328 Compare January 21, 2026 13:20
@qimcis
Copy link
Copy Markdown
Contributor Author

qimcis commented Jan 21, 2026

brilliant! please consider adding tests next

should i add tests in this pr? or make a separate pr for these

@mickqian
Copy link
Copy Markdown
Collaborator

@qimcis you could do it later

@mickqian
Copy link
Copy Markdown
Collaborator

diffusion tests and lint all passed, the failed srt tests are irrelevant

@mickqian mickqian merged commit 71482dd into sgl-project:main Jan 22, 2026
280 of 307 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
…nd (sgl-project#16662)

Signed-off-by: Chi <chixie.mcisaac@gmail.com>
Signed-off-by: qimcis <chixie.mcisaac@gmail.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
@qimcis qimcis mentioned this pull request Feb 24, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file diffusion SGLang Diffusion documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants