Skip to content

[Feature]Adding decode vae patch parallel supports for LTX-2#2135

Open
erfgss wants to merge 17 commits into
vllm-project:mainfrom
erfgss:videoGen_vae
Open

[Feature]Adding decode vae patch parallel supports for LTX-2#2135
erfgss wants to merge 17 commits into
vllm-project:mainfrom
erfgss:videoGen_vae

Conversation

@erfgss
Copy link
Copy Markdown
Contributor

@erfgss erfgss commented Mar 24, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

This PR adds support for VAE patch parallelism in the LTX-2 text-to-video pipeline.

By enabling distributed VAE decoding when --vae-patch-parallel-size > 1, this change improves multi-GPU utilization and reduces VAE decode latency for video generation workloads.

Test Plan

  • Use LTX-2 text-to-video inference
  • Set tensor-parallel-size=2
  • Enable --vae-use-tiling
  • Run with:
    • --vae-patch-parallel-size=1
    • --vae-patch-parallel-size=2
  • Compare VAE decode time reported in logs

Test Result

Model Task Tensor Parallel Size VAE Patch Parallel Size VAE Decode (ms)
LTX-2 text-to-video 2 1 11474.19
LTX-2 text-to-video 2 2 293.67
LTX-2 image-to-video 2 1 730.73
LTX-2 image-to-video 2 2 397.07

text-to-video VAE Patch Parallel Size=1

ltx2_t2v_diffvae1.mp4

text-to-video VAE Patch Parallel Size=2

ltx2_t2v_diffvae2.mp4

image-to-video VAE Patch Parallel Size=1

ltx2_i2v_vae1.mp4

image-to-video VAE Patch Parallel Size=2

ltx2_i2v_vae2.mp4

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Chen Yang <2082464740@qq.com>
@erfgss erfgss requested a review from hsliuustc0106 as a code owner March 24, 2026 12:07
@erfgss
Copy link
Copy Markdown
Contributor Author

erfgss commented Mar 24, 2026

@david6666666

@david6666666
Copy link
Copy Markdown
Collaborator

add unit test and output video comparison

Signed-off-by: Chen Yang <2082464740@qq.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c98eb84932

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +122 to +124
timestep: torch.Tensor | None = None,
return_dict: bool = True,
*args: Any,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve base decode argument order for causal flag

AutoencoderKLLTX2Video.decode takes (z, temb=None, causal=None, return_dict=True), but this override changes the positional order to (z, timestep=None, return_dict=True, ...). Any caller that passes causal positionally (for example decode(z, temb, False)) will now set return_dict=False instead, silently changing the return type and leaving causal unset. This is a behavioral regression in the public method contract and can break wrappers that rely on the original positional API.

Useful? React with 👍 / 👎.

@david6666666 david6666666 changed the title [Feature]Adding vae patch parallel supports for VideoGen [Feature]Adding vae patch parallel supports for LTX-2 Mar 24, 2026
@david6666666
Copy link
Copy Markdown
Collaborator

update LTX-2 image-to-video also, and should update vllm-omni/docs/user_guide/diffusion_acceleration.md

@erfgss
Copy link
Copy Markdown
Contributor Author

erfgss commented Mar 24, 2026

update LTX-2 image-to-video also, and should update vllm-omni/docs/user_guide/diffusion_acceleration.md

I will update these

erfgss added 5 commits March 25, 2026 08:37
Signed-off-by: erfgss <97771661+erfgss@users.noreply.github.com>
Signed-off-by: erfgss <97771661+erfgss@users.noreply.github.com>
@david6666666
Copy link
Copy Markdown
Collaborator

@Bounty-hunter ptal thx

) -> torch.Tensor:
"""Decode a single latent tile into video space."""
tile = task.tensor
if hasattr(self, "clear_cache"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

@erfgss erfgss Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dec = torch.clamp(dec, min=-1.0, max=1.0)
return dec

def patch_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you evaluated the performance gain from patch splitting? When the height and width are small, the splited size (+blend) is almost equal to the total size, so the performance improvement may be limited? In this scenario, using temporal tiled decode parallel might be a better choice? https://github.com/huggingface/diffusers/blob/f2be8bd6b3dc4035bd989dc467f15d86bf3c9c12/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py#L1497

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you evaluated the performance gain from patch splitting? When the height and width are small, the splited size (+blend) is almost equal to the total size, so the performance improvement may be limited? In this scenario, using temporal tiled decode parallel might be a better choice? https://github.com/huggingface/diffusers/blob/f2be8bd6b3dc4035bd989dc467f15d86bf3c9c12/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py#L1497

When 24 frames of video are generated, temporal tiled decoding does not bring obvious gains, but instead increases the overhead.

@wtomin
Copy link
Copy Markdown
Collaborator

wtomin commented Mar 30, 2026

A recent PR changed the diffusion features docs strucure. Pls PTAL #1928.

erfgss and others added 3 commits March 31, 2026 10:13
Signed-off-by: Chen Yang <2082464740@qq.com>
Signed-off-by: Chen Yang <2082464740@qq.com>
@wtomin
Copy link
Copy Markdown
Collaborator

wtomin commented Apr 2, 2026

@erfgss Can you help to create a L4 e2e test for LTX2 model, covering the existing diffusion features supported (See #1217). As for how to create a L4 e2e test, please refer to #1832 .

Please update the document docs/user_guide/diffusion_features.md

@erfgss
Copy link
Copy Markdown
Contributor Author

erfgss commented Apr 2, 2026

@erfgss Can you help to create a L4 e2e test for LTX2 model, covering the existing diffusion features supported (See #1217). As for how to create a L4 e2e test, please refer to #1832 .

OK,I can do this

erfgss and others added 3 commits April 3, 2026 11:24
Signed-off-by: Chen Yang <2082464740@qq.com>
Signed-off-by: Chen Yang <2082464740@qq.com>
Copy link
Copy Markdown
Collaborator

@wtomin wtomin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Pls resolve the conflicts.

@david6666666
Copy link
Copy Markdown
Collaborator

Follow up pr can refer to #2368

erfgss added 2 commits April 8, 2026 22:09
Signed-off-by: erfgss <97771661+erfgss@users.noreply.github.com>
@erfgss erfgss changed the title [Feature]Adding vae patch parallel supports for LTX-2 [Feature]Adding decode vae patch parallel supports for LTX-2 Apr 9, 2026
erfgss added 2 commits April 13, 2026 09:20
Signed-off-by: erfgss <97771661+erfgss@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants