Skip to content

[Feature] Support vae tiling parallel encode#2368

Merged
gcanlin merged 15 commits into
vllm-project:mainfrom
gcanlin:vae-encode-parallel
Apr 6, 2026
Merged

[Feature] Support vae tiling parallel encode#2368
gcanlin merged 15 commits into
vllm-project:mainfrom
gcanlin:vae-encode-parallel

Conversation

@gcanlin
Copy link
Copy Markdown
Collaborator

@gcanlin gcanlin commented Mar 31, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Add support for vae encode parallel.

Test Plan

Run Wan2.2 I2V end-to-end on 8 NPUs with Ulysses enabled:

python image_to_video.py \
  --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \
  --image cherry_blossom.jpg \
  --prompt "Cherry blossoms swaying gently in the breeze, petals falling, smooth motion" \
  --negative-prompt "<optional quality filter>" \
  --height 512 \
  --width 768 \
  --num-frames 49 \
  --guidance-scale 4.0 \
  --num-inference-steps 20 \
  --flow-shift 12.0 \
  --fps 16 \
  --output i2v_output.mp4 \
  --enable-layerwise-offload \
  --ulysses-degree 8 \
  --vae-patch-parallel-size 8 \
  --vae-use-tiling

Test Results

Config Time
without vae parallel 52s
only vae decode parallel 47s
vae decode parallel + vae encode parallel 43s

Main:

wan22_main.mp4

This PR:

wan22_vae_encode_parallel.mp4

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin requested a review from hsliuustc0106 as a code owner March 31, 2026 09:37
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5694af1067

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +116 to +117
for i in range(0, height, self.tile_sample_stride_height):
for j in range(0, width, self.tile_sample_stride_width):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Scale tiled-encode split coordinates after patchify

When config.patch_size is set, _encode_distributed patchifies x before dispatching tile work, but encode_tile_split still iterates with self.tile_sample_stride_* in pre-patch coordinates. On patchified inputs this makes tiles much larger than intended (and far fewer), so patch-parallel encode loses most of its parallelism and uses incorrect overlap geometry for seam blending on patchified Wan checkpoints. The decode path already applies patch-size-aware scaling, so encode should mirror that coordinate system.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

any accuracy test?

gcanlin added 3 commits April 1, 2026 09:38
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 1, 2026

@Bounty-hunter @wtomin Please help review. Thanks!

@wtomin
Copy link
Copy Markdown
Collaborator

wtomin commented Apr 1, 2026

  1. Please update .\docs\design\feature\vae_parallel.md and .\docs\user_guide\diffusion\parallelism\vae_patch_parallel.md (if applied);
  2. Can you verify that distributed encode also work for image-to-image models? Give some examples pls.

Comment thread vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py Outdated
Comment thread tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py
gcanlin added 2 commits April 1, 2026 15:11
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 1, 2026

  1. Please update .\docs\design\feature\vae_parallel.md and .\docs\user_guide\diffusion\parallelism\vae_patch_parallel.md (if applied);
  2. Can you verify that distributed encode also work for image-to-image models? Give some examples pls.

For 1, I have added the doc for vae encode parallel.

For 2, I think it's better to implement it in a follow-up PR. Will request help from community.

@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 1, 2026

any accuracy test?

I will enable nightly-test to cover GPU accuracy test. And for NPU, I will execute the same test locally and paste the result.

@gcanlin gcanlin added the nightly-test label to trigger buildkite nightly test CI label Apr 1, 2026
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 2, 2026

>       assert ssim_score >= SSIM_THRESHOLD, (
            f"SSIM below threshold: got {ssim_score:.6f}, expected >= {SSIM_THRESHOLD:.6f}. "
            f"online={online_path} offline={offline_path}"
        )
E       AssertionError: SSIM below threshold: got 0.922084, expected >= 0.940000. online=/workspace/build/buildkite/tests/e2e/accuracy/wan22_i2v/result/rabbit-cf925a4c/online.mp4 offline=/workspace/build/buildkite/tests/e2e/accuracy/wan22_i2v/result/rabbit-cf925a4c/offline.mp4
E       assert 0.922084 >= 0.94
tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py:668: AssertionError

This PR leads to accuracy regression currently. I will try to fix it.

Comment thread vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py Outdated
@wtomin
Copy link
Copy Markdown
Collaborator

wtomin commented Apr 2, 2026

This PR leads to accuracy regression currently. I will try to fix it.

I think it is expected. Previously vae does not use tiled_encode, now you are using tiled_encode in parallel mode, is that right?

@hsliuustc0106 hsliuustc0106 added the ready label to trigger buildkite CI label Apr 2, 2026
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a few comments, mostly minor

Comment thread vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py Outdated
Comment thread vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py Outdated
Comment thread vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl_wan.py
@david6666666
Copy link
Copy Markdown
Collaborator

>       assert ssim_score >= SSIM_THRESHOLD, (
            f"SSIM below threshold: got {ssim_score:.6f}, expected >= {SSIM_THRESHOLD:.6f}. "
            f"online={online_path} offline={offline_path}"
        )
E       AssertionError: SSIM below threshold: got 0.922084, expected >= 0.940000. online=/workspace/build/buildkite/tests/e2e/accuracy/wan22_i2v/result/rabbit-cf925a4c/online.mp4 offline=/workspace/build/buildkite/tests/e2e/accuracy/wan22_i2v/result/rabbit-cf925a4c/offline.mp4
E       assert 0.922084 >= 0.94
tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py:668: AssertionError

This PR leads to accuracy regression currently. I will try to fix it.

Has this accuracy issue been resolved?

@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 3, 2026

Has this accuracy issue been resolved?

I found that the current nightly accuracy test is only for the comparison between offline and online inference but not main and this PR. So it's wired. It looks more like I didn't align offline and online parallel config in the test.

@david6666666
Copy link
Copy Markdown
Collaborator

maybe we can add -buildkite-agent artifact upload to see generated videos and judged by human.

Comment thread tests/diffusion/distributed/test_autoencoder_kl_wan_encode.py Outdated
Comment thread docs/design/feature/vae_parallel.md Outdated
Copy link
Copy Markdown
Collaborator

@wtomin wtomin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

gcanlin added 5 commits April 6, 2026 13:23
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin enabled auto-merge (squash) April 6, 2026 13:39
@gcanlin
Copy link
Copy Markdown
Collaborator Author

gcanlin commented Apr 6, 2026

will this help other models?

I think yes. But other models also need to do some minor adaption.

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
@gcanlin gcanlin merged commit e771842 into vllm-project:main Apr 6, 2026
7 of 8 checks passed
david6666666 pushed a commit to david6666666/vllm-omni that referenced this pull request Apr 7, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
(cherry picked from commit e771842)
Signed-off-by: David Chen <530634352@qq.com>
david6666666 added a commit that referenced this pull request Apr 7, 2026
)

Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: David Chen <530634352@qq.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
skf-1999 pushed a commit to Semmer2/vllm-omni that referenced this pull request Apr 7, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
vraiti pushed a commit to vraiti/vllm-omni that referenced this pull request Apr 9, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for both encode and decode paths
- encode: encode_tile_split/exec/merge + _encode override with broadcast_result=True
- decode: tile_split/exec/merge + decode override with broadcast_result=False
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines

Encode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for both encode and decode paths
- encode: encode_tile_split/exec/merge + _encode override with broadcast_result=True
- decode: tile_split/exec/merge + decode override with broadcast_result=False
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines

Encode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for both encode and decode paths
- encode: encode_tile_split/exec/merge + _encode override with broadcast_result=True
- decode: tile_split/exec/merge + decode override with broadcast_result=False
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines

Encode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for both encode and decode paths
- encode: encode_tile_split/exec/merge + _encode override with broadcast_result=True
- decode: tile_split/exec/merge + decode override with broadcast_result=False
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines

Encode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for distributed decode
- decode: tile_split/exec/merge + decode override with broadcast_result=False
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines

Decode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for distributed decode
- decode: tile_split/exec/merge always used (small latents included, consistent
  with base class behavior; base class patch_split is incompatible with 5D latents)
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines
- Add unit tests for tile_split, tile_merge, and blend logic

Decode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for distributed decode
- decode: tile_split/exec/merge always used (small latents included, consistent
  with base class behavior; base class patch_split is incompatible with 5D latents)
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines
- Add unit tests for tile_split, tile_merge, and blend logic

Decode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
daixinning pushed a commit to daixinning/vllm-omni that referenced this pull request Apr 13, 2026
- Add DistributedAutoencoderKLHunyuanVideo with overlapping tile split/exec/merge
  for distributed decode
- decode: tile_split/exec/merge always used (small latents included, consistent
  with base class behavior; base class patch_split is incompatible with 5D latents)
- Fix distributed_vae_executor: use self.rank < pp_size instead of pp_size <= self.world_size
- Wire DistributedAutoencoderKLHunyuanVideo into T2V and I2V pipelines
- Add unit tests for tile_split, tile_merge, and blend logic

Decode parallel follows the pattern from vllm-project#2368 (Wan VAE encode parallel).

Signed-off-by: daixinning <daixinning@163.com>
bob-021206 pushed a commit to jasonlee-1024/vllm-omni that referenced this pull request Apr 21, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Signed-off-by: bob-021206 <binyan_github@163.com>
lengrongfu pushed a commit to lengrongfu/vllm-omni that referenced this pull request May 1, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
clodaghwalsh17 pushed a commit to clodaghwalsh17/nm-vllm-omni-ent that referenced this pull request May 12, 2026
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants