[Enhancement] MoT Fused Kernels: Triton Implementation#2897
Conversation
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Rebase MoT fused kernels (PR vllm-project#1328) onto current main and fix issues: - Rewrite _forward_sp_gen() to use MoT unified API instead of deleted *_moe_gen layers (qkv_proj_moe_gen, o_proj_moe_gen, q_norm_moe_gen, k_norm_moe_gen), which caused AttributeError when SP was active - Parameterize bias in MoT kernel tests (ZJY0516 review feedback) - Add configs/ directory with README documenting the 3-tier config loading mechanism (ZJY0516 review feedback) Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@yzhu802 Do you mean that you will continue working on this PR? If so, please let me know so our efforts won't overlap. |
|
Ready for full review when draft status removed. Preliminary scan available on request. |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
|
@princepride @hsliuustc0106 The PR is ready for review. Basically rebased PR #1328 onto the latest main and addressed a few remaining comments. |
hsliuustc0106
left a comment
There was a problem hiding this comment.
PR #2897 Review — MoT Fused Kernels: Triton Implementation
Author: @timzsu (takes over PR #1328 by @yzhu802)
Size: +4041 / -224 across 16 files (substantial)
BLOCKER scan
| Category | Result |
|---|---|
| Correctness | PASS |
| Reliability/Safety | PASS |
| Breaking Changes | PASS |
| Test Coverage | PASS |
| Documentation | PASS |
| Security | PASS |
OVERALL: NO BLOCKERS
What was validated
1. Triton kernels (mot_gemm.py, mot_rms_norm.py)
- The unified MoT GEMM kernel (
mot_unified_gemm_kernel) uses a clean 3-part architecture: router (_get_mot_pointers), static dispatch to compute cores (_core_standard_gemm/_core_weight_only_gemm), and a shared store phase. This is well-structured. - Compile-time
constexprbranching onQUANT_TYPE,EVEN_K/N, stride-is-1 optimizations — no runtime overhead from unnecessary branches. - The router correctly maps PIDs to text vs VAE regions using
num_pid_m_text * num_pid_nboundary, with indirect index loading viatl.loadfor MoT routing. - Store mask uses
m_mask & n_mask— correctly avoids writing to padding rows wherereal_row_idxsmay be 0. - RMS norm kernels (
_mot_rms_norm_kernel,_mot_rms_norm_qk_kernel) properly route text/vae tokens and compute in fp32 with correct downcast.
2. MoT Layer wrappers (mot_qkv_parallel_linear.py, mot_row_parallel_linear.py)
- Clean inheritance from vLLM's
QKVParallelLinear/RowParallelLinearwithgen_expsubmodule for VAE weights. _mot_gemm_dispatchroutes by dtype: unquantized (BF16/FP16), FP8 W8A8, INT8 W8A16, or fallback to gather/scatter.- Non-CUDA platforms (
not current_platform.is_cuda()) properly fall back to gather/scatter path. - Bias handling correctly accounts for
skip_bias_addand TP rank (tp_rank > 0→ no bias for RowParallel).
3. Model integration (bagel_transformer.py)
_forward_sp_gen()correctly rewritten to use unified MoT API instead of deleted*_moe_genlayers.- The
forward()path for both"und"and"gen"modes is now clean:text_indices=Nonefor und mode, actual indices for gen mode. BagelMLPupdated to useSiluAndMul()(fused SiLU+gate) instead of separatenn.SiLU()+ multiply — this is a minor performance improvement beyond the scope of MoT, but correct.load_weights()properly remaps checkpoint names (_moe_gensuffixes →.gen_exp/.gen_weight), with ordered matching (most-specific patterns first).
4. Pipeline (pipeline_bagel.py)
load_weights()correctly expands allowed names for the new MoT parameter structure.
5. Test coverage
- 21/21 kernel unit tests passed with high accuracy (cos_sim > 0.999, max_abs < 0.0625).
- Tests cover both
bias=Trueandbias=Falsevia@pytest.mark.parametrize. - E2E benchmark on 2x RTX 6000 Ada shows 1.05x speedup over baseline (conservative defaults, not auto-tuned).
- CPU tests (13/13) verify no regression to existing functionality.
6. Benchmark infrastructure
mot_linear_benchmarks.pyprovides auto-tuning with proper search space pruning (SRAM capacity, register pressure, occupancy).- 3-tier config loading (env → built-in → default) mirrors vLLM's fused_moe pattern.
7. Documentation
configs/READMEdocuments the tuning mechanism.- PR description is thorough with purpose, test plan, results, and checklist.
Non-blocking suggestions
-
_save_generated_outputssilent exception (diffusion_benchmark_serving.py:856): Theexcept Exception as eonly prints a warning but does not track failures. Consider logging withlogging.warning()for structured observability, or incrementing a failure counter. -
warmup-num-inference-stepsdefault change (1→2): This is a behavior change for all diffusion benchmarks, not just BAGEL. Users running other models may notice slightly longer warmup. Consider adding a comment noting this was changed specifically for models wherenum_timesteps-1 == 0denoising steps, and suggest that model-specific overrides could be added in the future if this causes issues for other pipelines. -
gen_expas baretorch.nn.Module(): Inmot_qkv_parallel_linear.pyandmot_row_parallel_linear.py,self.gen_exp = torch.nn.Module()creates a module without__init__call. While this works for creating a weight container, thegen_exp.quant_method = self.quant_methodline adds an attribute after construction. This is functional but slightly unconventional. A minor comment explaining the design choice would help future readers. -
Benchmark test parametrize
ids: Intest_mot_linear.py,ids=lambda val: ""generates empty IDs for parametrized tests. Consider using descriptive IDs likef"img{image_num}"for clearer pytest output. -
get_mot_default_configsmall-M fallback: WhenM <= 16,BLOCK_SIZE_M=16is used. IfMis not a power of 2 (e.g., M=3), the kernel will still allocate a full 16-row tile. This is fine for correctness (masks handle it) but worth a brief comment.
Verdict
This is a well-executed performance enhancement. The MoT fused Triton kernels eliminate the gather/scatter overhead of running two separate linear projections for text and VAE tokens. The 3-tier config system and auto-tuning infrastructure are production-ready. Test coverage is thorough with both correctness and performance validation.
APPROVE — no blocking issues found.
| @@ -0,0 +1,50 @@ | |||
| This directory contains auto-tuned Triton kernel configurations for the | |||
There was a problem hiding this comment.
this shoudl not be placed here
There was a problem hiding this comment.
is docs/user_guide/diffusion the correct location for this README.md?
There was a problem hiding this comment.
Hi @timzsu,
I noticed the CI here is currently blocked by the test_bagel_img2img_shared_memory_connector pixel mismatch. The max deviation is around ~12, which slightly exceeds the strict tolerance of 10. Since text2img passed, this is just the expected numerical drift from the Triton kernels' FMA accumulation over the diffusion steps.
To unblock this, I have already relaxed the tolerance to 15 in my original PR (#1328), along with addressing all your review comments (adding benchmark logs, fixing pytest ids, adding design comments, move related README.md to docs/user_guide/diffusion and renaming it) and rebasing everything cleanly on top of your changes and upstream/main.
Since my PR is now fully up-to-date, conflict-free, and contains the CI fix + all your commits, would it be possible to just merge my original PR (#1328) instead? I have linked that specific PR on my resume, so having it marked as 'Merged' would be incredibly helpful for me!
Thank you so much for your guidance on getting this framework-ready!
There was a problem hiding this comment.
Sure. Let's move the discussion to your PR #1328. I will review your changes :)
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
Takes over PR #1328 by @yzhu802 — high-performance MoT (Mixture-of-Tokens) Triton kernels for BAGEL, rebased onto current
mainwith the following fixes:_forward_sp_gen()crash: The MoT refactor replaced separate*_moe_genlayers with unified MoT layers using routing indices. The SP attention path still referenced deleted layers (qkv_proj_moe_gen,o_proj_moe_gen,q_norm_moe_gen,k_norm_moe_gen), causingAttributeError. Rewrote to use MoT unified API.biasin kernel tests (ZJY0516 review):test_mot_qkv_parallelandtest_mot_o_projnow test bothbias=Trueandbias=Falsevia@pytest.mark.parametrize.configs/README (ZJY0516 review): Documents the 3-tier config loading mechanism and tuning instructions.bagel_transformer.py(11 conflicts) andpipeline_bagel.py(1 conflict) — merged upstream'squant_config/prefixparams and SP support with MoT layer changes.Note: The original PR also includes a minor change to
benchmarks/diffusion/diffusion_benchmark_serving.py:--save-diroption to save generated images for visual inspection--warmup-num-inference-stepsdefault changed from 1→2 (BAGEL runsnum_timesteps-1denoising iterations, so 1 step = 0 actual denoising)Test Plan
Test Result
MoT kernel tests: 21/21 passed on RTX 6000 Ada (no tuned config — conservative defaults):
CPU tests: 13/13 passed (test_arg_utils)
E2E serving benchmark: 5/5 successful on 2× RTX 6000 Ada (BAGEL-7B-MoT, vbench t2i, 5 prompts, concurrency=1):
main)Note: Running on RTX 6000 Ada without tuned MoT configs (conservative defaults). The original PR reported ~1.29x e2e speedup on A100 with auto-tuned configs. The speedup is expected to be larger on A100/H100 with tuned configs since:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model. Please runmkdocs serveto sync the documentation editions to./docs.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)