Skip to content

[Enhancement] Introduce High-Performance MoT (Mixture-of-Tokens) Kernels: Triton Implementation & A100 Tuning#1328

Open
yzhu802 wants to merge 7 commits into
vllm-project:mainfrom
yzhu802:enhancement/mot_fused_kernels
Open

[Enhancement] Introduce High-Performance MoT (Mixture-of-Tokens) Kernels: Triton Implementation & A100 Tuning#1328
yzhu802 wants to merge 7 commits into
vllm-project:mainfrom
yzhu802:enhancement/mot_fused_kernels

Conversation

@yzhu802
Copy link
Copy Markdown

@yzhu802 yzhu802 commented Feb 11, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

This PR introduces high-performance Mixture-of-Tokens (MoT) core operator primitives for vLLM-Omni, including MoTLinear and MoTRMSNorm. These operators are developed using Triton, following the proposals in RFC #1146 and #936.

Design

1) Overall Architecture
Adheres to the hierarchical design of vLLM Triton operators: Triton kernels -> Layer encapsulation -> Model invocation. The implementations for both the operator layer and layer encapsulation are centralized within the vllm_omni/diffusion/layers/mot/ directory.

2) Quantization-Friendly Interface Design

Operator-Level Separation: The MoT routing logic and GEMM logic are decoupled into separate kernels. If future frameworks implement quantization logic using Triton, these kernels can be easily swapped out.

Layer Encapsulation Separation: Each layer wrapper implements a fallback method to the native vLLM GEMM backend. If a future CUDA-based quantized GEMM backend outperforms the quantized MoTLinear, this fallback mechanism can be directly utilized.

3) Hardware/Model-Specific Tuning Mechanism
Automated Static Tuning: The script benchmarks/kernels/mot_linear_benchmarks.py provides two modes: benchmarking and tuning. Launching the script with --tune triggers the tuning mode, which generates the required compilation parameters for the current hardware and model. To accelerate the search for optimal parameter combinations, it leverages Ray for parallel computation in multi-GPU environments and implements a pruning strategy specifically designed for NVIDIA GPUs.

Note: This script has not been tested on ROCm architectures or other NPUs. Currently, running it in non-CUDA environments will raise a NotImplementedError().

Automated Tuning Parameter Loading: For CUDA architecture hardware, the framework automatically loads the optimal parameters generated by the tuning script. If the specific configuration file for the current hardware is missing, it will explicitly warn the user in the terminal to perform tuning, otherwise performance will degrade significantly. In this fallback scenario, the Triton operator uses a sub-optimal configuration combination that prevents compilation errors or register spilling on lower-end GPUs (e.g., T4) but yields poor performance on high-end GPUs (e.g., H100/A100).

Note: For non-CUDA architectures, it defaults to falling back to the original vLLM linear layer implementation. This can be replaced by diffusion-specific linear layers in the future.

Test Plan

Unit Testing: Layer encapsulation unit tests are located in vllm-omni/tests/diffusion/kernels/mot. These test the acceleration and accuracy against three baselines across sequence lengths from 1k to 8k:

        MoTRMSNorm vs. vLLM's RMSNorm + index_select + index_scatter.

        MoTQKVParallelLinear vs. QKVParallelLinear + index_select + index_scatter.

        MoTRowParallelLinear vs. RowParallelLinear + index_select + index_scatter.

End-to-End (E2E) Testing: Follows the existing benchmarks/diffusion/diffusion_benchmark_serving.py. A minor modification was added allowing the caller to optionally save the generated images (saving is only triggered if --save_dir is passed, otherwise the original behavior is preserved). This allows for simultaneous evaluation of E2E generation quality and speed.

Test Result

Unit Tests

Correctness for MoTRMSNorm (All Passed):

For baseline values <1: Threshold is atol=5e-2 and cos_sim > 0.99.

For baseline values >1: Threshold is rtol=5e-2 and cos_sim > 0.99.

Correctness for MoTQKVParallel & MoTRowParallel (All Passed):

For baseline values <1: Threshold is atol=1e-1 and cos_sim > 0.98.

For baseline values >1: Threshold is rtol=1e-1 and cos_sim > 0.98.

Performance for MoTRMSNorm:
Since RMSNorm is a memory-bound operation, reducing memory access overhead significantly accelerates computation. Achieved a 5.52x ~ 1.39x speedup across 1k-8k sequence lengths.

Performance for MoTQKVParallel & MoTRowParallel:
These linear layers are compute-bound but have relatively narrow matrix widths.

MoTQKVParallel: Achieved a 1.51x ~ 1.25x speedup.

MoTRowParallel: Achieved a 1.12x ~ 1.52x speedup.

Note: FFN layers are explicitly compute-bound. Fused operators only show improvements at short sequence lengths. For sequences longer than 4096, performance regresses compared to the native CuBLAS implementation. Therefore, this PR does not introduce MoT operators for the FFN component; the model layer retains the original vLLM linear layer implementation.

End-to-End (E2E) Tests

Overall E2E Speedup: ~1.29x

The framework was launched on a single A100 and tested using the following command:

python3 benchmarks/diffusion/diffusion_benchmark_serving.py \
  --base-url http://localhost:8099 \
  --model ByteDance-Seed/BAGEL-7B-MoT \
  --task t2i \
  --dataset vbench \
  --num-prompts 5 \
  --warmup-num-inference-steps 2

Test Group Results (Average of 3 runs):

================= Serving Benchmark Result =================
Backend:                                 vllm-omni      
Model:                                   ByteDance-Seed/BAGEL-7B-MoT
Dataset:                                 vbench         
Task:                                    t2i            
--------------------------------------------------
Benchmark duration (s):                  129.81 ± 2.28         
Request rate:                            inf            
Max request concurrency:                 1              
Successful requests:                     5/5              
--------------------------------------------------
Request throughput (req/s):              0.04 ± 0.00              
============================================================

Control Group Results (Average of 3 runs):

================= Serving Benchmark Result =================
Backend:                                 vllm-omni      
Model:                                   ByteDance-Seed/BAGEL-7B-MoT
Dataset:                                 vbench         
Task:                                    t2i            
--------------------------------------------------
Benchmark duration (s):                  166.97 ± 4.43         
Request rate:                            inf            
Max request concurrency:                 1              
Successful requests:                     5/5              
--------------------------------------------------
Request throughput (req/s):              0.03 ± 0.00          
============================================================

generated image according to prompt "a person giving a presentation"
image

other info mentioned in vLLM-Omni Meeting 2026-03-06 11:30 AM (UTC+8)
MoT kernel PR-presentation.pdf

Supplementary Experimental Notes

For some cases in pixel-level testing where the maximum RGB value difference reached 12 (exceeding the tolerance of 10).
To ensure that the Triton kernel numerical drift does not affect the end-to-end generation quality, I have conducted a rigorous visual fidelity test comparing this branch against the stable baseline.

Environment & Reproducibility

Test Command:

python3 vllm-omni/examples/offline_inference/bagel/end2end.py \
--modality text2img  \
--model /root/data/hf_cache/ByteDance-Seed/BAGEL-7B-MoT \
--txt-prompts /root/data/prompts.txt \
--output /root/data/triton_e2e_result/0423-main_branch 

Test Prompts:

A cute dog
A cyberpunk city at night
A watercolor mountain landscape
A chef is cooking
A woman is giving a presentation in a meeting
Baseline This PR
image image
image image
image image
image image
image image

Test Prompts:

python3 vllm-omni/examples/offline_inference/bagel/end2end.py \
--modality img2img \
--model /root/data/hf_cache/ByteDance-Seed/BAGEL-7B-MoT \
--image-path /root/data/triton_e2e_result/tree.jpeg \
--prompts \
    "make it oil painting style" \
    "make it into a sketch style" \
--output /root/data/triton_e2e_result/0423-main_branch/i2i 
Baseline This PR
image image
image image
Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4691d9a97a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm_omni/diffusion/models/bagel/bagel_transformer.py Outdated
Comment thread vllm_omni/diffusion/models/bagel/bagel_transformer.py Outdated
Comment thread vllm_omni/diffusion/models/bagel/bagel_transformer.py Outdated
Comment thread vllm_omni/diffusion/models/bagel/bagel_transformer.py Outdated
@yzhu802 yzhu802 force-pushed the enhancement/mot_fused_kernels branch from d55d4bd to aac5553 Compare February 11, 2026 08:01
@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Feb 11, 2026

Hi @princepride @hsliuustc0106

Following up on our discussion from the 2026-02-04 Meeting regarding RFC #1146, I have implemented the initial fused MoT kernels (Linear & MLP) and verified their correctness. Preliminary benchmarks on A100 show a ~10x speedup compared to the PyTorch native implementation.

I noticed that bagel_transformer.py was recently updated to include vLLM's tensor parallel components (e.g., ColumnParallelLinear, QKVParallelLinear).

To avoid merge conflicts and ensure I don't disrupt the ongoing TP support, I plan to submit the kernels and benchmarks first (Infrastructure only) in my upcoming PR, without modifying the model files yet.

Before I proceed with integration, could you confirm if anyone else is currently working on Bagel performance optimization?

If not, here is my proposed roadmap:

[Current PR] Kernel Primitives: Submit Triton kernels with static tuning for A100.

Deep Fusion: Implement QKV fusion and FFN fusion patterns to further reduce kernel launch overhead.

Quantization Support: Align with the framework’s quantization interfaces for end-to-end testing.

Mainstream Hardware Tuning: Add config support for H100/4090 and refine the tuning scripts (referencing vLLM's MoE design).

Model Integration: Inherit from ColumnParallelLinear / RowParallelLinear to create MoT-aware parallel layers and fully replace the implementation in bagel_transformer.py.

Does this plan align with the current roadmap?

Thanks!

@princepride
Copy link
Copy Markdown
Collaborator

Yes, Someone is working on cfg parallel.

@princepride
Copy link
Copy Markdown
Collaborator

TP has already been merged, and I think it's best to implement the full functionality, because only kernel are not convenient for e2e testing, and I also believe that kernel will not be affected by the current code.

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Feb 11, 2026

TP has already been merged, and I think it's best to implement the full functionality, because only kernel are not convenient for e2e testing, and I also believe that kernel will not be affected by the current code.

Got it. That clarifies things, I will proceed with implementing the full feature set in this PR. :)

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid kernel work -- the routing/dispatch design is clean and the benchmark numbers are convincing. A few things worth tightening up before this lands.

from vllm.triton_utils import tl, triton


# Default Triton block-size configuration.(tuned for A100、A800)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_best_config is entirely hardcoded for Bagel exact shapes (3584, 512, 18944). If someone uses MoTLinear with any other dimension they silently get the fallback config which may be far from optimal. Would it make sense to add an autotuning path (even a simple triton.autotune) so this generalizes beyond Bagel, or at least log a warning when hitting the fallback?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @lishunyang12
Thank you for your review, and I’m sorry I haven’t been logging in to check messages much recently. That was entirely my fault — I should have marked the PR as WIP at the time.

Yes, you’re absolutely right. In the current version, I have implemented an auto-tuning script (located at benchmarks/kernels/mot_linear_benchmarks.py). In this version, the user only needs to provide the model’s Hugging Face name and the TP size planned for deployment. The script will then retrieve the current hardware information and automatically search for the optimal configuration parameters.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, the auto-tuning + 3-tier config loading is a big improvement over the hardcoded shapes.

pid_m = cur_pid // num_pid_n
pid_n = cur_pid % num_pid_n

# 4. Load indirect indices (Indirect Indexing for A)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The router initializes to the VAE path first, then overwrites to the text path if pid < num_pid_m_text * num_pid_n. This avoids an else branch in Triton -- clever, but could you add a one-line comment explaining this pattern? Without it the initialize-to-VAE-then-maybe-overwrite flow reads like a bug on first glance.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I can do that. :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

K,
# Strides
A.stride(0),
A.stride(1),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .t() on the weight here (and in every invoke_mot_gemm call in mot_linear.py) produces a non-contiguous view each time. The is_weak_contiguous check handles it, but would it be cleaner to store the weights pre-transposed at init time so you do not pay for the transpose view on every forward call?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you’re absolutely right. In the current implementation, I use subclasses of vLLM’s QKVParallelLinear, ColumnParallelLinear, and similar classes as wrapper layers. During testing, I found that vLLM already ensures the weights are transposed when they are loaded.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, vLLM wrapper layers handle the transpose internally.

# 0=None, 1=W8A8, 2=W8A16, 3=W4A16
quant_type = 0
ACCUMULATOR_DTYPE = tl.float32
COMPUTE_DTYPE = triton_dtype(A.dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_int4_w4a16 immediately raises NotImplementedError, and _core_weight_only_gemm has a matching raise for WEIGHT_BITS != 8. If W4 is truly not planned for this PR, maybe just remove the use_int4_w4a16 parameter entirely rather than leaving dead code paths in both the launcher and the kernel?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to handle various quantization scenarios after the framework’s quantization work is completed, ensuring that Bagel can stay aligned with other diffusion models. I’m aware that some INT4 quantization approaches for language models involve special memory layouts, but I’m not sure what kind of quantization scheme would be most suitable for Bagel. Perhaps later I can look into related research or experiment on my own to figure this out.

I’m planning to reserve the interfaces here for potential future implementations. Do you think that would be a bad idea?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, can revisit when the framework quantization work lands.

# --- Text Path ---
row_idx = tl.load(text_indices_ptr + pid)
weight_ptr = text_weight_ptr
else:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sum-of-squares accumulation loops over n_cols in BLOCK_SIZE chunks, and BLOCK_SIZE is capped at 4096 in the launcher. For hidden_dim=3584 that is one pass. If hidden_dim ever exceeds 4096, each extra pass re-reads from global memory. Have you profiled whether that matters, or is the assumption that hidden_dim <= 4096 always holds for MoT models?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am aware of two models that utilize the MoT (Mixture-of-Transformers) architecture: Chameleon and Bagel. Both satisfy the condition of hidden_dim < 4096. Since the optimization overhead for RMSNorm is relatively small, I don't believe it requires a complex tuning mechanism similar to what we use for linear layers (for instance, the RMSNorm operator in vLLM also employs hard-coded, simple rules).

However, we must ensure that we don't encounter compilation errors or register pressure/overflow on lower-end GPUs. Considering that 4096×3×2/1024=24KB, this already places a strain on some entry-level hardware; larger values could lead to direct compilation failures on certain devices.

For future models where hidden_dim > 4096, the computational logic will remain functionally correct—it will simply transition to a looped execution mode.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.


# gen mode – fused MoT Triton kernel
from vllm_omni.diffusion.layers.mot.ops.mot_rms_norm import (
mot_rms_norm,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In forward_cuda und mode you delegate to vllm rms_norm, but forward_native uses _rms_norm_native which casts to fp32 internally. If someone toggles between native and cuda paths, the numerical behavior will differ slightly since the vllm kernel may not do the same fp32 upcast. Is that intentional?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this MoTLinear was only used by me for testing and reporting intermediate results. I have now wrapped it using vLLM’s layers such as ColumnParallelLinear and others.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good.

C = torch.empty(x.size(0), self.out_features, dtype=x.dtype, device=x.device)

text_bias = self.text_linear.bias.data if self.text_linear.bias is not None else None
vae_bias = self.vae_linear.bias.data if self.vae_linear.bias is not None else None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In und mode (text_indices is None) you fall through to self.text_linear(x) which is plain nn.Linear. So und mode never benefits from any kernel fusion. Is that fine performance-wise?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry again, this MoTLinear was only used by me for testing and reporting intermediate results. QAQ The current version has been refactored. You may take another look.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, falling through to super().forward() is clean.

shared_kwargs = dict(
bias_text=None,
bias_vae=None,
text_indices=text_indices,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forward_cuda does 3 separate invoke_mot_gemm calls (gate, up, down). Have you considered fusing gate+up into a single kernel launch since they share the same input x and indices? I see this is in your roadmap as Deep Fusion, just curious about timeline.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have attempted to fuse the FFN (Feed-Forward Network) layers; however, I found that because the FFN is inherently and heavily compute-bound, the gains from fusion are marginal. In fact, there is even a performance regression during long-sequence processing—to be clear, this regression is relative to vLLM's MergeColumnParallelLinear combined with its multi-backend GEMM operator dispatching, rather than a standard PyTorch implementation.

As a result, I have decided to retain the original approach, utilizing MergeColumnParallelLinear and RowParallelLinear to implement the FFN. You can review the details in the latest commit. :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, not worth the complexity if gains are marginal/negative.

Comment thread tests/kernels/test_mot_gemm.py Outdated
import copy

layer_bf16 = copy.deepcopy(layer)
for p in layer_bf16.parameters():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correctness tests compare fp32-native vs bf16-triton, so the error threshold is really measuring precision loss from bf16 more than kernel correctness. Would it be worth adding a bf16-native vs bf16-triton test too? That would isolate kernel correctness from precision differences.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correctness comparison in the new version is between vLLM’s linear layer implementations (such as ColumnParallelLinear) and my corresponding MoT linear layer implementation. You may take another look at the latest commit. :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice improvement, comparing vLLM linear vs MoT linear isolates kernel correctness much better.

Comment thread benchmarks/kernels/mot_benchmarks.py Outdated

def _make_indices(M: int, text_ratio: float = 0.05, device: str = "cuda"):
"""Split M rows into text / vae index tensors."""
M_text = max(1, int(M * text_ratio))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmark hardcodes text_ratio=0.05 (95% VAE tokens). Would be good to also benchmark with a higher text ratio (e.g. 0.5) to check whether the kernel degrades when the two expert paths are more balanced.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into how Bagel works. Essentially, when generating images, the mixed tokens look like this:
<img_start> ......some image tokens..... <img_end>
<img_start> ......some image tokens..... <img_end>
This means the text_ratio is almost impossible to exceed 0.01. In the framework-provided end-to-end tests, 4096 image tokens correspond to only 2 text tokens.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good context, the ~0.5-1.5% actual text ratio explains it.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

@vllm-omni-reviewer

@yzhu802 yzhu802 force-pushed the enhancement/mot_fused_kernels branch from aac5553 to 8f541f3 Compare March 4, 2026 09:55
@hsliuustc0106
Copy link
Copy Markdown
Collaborator

can you provide more acc tests?

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Mar 4, 2026

can you provide more acc tests?

Hi @hsliuustc0106
By "more acc test" you mean tests cases with more strict tolerance or more diverse test cases? or both?

BTW I tested again just now and found rebase on the main branch make my weight loading fails.... Let me look into this issue tomorrow. QAQ

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please resolve DCO and change chinese comments in the code.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

hsliuustc0106 commented Mar 4, 2026

can you provide more acc tests?

Hi @hsliuustc0106 By "more acc test" you mean tests cases with more strict tolerance or more diverse test cases? or both?

BTW I tested again just now and found rebase on the main branch make my weight loading fails.... Let me look into this issue tomorrow. QAQ

yes, recently we are considering acc tests for diffusion models by collecting some representative examples. We might need to compare them for pixel level values w/o any optimization.

cc @wtomin @princepride

@yzhu802 yzhu802 force-pushed the enhancement/mot_fused_kernels branch from 6879e20 to 2882c3b Compare March 5, 2026 04:07
@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Mar 5, 2026

Please resolve DCO and change chinese comments in the code.

DCO resolved and Chinese comments removed.

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Mar 5, 2026

can you provide more acc tests?

Hi @hsliuustc0106 By "more acc test" you mean tests cases with more strict tolerance or more diverse test cases? or both?
BTW I tested again just now and found rebase on the main branch make my weight loading fails.... Let me look into this issue tomorrow. QAQ

yes, recently we are considering acc tests for diffusion models by collecting some representative examples. We might need to compare them for pixel level values w/o any optimization.

cc @wtomin @princepride

Thanks for the clarification! Regarding the acc tests and the current progress, I have a few updates to share:

About the previous error: I just fixed the weight loading issue caused by the rebase. The operator now works fine with the main branch.

Potential issue on the main branch: During my testing just now, I noticed some abnormal behavior on the main branch. After the rebase, although the model runs and generates images, it no longer seems to follow prompts (it generates the exact same image regardless of the input). This feels like a bug related to CFG-parallel or KV transfer (I noticed those two commits were merged recently). We might need to verify this issue before proceeding with the pixel-level acc tests as a baseline.

Regarding the acc tests: I totally agree with adding pixel-level comparison tests. However, this sounds more like a framework-level testing infrastructure rather than a test specific to this single operator. The implementation I have in mind is a standalone script (e.g., bash + python) that automatically switches between the PR branch and the main branch, runs the same tests, and compares the pixel-level differences of the saved images. Since this involves developing a general testing pipeline, it might not be appropriate to bundle it within this specific PR. Should we consider opening a separate Issue or PR for this testing infrastructure?

After rebase, prompt "a person giving a presentation to a room full of colleagues" (and basically all prompts) now generate this:
image

@princepride
Copy link
Copy Markdown
Collaborator

This result means the kv cache transfer failed.

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Mar 6, 2026

This result means the kv cache transfer failed.

Oh, OK. What should I do now to make this PR merge?

@princepride
Copy link
Copy Markdown
Collaborator

I think you need know how to solve this error, the original code don't have it

@yzhu802 yzhu802 force-pushed the enhancement/mot_fused_kernels branch from 2882c3b to 1eb55b8 Compare March 10, 2026 11:01
@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Mar 10, 2026

@princepride hi, maintainer, I just rebased on the newest code, can the PR be merged now? QAQ

@princepride
Copy link
Copy Markdown
Collaborator

I will review it later

get_best_mot_config,
)

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@congw729 I add these labels to test kernel, What do you think?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do not need multiple cards, it is fine.

@princepride
Copy link
Copy Markdown
Collaborator

@ZJY0516 PTAL

@princepride princepride requested a review from ZJY0516 March 10, 2026 15:25
bias=True,
params_dtype=dcfg.torch_dtype,
disable_tp=True,
).cuda()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better avoid hard-code

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I can move bias to @pytest.mark.parametrize().

logger = logging.getLogger(__name__)

# =====================================================================
# MoT GEMM Config Loading (3-tier: env → built-in → default)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a doc for this

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. The main documentation related to configuration loading for the MoE operator in vLLM is here:
vllm/model_executor/layers/fused_moe/configs/README.md
This document explains what the JSON files in that directory are and how they are generated.

I could design a similar document under the following directory:
vllm_omni/diffusion/layers/mot/configs/

Or would you prefer that I place the documentation under docs/user_guide/diffusion/ instead?

# =================================================================


# Core A: Standard GEMM (for BF16/FP16 and W8A8)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

little question: why we need to implement a Standard GEMM

Copy link
Copy Markdown
Author

@yzhu802 yzhu802 Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking, this is not a standard GEMM, but rather the second half of one. In a standard GEMM kernel, the beginning typically computes pid_m and pid_n from pid. However, Bagel uses a specialized MoT routing scheme, so the mapping between pid and pid_m&pid_n is different.

If standard GEMM and quantized GEMM can separate the tiling logic from the main compute loop into two kernels, then Bagel could potentially reuse the latter part of their implementation.

Comment thread vllm_omni/diffusion/layers/mot/ops/mot_rms_norm.py

from vllm_omni.diffusion.layers.mot.mot_layernorm import MoTRMSNorm

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu]
Copy link
Copy Markdown
Collaborator

@congw729 congw729 Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many cards do you need for this test? We highly commend adding the platform mark - @pytest.mark.gpu by the decorator @hardware_tests. To state the computing resource and the number of card together, please check https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many cards do you need for this test? We highly commend adding the platform mark - @pytest.mark.gpu by the decorator @hardware_tests. To state the computing resource and the number of card together, please check https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/

This test only need 1 gpu. I've tested tp=2 for all-reduce correctness but decided not to add that testcase to PR since it needs to be run by torchrun --nproc_per_node=2 -m pytest <script_name> feels like it will introduce some unstable stuff into the framework CI.

@Gaohan123 Gaohan123 added this to the v0.18.0 milestone Mar 17, 2026
@Gaohan123 Gaohan123 modified the milestones: v0.18.0, v0.20.0 Apr 14, 2026
@timzsu
Copy link
Copy Markdown
Contributor

timzsu commented Apr 18, 2026

@princepride Since this PR has been inactive over the past month, I would like to take on this PR and complete it.

@princepride
Copy link
Copy Markdown
Collaborator

@princepride Since this PR has been inactive over the past month, I would like to take on this PR and complete it.

Sure, you can continue do it! I think you can cherry-pick the patch.

timzsu added a commit to timzsu/vllm-omni that referenced this pull request Apr 18, 2026
Rebase MoT fused kernels (PR vllm-project#1328) onto current main and fix issues:

- Rewrite _forward_sp_gen() to use MoT unified API instead of deleted
  *_moe_gen layers (qkv_proj_moe_gen, o_proj_moe_gen, q_norm_moe_gen,
  k_norm_moe_gen), which caused AttributeError when SP was active
- Parameterize bias in MoT kernel tests (ZJY0516 review feedback)
- Add configs/ directory with README documenting the 3-tier config
  loading mechanism (ZJY0516 review feedback)

Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yzhu802 yzhu802 force-pushed the enhancement/mot_fused_kernels branch from 0b0419a to efc7f81 Compare April 19, 2026 05:25
Yufeng Zhu and others added 6 commits April 19, 2026 13:45
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Rebase MoT fused kernels (PR vllm-project#1328) onto current main and fix issues:

- Rewrite _forward_sp_gen() to use MoT unified API instead of deleted
  *_moe_gen layers (qkv_proj_moe_gen, o_proj_moe_gen, q_norm_moe_gen,
  k_norm_moe_gen), which caused AttributeError when SP was active
- Parameterize bias in MoT kernel tests (ZJY0516 review feedback)
- Add configs/ directory with README documenting the 3-tier config
  loading mechanism (ZJY0516 review feedback)

Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
@yzhu802 yzhu802 force-pushed the enhancement/mot_fused_kernels branch from efc7f81 to 35f4fed Compare April 19, 2026 05:47
Copy link
Copy Markdown
Contributor

@timzsu timzsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide some sample images to justify the test relaxation.


PIXEL_TOLERANCE = 10
# Relax pixel tolerance to 15 to account for Triton kernel numerical drift
PIXEL_TOLERANCE = 15
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzhu802 To justify this relaxation, can you post some sample output images to ensure that this deviation does not damage the output quality?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do this tomorrow evening. Sorry, I’ve been really busy lately.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@princepride @lishunyang12 @timzsu
The PR updates some T2I and I2I examples, and it is difficult to notice any visible differences with the naked eye. In addition, I observed that Bagel is not particularly good at image-to-image tasks, so it is natural that the pixel difference after operator fusion is larger compared to text-to-image tasks.

@Gaohan123
Copy link
Copy Markdown
Collaborator

@yzhu802 Could you please provide some sample images? Thanks!

@Gaohan123 Gaohan123 added ready label to trigger buildkite CI diffusion-x2iat-test label to trigger buildkite x2image + x2audio + x2text series of diffusion models test in nightly CI labels Apr 30, 2026
@Gaohan123 Gaohan123 modified the milestones: v0.20.0, v0.22.0 May 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion-x2iat-test label to trigger buildkite x2image + x2audio + x2text series of diffusion models test in nightly CI ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants