Skip to content

[Diffusion] Refactor CFG parallel for extensibility and performance#2063

Merged
princepride merged 6 commits intovllm-project:mainfrom
TKONIY:cfg-parallel-extensibility
Mar 30, 2026
Merged

[Diffusion] Refactor CFG parallel for extensibility and performance#2063
princepride merged 6 commits intovllm-project:mainfrom
TKONIY:cfg-parallel-extensibility

Conversation

@TKONIY
Copy link
Copy Markdown
Contributor

@TKONIY TKONIY commented Mar 21, 2026

Summary

Refactor CFGParallelMixin to improve performance, reduce code complexity, and enable extensibility for multi-output models (e.g., world models producing both video and action predictions).

Problem

Current CFGParallelMixin has three issues:

  1. Redundant communication: Every denoising step does all_gather + broadcast. After all_gather, both ranks already have identical positive and negative predictions — the broadcast in scheduler_step_maybe_with_cfg is unnecessary.

  2. Code complexity: predict_noise_maybe_with_cfg returns valid result only on rank 0, None on rank 1. scheduler_step_maybe_with_cfg has rank-0-only step + broadcast. All downstream code must handle None/rank-branching.

  3. Limited extensibility: predict_noise_maybe_with_cfg assumes single tensor output. Models with multi-output (e.g., DreamZero world model returning (video_pred, action_pred)) cannot use the mixin without overriding the entire method.

Changes

Change 1: Remove redundant broadcast

  • After all_gather, all ranks compute CFG combine + scheduler step locally
  • Eliminates one broadcast per denoising step
  • predict_noise_maybe_with_cfg returns valid tensor on all ranks (not just rank 0)
  • scheduler_step_maybe_with_cfg simplified: just calls scheduler_step on all ranks

Change 2: Tuple output support via combine_cfg_noise override

  • predict_noise() can now return tuple[Tensor, ...] for multi-output models
  • Mixin internally normalizes to tuple via _wrap/_unwrap helpers, processes uniformly
  • combine_cfg_noise() accepts tuple, applies CFG formula per-element by default
  • Multi-output models override combine_cfg_noise() for custom per-element logic:
class MyWorldModelPipeline(nn.Module, CFGParallelMixin):
    def predict_noise(self, **kwargs):
        video_pred, action_pred = self.transformer(**kwargs)
        return (video_pred, action_pred)

    def combine_cfg_noise(self, positive_noise_pred, negative_noise_pred, scale, normalize):
        (video_pos, action_pos) = positive_noise_pred
        (video_neg, action_neg) = negative_noise_pred
        (video_combined,) = super().combine_cfg_noise((video_pos,), (video_neg,), scale, normalize)
        return (video_combined, action_pos)  # action: positive only, no CFG

Correctness Argument

After all_gather, both ranks have bit-identical copies of positive/negative predictions. CFG combine and scheduler step are deterministic (pure arithmetic, no randomness). Same inputs + same ops = same results on both ranks.

Benchmark

Setup:

  • Model: Qwen/Qwen-Image-2512 (25B)
  • Hardware: NVIDIA RTX PRO 6000 Blackwell (98GB VRAM), PCIe Gen5 x16, PIX topology
  • Parameters: cfg_scale=4.0, steps=20, 1024x1024, seed=142
  • Main branch: cb6d012d (latest origin/main)

Performance

Test Branch CFG Parallel GPUs Total Time (ms) vs main
1 main Off 1 9823
2 main On (size=2) 2 5284
3 refactor Off 1 9797 -0.3%
4 refactor On (size=2) 2 5205 -1.5%

No performance regression. Broadcast cost is negligible on this hardware/model (PCIe Gen5, ~0.1ms vs ~250ms transformer forward per step). The value of this PR is code simplification and extensibility, not raw speed.

Correctness

Comparison MD5 Match
baseline_no_cfg vs refactor_no_cfg Identical (972a8c57)
baseline_cfg_parallel vs refactor_cfg_parallel Identical (c218a5aa)

Bit-identical outputs confirm the refactor produces exactly the same results.

Test plan

  • Benchmark: Qwen-Image-2512 (25B) on 1 GPU and 2 GPU CFG parallel — no regression
  • Correctness: bit-identical images (MD5 match) between main and refactor
  • Unit test: tests/diffusion/distributed/test_cfg_parallel.py now verifies all CFG ranks receive the same combined output and match the sequential CFG baseline
  • Run existing CI tests

Next step

  • After this PR is merged, refactor LTX2 and LTX2 I2V (video + audio) with the refactor CFG parallel framework.

@TKONIY TKONIY force-pushed the cfg-parallel-extensibility branch from 6691f5d to c40e4a8 Compare March 21, 2026 15:55
@TKONIY TKONIY mentioned this pull request Mar 21, 2026
20 tasks
@TKONIY TKONIY force-pushed the cfg-parallel-extensibility branch 16 times, most recently from 936401e to 7f3b059 Compare March 22, 2026 21:57
@TKONIY TKONIY marked this pull request as ready for review March 22, 2026 21:57
@TKONIY TKONIY requested a review from hsliuustc0106 as a code owner March 22, 2026 21:57
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7f3b05931d

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py Outdated
Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py
@TKONIY TKONIY marked this pull request as draft March 22, 2026 22:05
@TKONIY TKONIY force-pushed the cfg-parallel-extensibility branch 8 times, most recently from 042f6d3 to dfed6ba Compare March 22, 2026 23:28
@wtomin
Copy link
Copy Markdown
Collaborator

wtomin commented Mar 24, 2026

I agree with philosophy of this refactoring PR, as it brings better extensibility.

CFG combine and scheduler step are deterministic (pure arithmetic, no randomness).

I am afraid some diffusion schedulers are not deterministic, for example, DDPM solver. Although in vLLM-Omni, most of diffusion models use deterministic schedulers, like flow matching.

Therefore, I suggest to better handle the randomness of scheduler_step_maybe_with_cfg

Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py
@wtomin
Copy link
Copy Markdown
Collaborator

wtomin commented Mar 25, 2026

Bagel, GLM-Image, dreamid-omni, nextstep1.1, and LTX-2 do not inherit from CFGParallelMixmin, but implement CFG Parallel from scratch. Please check if your PR affect them.

Comment thread docs/design/feature/cfg_parallel.md Outdated
Comment thread docs/design/feature/cfg_parallel.md
@TKONIY
Copy link
Copy Markdown
Contributor Author

TKONIY commented Mar 25, 2026

Bagel, GLM-Image, dreamid-omni, nextstep1.1, and LTX-2 do not inherit from CFGParallelMixmin, but implement CFG Parallel from scratch. Please check if your PR affect them.

Thanks for the thorough review! Here's what's been addressed:

1. Non-deterministic scheduler (generator support)

Done in a5f2de3dscheduler_step() and scheduler_step_maybe_with_cfg() now accept an optional generator param, passed through to the scheduler's step() call.

2 & 3. Doc updates (d53df5a2)

  • Added generator=None to the composite scheduler example (VideoAudioScheduler.step()) and the diffuse() call site
  • Added a note about explicitly passing generator=torch.Generator(device).manual_seed(seed) for non-deterministic schedulers like DDPM — adopted your suggested wording

4. Impact on Bagel / GLM-Image / DreamID-Omni / NextStep 1.1 / LTX-2

Checked all five:

Model Inherits Mixin? Affected?
Bagel No No — fully independent CFG impl
GLM-Image No No — uses parallel_state directly
NextStep 1.1 No No — uses parallel_state directly
DreamID-Omni Yes No — only calls combine_cfg_noise() with plain tensors; _wrap/_unwrap preserves backward compat
LTX-2 Yes No — same as above; uses its own predict_noise_av_maybe_with_cfg

DreamID-Omni and LTX-2 inherit CFGParallelMixin but only use combine_cfg_noise(), which now handles both plain Tensor and tuple[Tensor, ...] transparently. Verified with standalone unit tests covering both input types.

btw, I have an on-going draft PR refactoring LTX-2 on the new cfg_parallel.

Copy link
Copy Markdown
Collaborator

