[Enhancement] Introduce High-Performance MoT (Mixture-of-Tokens) Kernels: Triton Implementation & A100 Tuning#1328
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4691d9a97a
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
d55d4bd to
aac5553
Compare
|
Hi @princepride @hsliuustc0106 Following up on our discussion from the 2026-02-04 Meeting regarding RFC #1146, I have implemented the initial fused MoT kernels (Linear & MLP) and verified their correctness. Preliminary benchmarks on A100 show a ~10x speedup compared to the PyTorch native implementation. I noticed that bagel_transformer.py was recently updated to include vLLM's tensor parallel components (e.g., ColumnParallelLinear, QKVParallelLinear). To avoid merge conflicts and ensure I don't disrupt the ongoing TP support, I plan to submit the kernels and benchmarks first (Infrastructure only) in my upcoming PR, without modifying the model files yet. Before I proceed with integration, could you confirm if anyone else is currently working on Bagel performance optimization? If not, here is my proposed roadmap: Does this plan align with the current roadmap? Thanks! |
|
Yes, Someone is working on cfg parallel. |
|
TP has already been merged, and I think it's best to implement the full functionality, because only kernel are not convenient for e2e testing, and I also believe that kernel will not be affected by the current code. |
Got it. That clarifies things, I will proceed with implementing the full feature set in this PR. :) |
lishunyang12
left a comment
There was a problem hiding this comment.
Solid kernel work -- the routing/dispatch design is clean and the benchmark numbers are convincing. A few things worth tightening up before this lands.
| from vllm.triton_utils import tl, triton | ||
|
|
||
|
|
||
| # Default Triton block-size configuration.(tuned for A100、A800) |
There was a problem hiding this comment.
get_best_config is entirely hardcoded for Bagel exact shapes (3584, 512, 18944). If someone uses MoTLinear with any other dimension they silently get the fallback config which may be far from optimal. Would it make sense to add an autotuning path (even a simple triton.autotune) so this generalizes beyond Bagel, or at least log a warning when hitting the fallback?
There was a problem hiding this comment.
hi, @lishunyang12
Thank you for your review, and I’m sorry I haven’t been logging in to check messages much recently. That was entirely my fault — I should have marked the PR as WIP at the time.
Yes, you’re absolutely right. In the current version, I have implemented an auto-tuning script (located at benchmarks/kernels/mot_linear_benchmarks.py). In this version, the user only needs to provide the model’s Hugging Face name and the TP size planned for deployment. The script will then retrieve the current hardware information and automatically search for the optimal configuration parameters.
There was a problem hiding this comment.
Nice, the auto-tuning + 3-tier config loading is a big improvement over the hardcoded shapes.
| pid_m = cur_pid // num_pid_n | ||
| pid_n = cur_pid % num_pid_n | ||
|
|
||
| # 4. Load indirect indices (Indirect Indexing for A) |
There was a problem hiding this comment.
The router initializes to the VAE path first, then overwrites to the text path if pid < num_pid_m_text * num_pid_n. This avoids an else branch in Triton -- clever, but could you add a one-line comment explaining this pattern? Without it the initialize-to-VAE-then-maybe-overwrite flow reads like a bug on first glance.
| K, | ||
| # Strides | ||
| A.stride(0), | ||
| A.stride(1), |
There was a problem hiding this comment.
Calling .t() on the weight here (and in every invoke_mot_gemm call in mot_linear.py) produces a non-contiguous view each time. The is_weak_contiguous check handles it, but would it be cleaner to store the weights pre-transposed at init time so you do not pay for the transpose view on every forward call?
There was a problem hiding this comment.
Yes, you’re absolutely right. In the current implementation, I use subclasses of vLLM’s QKVParallelLinear, ColumnParallelLinear, and similar classes as wrapper layers. During testing, I found that vLLM already ensures the weights are transposed when they are loaded.
There was a problem hiding this comment.
Makes sense, vLLM wrapper layers handle the transpose internally.
| # 0=None, 1=W8A8, 2=W8A16, 3=W4A16 | ||
| quant_type = 0 | ||
| ACCUMULATOR_DTYPE = tl.float32 | ||
| COMPUTE_DTYPE = triton_dtype(A.dtype) |
There was a problem hiding this comment.
use_int4_w4a16 immediately raises NotImplementedError, and _core_weight_only_gemm has a matching raise for WEIGHT_BITS != 8. If W4 is truly not planned for this PR, maybe just remove the use_int4_w4a16 parameter entirely rather than leaving dead code paths in both the launcher and the kernel?
There was a problem hiding this comment.
I plan to handle various quantization scenarios after the framework’s quantization work is completed, ensuring that Bagel can stay aligned with other diffusion models. I’m aware that some INT4 quantization approaches for language models involve special memory layouts, but I’m not sure what kind of quantization scheme would be most suitable for Bagel. Perhaps later I can look into related research or experiment on my own to figure this out.
I’m planning to reserve the interfaces here for potential future implementations. Do you think that would be a bad idea?
There was a problem hiding this comment.
Fair enough, can revisit when the framework quantization work lands.
| # --- Text Path --- | ||
| row_idx = tl.load(text_indices_ptr + pid) | ||
| weight_ptr = text_weight_ptr | ||
| else: |
There was a problem hiding this comment.
The sum-of-squares accumulation loops over n_cols in BLOCK_SIZE chunks, and BLOCK_SIZE is capped at 4096 in the launcher. For hidden_dim=3584 that is one pass. If hidden_dim ever exceeds 4096, each extra pass re-reads from global memory. Have you profiled whether that matters, or is the assumption that hidden_dim <= 4096 always holds for MoT models?
There was a problem hiding this comment.
I am aware of two models that utilize the MoT (Mixture-of-Transformers) architecture: Chameleon and Bagel. Both satisfy the condition of hidden_dim < 4096. Since the optimization overhead for RMSNorm is relatively small, I don't believe it requires a complex tuning mechanism similar to what we use for linear layers (for instance, the RMSNorm operator in vLLM also employs hard-coded, simple rules).
However, we must ensure that we don't encounter compilation errors or register pressure/overflow on lower-end GPUs. Considering that 4096×3×2/1024=24KB, this already places a strain on some entry-level hardware; larger values could lead to direct compilation failures on certain devices.
For future models where hidden_dim > 4096, the computational logic will remain functionally correct—it will simply transition to a looped execution mode.
|
|
||
| # gen mode – fused MoT Triton kernel | ||
| from vllm_omni.diffusion.layers.mot.ops.mot_rms_norm import ( | ||
| mot_rms_norm, |
There was a problem hiding this comment.
In forward_cuda und mode you delegate to vllm rms_norm, but forward_native uses _rms_norm_native which casts to fp32 internally. If someone toggles between native and cuda paths, the numerical behavior will differ slightly since the vllm kernel may not do the same fp32 upcast. Is that intentional?
There was a problem hiding this comment.
Sorry, this MoTLinear was only used by me for testing and reporting intermediate results. I have now wrapped it using vLLM’s layers such as ColumnParallelLinear and others.
| C = torch.empty(x.size(0), self.out_features, dtype=x.dtype, device=x.device) | ||
|
|
||
| text_bias = self.text_linear.bias.data if self.text_linear.bias is not None else None | ||
| vae_bias = self.vae_linear.bias.data if self.vae_linear.bias is not None else None |
There was a problem hiding this comment.
In und mode (text_indices is None) you fall through to self.text_linear(x) which is plain nn.Linear. So und mode never benefits from any kernel fusion. Is that fine performance-wise?
There was a problem hiding this comment.
Sorry again, this MoTLinear was only used by me for testing and reporting intermediate results. QAQ The current version has been refactored. You may take another look.
There was a problem hiding this comment.
Looks good, falling through to super().forward() is clean.
| shared_kwargs = dict( | ||
| bias_text=None, | ||
| bias_vae=None, | ||
| text_indices=text_indices, |
There was a problem hiding this comment.
forward_cuda does 3 separate invoke_mot_gemm calls (gate, up, down). Have you considered fusing gate+up into a single kernel launch since they share the same input x and indices? I see this is in your roadmap as Deep Fusion, just curious about timeline.
There was a problem hiding this comment.
I have attempted to fuse the FFN (Feed-Forward Network) layers; however, I found that because the FFN is inherently and heavily compute-bound, the gains from fusion are marginal. In fact, there is even a performance regression during long-sequence processing—to be clear, this regression is relative to vLLM's MergeColumnParallelLinear combined with its multi-backend GEMM operator dispatching, rather than a standard PyTorch implementation.
As a result, I have decided to retain the original approach, utilizing MergeColumnParallelLinear and RowParallelLinear to implement the FFN. You can review the details in the latest commit. :)
There was a problem hiding this comment.
Makes sense, not worth the complexity if gains are marginal/negative.
| import copy | ||
|
|
||
| layer_bf16 = copy.deepcopy(layer) | ||
| for p in layer_bf16.parameters(): |
There was a problem hiding this comment.
The correctness tests compare fp32-native vs bf16-triton, so the error threshold is really measuring precision loss from bf16 more than kernel correctness. Would it be worth adding a bf16-native vs bf16-triton test too? That would isolate kernel correctness from precision differences.
There was a problem hiding this comment.
The correctness comparison in the new version is between vLLM’s linear layer implementations (such as ColumnParallelLinear) and my corresponding MoT linear layer implementation. You may take another look at the latest commit. :)
There was a problem hiding this comment.
Nice improvement, comparing vLLM linear vs MoT linear isolates kernel correctness much better.
|
|
||
| def _make_indices(M: int, text_ratio: float = 0.05, device: str = "cuda"): | ||
| """Split M rows into text / vae index tensors.""" | ||
| M_text = max(1, int(M * text_ratio)) |
There was a problem hiding this comment.
The benchmark hardcodes text_ratio=0.05 (95% VAE tokens). Would be good to also benchmark with a higher text ratio (e.g. 0.5) to check whether the kernel degrades when the two expert paths are more balanced.
There was a problem hiding this comment.
I looked into how Bagel works. Essentially, when generating images, the mixed tokens look like this:
<img_start> ......some image tokens..... <img_end>
<img_start> ......some image tokens..... <img_end>
This means the text_ratio is almost impossible to exceed 0.01. In the framework-provided end-to-end tests, 4096 image tokens correspond to only 2 text tokens.
There was a problem hiding this comment.
Good context, the ~0.5-1.5% actual text ratio explains it.
|
@vllm-omni-reviewer |
aac5553 to
8f541f3
Compare
|
can you provide more acc tests? |
Hi @hsliuustc0106 BTW I tested again just now and found rebase on the main branch make my weight loading fails.... Let me look into this issue tomorrow. QAQ |
yes, recently we are considering acc tests for diffusion models by collecting some representative examples. We might need to compare them for pixel level values w/o any optimization. |
6879e20 to
2882c3b
Compare
DCO resolved and Chinese comments removed. |
Thanks for the clarification! Regarding the acc tests and the current progress, I have a few updates to share: About the previous error: I just fixed the weight loading issue caused by the rebase. The operator now works fine with the main branch. Potential issue on the main branch: During my testing just now, I noticed some abnormal behavior on the main branch. After the rebase, although the model runs and generates images, it no longer seems to follow prompts (it generates the exact same image regardless of the input). This feels like a bug related to CFG-parallel or KV transfer (I noticed those two commits were merged recently). We might need to verify this issue before proceeding with the pixel-level acc tests as a baseline. Regarding the acc tests: I totally agree with adding pixel-level comparison tests. However, this sounds more like a framework-level testing infrastructure rather than a test specific to this single operator. The implementation I have in mind is a standalone script (e.g., bash + python) that automatically switches between the PR branch and the main branch, runs the same tests, and compares the pixel-level differences of the saved images. Since this involves developing a general testing pipeline, it might not be appropriate to bundle it within this specific PR. Should we consider opening a separate Issue or PR for this testing infrastructure? After rebase, prompt "a person giving a presentation to a room full of colleagues" (and basically all prompts) now generate this: |
|
This result means the kv cache transfer failed. |
Oh, OK. What should I do now to make this PR merge? |
|
I think you need know how to solve this error, the original code don't have it |
2882c3b to
1eb55b8
Compare
|
@princepride hi, maintainer, I just rebased on the newest code, can the PR be merged now? QAQ |
|
I will review it later |
| get_best_mot_config, | ||
| ) | ||
|
|
||
| pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu] |
There was a problem hiding this comment.
@congw729 I add these labels to test kernel, What do you think?
There was a problem hiding this comment.
If you do not need multiple cards, it is fine.
|
@ZJY0516 PTAL |
| bias=True, | ||
| params_dtype=dcfg.torch_dtype, | ||
| disable_tp=True, | ||
| ).cuda() |
There was a problem hiding this comment.
ok, I can move bias to @pytest.mark.parametrize().
| logger = logging.getLogger(__name__) | ||
|
|
||
| # ===================================================================== | ||
| # MoT GEMM Config Loading (3-tier: env → built-in → default) |
There was a problem hiding this comment.
OK. The main documentation related to configuration loading for the MoE operator in vLLM is here:
vllm/model_executor/layers/fused_moe/configs/README.md
This document explains what the JSON files in that directory are and how they are generated.
I could design a similar document under the following directory:
vllm_omni/diffusion/layers/mot/configs/
Or would you prefer that I place the documentation under docs/user_guide/diffusion/ instead?
| # ================================================================= | ||
|
|
||
|
|
||
| # Core A: Standard GEMM (for BF16/FP16 and W8A8) |
There was a problem hiding this comment.
little question: why we need to implement a Standard GEMM
There was a problem hiding this comment.
Strictly speaking, this is not a standard GEMM, but rather the second half of one. In a standard GEMM kernel, the beginning typically computes pid_m and pid_n from pid. However, Bagel uses a specialized MoT routing scheme, so the mapping between pid and pid_m&pid_n is different.
If standard GEMM and quantized GEMM can separate the tiling logic from the main compute loop into two kernels, then Bagel could potentially reuse the latter part of their implementation.
|
|
||
| from vllm_omni.diffusion.layers.mot.mot_layernorm import MoTRMSNorm | ||
|
|
||
| pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.gpu] |
There was a problem hiding this comment.
How many cards do you need for this test? We highly commend adding the platform mark - @pytest.mark.gpu by the decorator @hardware_tests. To state the computing resource and the number of card together, please check https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/
There was a problem hiding this comment.
How many cards do you need for this test? We highly commend adding the platform mark - @pytest.mark.gpu by the decorator @hardware_tests. To state the computing resource and the number of card together, please check https://docs.vllm.ai/projects/vllm-omni/en/latest/contributing/ci/tests_markers/
This test only need 1 gpu. I've tested tp=2 for all-reduce correctness but decided not to add that testcase to PR since it needs to be run by torchrun --nproc_per_node=2 -m pytest <script_name> feels like it will introduce some unstable stuff into the framework CI.
|
@princepride Since this PR has been inactive over the past month, I would like to take on this PR and complete it. |
Sure, you can continue do it! I think you can cherry-pick the patch. |
Rebase MoT fused kernels (PR vllm-project#1328) onto current main and fix issues: - Rewrite _forward_sp_gen() to use MoT unified API instead of deleted *_moe_gen layers (qkv_proj_moe_gen, o_proj_moe_gen, q_norm_moe_gen, k_norm_moe_gen), which caused AttributeError when SP was active - Parameterize bias in MoT kernel tests (ZJY0516 review feedback) - Add configs/ directory with README documenting the 3-tier config loading mechanism (ZJY0516 review feedback) Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
0b0419a to
efc7f81
Compare
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Rebase MoT fused kernels (PR vllm-project#1328) onto current main and fix issues: - Rewrite _forward_sp_gen() to use MoT unified API instead of deleted *_moe_gen layers (qkv_proj_moe_gen, o_proj_moe_gen, q_norm_moe_gen, k_norm_moe_gen), which caused AttributeError when SP was active - Parameterize bias in MoT kernel tests (ZJY0516 review feedback) - Add configs/ directory with README documenting the 3-tier config loading mechanism (ZJY0516 review feedback) Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
efc7f81 to
35f4fed
Compare
timzsu
left a comment
There was a problem hiding this comment.
Please provide some sample images to justify the test relaxation.
|
|
||
| PIXEL_TOLERANCE = 10 | ||
| # Relax pixel tolerance to 15 to account for Triton kernel numerical drift | ||
| PIXEL_TOLERANCE = 15 |
There was a problem hiding this comment.
@yzhu802 To justify this relaxation, can you post some sample output images to ensure that this deviation does not damage the output quality?
There was a problem hiding this comment.
I can do this tomorrow evening. Sorry, I’ve been really busy lately.
There was a problem hiding this comment.
@princepride @lishunyang12 @timzsu
The PR updates some T2I and I2I examples, and it is difficult to notice any visible differences with the naked eye. In addition, I observed that Bagel is not particularly good at image-to-image tasks, so it is natural that the pixel difference after operator fusion is larger compared to text-to-image tasks.
|
@yzhu802 Could you please provide some sample images? Thanks! |

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
This PR introduces high-performance Mixture-of-Tokens (MoT) core operator primitives for vLLM-Omni, including MoTLinear and MoTRMSNorm. These operators are developed using Triton, following the proposals in RFC #1146 and #936.
Design
1) Overall Architecture
Adheres to the hierarchical design of vLLM Triton operators: Triton kernels -> Layer encapsulation -> Model invocation. The implementations for both the operator layer and layer encapsulation are centralized within the vllm_omni/diffusion/layers/mot/ directory.
2) Quantization-Friendly Interface Design
Operator-Level Separation: The MoT routing logic and GEMM logic are decoupled into separate kernels. If future frameworks implement quantization logic using Triton, these kernels can be easily swapped out.
Layer Encapsulation Separation: Each layer wrapper implements a fallback method to the native vLLM GEMM backend. If a future CUDA-based quantized GEMM backend outperforms the quantized MoTLinear, this fallback mechanism can be directly utilized.
3) Hardware/Model-Specific Tuning Mechanism
Automated Static Tuning: The script benchmarks/kernels/mot_linear_benchmarks.py provides two modes: benchmarking and tuning. Launching the script with --tune triggers the tuning mode, which generates the required compilation parameters for the current hardware and model. To accelerate the search for optimal parameter combinations, it leverages Ray for parallel computation in multi-GPU environments and implements a pruning strategy specifically designed for NVIDIA GPUs.
Note: This script has not been tested on ROCm architectures or other NPUs. Currently, running it in non-CUDA environments will raise a NotImplementedError().
Automated Tuning Parameter Loading: For CUDA architecture hardware, the framework automatically loads the optimal parameters generated by the tuning script. If the specific configuration file for the current hardware is missing, it will explicitly warn the user in the terminal to perform tuning, otherwise performance will degrade significantly. In this fallback scenario, the Triton operator uses a sub-optimal configuration combination that prevents compilation errors or register spilling on lower-end GPUs (e.g., T4) but yields poor performance on high-end GPUs (e.g., H100/A100).
Note: For non-CUDA architectures, it defaults to falling back to the original vLLM linear layer implementation. This can be replaced by diffusion-specific linear layers in the future.
Test Plan
Unit Testing: Layer encapsulation unit tests are located in vllm-omni/tests/diffusion/kernels/mot. These test the acceleration and accuracy against three baselines across sequence lengths from 1k to 8k:
End-to-End (E2E) Testing: Follows the existing benchmarks/diffusion/diffusion_benchmark_serving.py. A minor modification was added allowing the caller to optionally save the generated images (saving is only triggered if --save_dir is passed, otherwise the original behavior is preserved). This allows for simultaneous evaluation of E2E generation quality and speed.
Test Result
Unit Tests
Correctness for MoTRMSNorm (All Passed):
Correctness for MoTQKVParallel & MoTRowParallel (All Passed):
Performance for MoTRMSNorm:
Since RMSNorm is a memory-bound operation, reducing memory access overhead significantly accelerates computation. Achieved a 5.52x ~ 1.39x speedup across 1k-8k sequence lengths.
Performance for MoTQKVParallel & MoTRowParallel:
These linear layers are compute-bound but have relatively narrow matrix widths.
MoTQKVParallel: Achieved a 1.51x ~ 1.25x speedup.
MoTRowParallel: Achieved a 1.12x ~ 1.52x speedup.
Note: FFN layers are explicitly compute-bound. Fused operators only show improvements at short sequence lengths. For sequences longer than 4096, performance regresses compared to the native CuBLAS implementation. Therefore, this PR does not introduce MoT operators for the FFN component; the model layer retains the original vLLM linear layer implementation.
End-to-End (E2E) Tests
Overall E2E Speedup: ~1.29x
The framework was launched on a single A100 and tested using the following command:
Test Group Results (Average of 3 runs):
Control Group Results (Average of 3 runs):
generated image according to prompt "a person giving a presentation"

other info mentioned in vLLM-Omni Meeting 2026-03-06 11:30 AM (UTC+8)
MoT kernel PR-presentation.pdf
Supplementary Experimental Notes
For some cases in pixel-level testing where the maximum RGB value difference reached 12 (exceeding the tolerance of 10).
To ensure that the Triton kernel numerical drift does not affect the end-to-end generation quality, I have conducted a rigorous visual fidelity test comparing this branch against the stable baseline.
Environment & Reproducibility
Baseline: v0.18.0
f55ea28 (HEAD -> pure_v0.18.0, tag: v0.18.0, upstream/release/v0.18.0) [Qwen3TTS][Bugfix] Replace vLLM fused layers with HF-compatible numerics in code predictor ([Qwen3TTS][Bugfix] Replace vLLM fused layers with HF-compatible numerics in code predictor #2277)
This PR: my work cleanly rebased on v0.18.0
86fd4e015 (HEAD -> mot_on_v0.18.0) Relax pixel tolerance to 15 to account for Triton kernel numerical drift
fa60bd10f Address review comments (benchmark logs, ids, and design comments)
d11d11419 [Enhancement] Fix SP path for MoT fused kernels, address review feedback
f52814a61 Add pytest markers for core model and diffusion tests
4f83134d9 Add pytest markers for core model and diffusion tests
5f49ddec5 enhancement: add mot fused kernels and benchmarks 01
f55ea28 (tag: v0.18.0, upstream/release/v0.18.0) [Qwen3TTS][Bugfix] Replace vLLM fused layers with HF-compatible numerics in code predictor ([Qwen3TTS][Bugfix] Replace vLLM fused layers with HF-compatible numerics in code predictor #2277)
Test Command:
Test Prompts:
Test Prompts:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)