@princepride princepride left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PTAL.

Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py Outdated
Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py Outdated
Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py Outdated
Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py Outdated
Comment thread vllm_omni/diffusion/distributed/cfg_parallel.py Outdated
@TKONIY TKONIY force-pushed the cfg-parallel-extensibility branch from 9d2bdf5 to 5fa1e5f Compare March 27, 2026 11:03
@TKONIY TKONIY requested a review from princepride March 27, 2026 21:10
@nussejzz
Copy link
Copy Markdown
Contributor

Bagel, GLM-Image, dreamid-omni, nextstep1.1, and LTX-2 do not inherit from CFGParallelMixmin, but implement CFG Parallel from scratch. Please check if your PR affect them.

@wtomin This shouldn't affect bagel, because that's how I implemented it. #1695

TKONIY and others added 6 commits March 28, 2026 19:23
Key changes to CFGParallelMixin:

1. Remove redundant broadcast in scheduler_step_maybe_with_cfg:
   After all_gather, all ranks already have identical positive and
   negative predictions. CFG combine and scheduler step are deterministic,
   so all ranks compute locally with identical results. The broadcast
   in scheduler_step_maybe_with_cfg is eliminated.

2. All ranks return valid results from predict_noise_maybe_with_cfg:
   Previously rank 0 returned the combined prediction and rank 1
   returned None. Now all ranks compute combine locally and return
   valid tensors. This simplifies downstream code.

3. Tuple output support via combine_cfg_noise override:
   predict_noise() can now return tuple[Tensor, ...] for multi-output
   models. The mixin internally normalizes via _wrap/_unwrap helpers.
   combine_cfg_noise() accepts tuple and applies CFG per-element.
   Multi-output models override combine_cfg_noise() for custom logic
   (e.g., CFG on video, positive-only on action).

Backward compatible: all existing callers use default parameters
and see no behavior change. Bit-identical outputs verified on
Qwen-Image-2512 (25B).

Signed-off-by: Yangshen Deng <yangshen.d@outlook.com>
…uper() calls

Fixes the documented pattern where super().combine_cfg_noise((tensor,), ...)
would return a plain tensor via _unwrap, causing tuple unpacking to fail.
Updated examples pass plain tensors to super() and receive plain tensors back.

Signed-off-by: Yangshen Deng <yangshen.d@outlook.com>
…ic scheduler support

Address review feedback from wtomin: after removing the broadcast in
scheduler_step_maybe_with_cfg, non-deterministic schedulers (e.g., DDPM)
would diverge across CFG parallel ranks. Add generator parameter to
scheduler_step() and scheduler_step_maybe_with_cfg() so callers can pass
a seeded generator to keep both ranks in sync.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yangshen Deng <yangshen.d@outlook.com>
Signed-off-by: Yangshen Deng <yangshen.d@outlook.com>
…rministic scheduler note

Address review feedback: update VideoAudioScheduler.step() signature
to include generator parameter, pass it through to sub-schedulers,
and add a note about using explicit generators for non-deterministic
schedulers (e.g., DDPM) in CFG parallel mode.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yangshen Deng <yangshen.d@outlook.com>
- Remove outer _unwrap() around combine_cfg_noise() calls in both
  CFG-parallel and sequential paths. combine_cfg_noise() already
  unwraps internally, so the double _unwrap stripped the batch
  dimension when batch_size == 1 for single-output models.
- Only pass generator kwarg to sched.step() when generator is not
  None, avoiding TypeError for schedulers without that parameter.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yangshen Deng <yangshen.d@outlook.com>
@TKONIY TKONIY force-pushed the cfg-parallel-extensibility branch from ba8f3a3 to c27b8ef Compare March 28, 2026 19:25
Copy link
Copy Markdown
Collaborator

@wtomin wtomin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@princepride princepride enabled auto-merge (squash) March 30, 2026 16:26
Copy link
Copy Markdown
Collaborator

@princepride princepride left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@princepride princepride added the ready label to trigger buildkite CI label Mar 30, 2026
@princepride princepride merged commit 837679d into vllm-project:main Mar 30, 2026
7 of 8 checks passed
vraiti pushed a commit to vraiti/vllm-omni that referenced this pull request Apr 9, 2026
…llm-project#2063)

Signed-off-by: Yangshen Deng <yangshen.d@outlook.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants