Skip to content

[Enhancement] MoT Fused Kernels: Triton Implementation#2897

Open
timzsu wants to merge 4 commits into
vllm-project:mainfrom
timzsu:enhancement/mot_fused_kernels
Open

[Enhancement] MoT Fused Kernels: Triton Implementation#2897
timzsu wants to merge 4 commits into
vllm-project:mainfrom
timzsu:enhancement/mot_fused_kernels

Conversation

@timzsu
Copy link
Copy Markdown
Contributor

@timzsu timzsu commented Apr 18, 2026

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Takes over PR #1328 by @yzhu802 — high-performance MoT (Mixture-of-Tokens) Triton kernels for BAGEL, rebased onto current main with the following fixes:

  • Fix _forward_sp_gen() crash: The MoT refactor replaced separate *_moe_gen layers with unified MoT layers using routing indices. The SP attention path still referenced deleted layers (qkv_proj_moe_gen, o_proj_moe_gen, q_norm_moe_gen, k_norm_moe_gen), causing AttributeError. Rewrote to use MoT unified API.
  • Parameterize bias in kernel tests (ZJY0516 review): test_mot_qkv_parallel and test_mot_o_proj now test both bias=True and bias=False via @pytest.mark.parametrize.
  • Add configs/ README (ZJY0516 review): Documents the 3-tier config loading mechanism and tuning instructions.
  • Resolve rebase conflicts: bagel_transformer.py (11 conflicts) and pipeline_bagel.py (1 conflict) — merged upstream's quant_config/prefix params and SP support with MoT layer changes.

Note: The original PR also includes a minor change to benchmarks/diffusion/diffusion_benchmark_serving.py:

  • --save-dir option to save generated images for visual inspection
  • --warmup-num-inference-steps default changed from 1→2 (BAGEL runs num_timesteps-1 denoising iterations, so 1 step = 0 actual denoising)

Test Plan

# MoT kernel unit tests (correctness + performance)
CUDA_VISIBLE_DEVICES=0 python -m pytest tests/diffusion/kernels/mot/ -v -s

# CPU tests (verify rebase didn't break existing functionality)
python -m pytest tests/engine/test_arg_utils.py -v

# E2E serving benchmark (2-GPU, BAGEL-7B-MoT)
python -m vllm_omni.entrypoints.cli.main serve ByteDance-Seed/BAGEL-7B-MoT \
  --omni --stage-configs-path tests/e2e/offline_inference/stage_configs/bagel_sharedmemory_2gpu_ci.yaml \
  --port 8099 --trust-remote-code
# Then in another terminal:
python benchmarks/diffusion/diffusion_benchmark_serving.py \
  --base-url http://localhost:8099 --model ByteDance-Seed/BAGEL-7B-MoT \
  --task t2i --dataset vbench --num-prompts 5 --warmup-num-inference-steps 2

Test Result

MoT kernel tests: 21/21 passed on RTX 6000 Ada (no tuned config — conservative defaults):

Test M max_abs max_rel cos_sim
QKV bias=T 1026 3.125e-02 2.597e-02 0.999993
QKV bias=F 1026 3.125e-02 1.563e-02 0.999996
QKV bias=T 8208 3.125e-02 7.813e-03 1.000000
O_proj bias=T 1026 3.125e-02 2.344e-02 0.999994
O_proj bias=F 1026 3.125e-02 1.550e-02 0.999996
RMSNorm layernorm 2048 6.250e-02 7.813e-03 0.999992
RMSNorm head_norm 2048 6.250e-02 7.813e-03 0.999987

CPU tests: 13/13 passed (test_arg_utils)

E2E serving benchmark: 5/5 successful on 2× RTX 6000 Ada (BAGEL-7B-MoT, vbench t2i, 5 prompts, concurrency=1):

Baseline (main) MoT Branch Speedup
Latency Mean 88.23s 83.98s 1.05x
Latency Median 88.93s 85.08s 1.05x
Benchmark Duration 441.13s 419.88s 1.05x
Successful 5/5 5/5 -

Note: Running on RTX 6000 Ada without tuned MoT configs (conservative defaults). The original PR reported ~1.29x e2e speedup on A100 with auto-tuned configs. The speedup is expected to be larger on A100/H100 with tuned configs since:

  1. MoT Triton kernels benefit from higher tensor core throughput on A100/H100
  2. Auto-tuned tile configs optimize for the specific hardware's memory hierarchy

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan. Please provide the test scripts & test commands. Please state the reasons if your codes don't require additional test scripts. For test file guidelines, please check the test style doc
  • The test results. Please paste the results comparison before and after, or the e2e results.
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model. Please run mkdocs serve to sync the documentation editions to ./docs.
  • (Optional) Release notes update. If your change is user-facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

Yufeng Zhu and others added 4 commits April 18, 2026 12:52
Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Rebase MoT fused kernels (PR vllm-project#1328) onto current main and fix issues:

- Rewrite _forward_sp_gen() to use MoT unified API instead of deleted
  *_moe_gen layers (qkv_proj_moe_gen, o_proj_moe_gen, q_norm_moe_gen,
  k_norm_moe_gen), which caused AttributeError when SP was active
- Parameterize bias in MoT kernel tests (ZJY0516 review feedback)
- Add configs/ directory with README documenting the 3-tier config
  loading mechanism (ZJY0516 review feedback)

Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@yzhu802
Copy link
Copy Markdown

yzhu802 commented Apr 18, 2026 via email

@timzsu
Copy link
Copy Markdown
Contributor Author

timzsu commented Apr 18, 2026

@yzhu802 Do you mean that you will continue working on this PR? If so, please let me know so our efforts won't overlap.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Ready for full review when draft status removed. Preliminary scan available on request.

@timzsu timzsu marked this pull request as ready for review April 18, 2026 09:32
@timzsu timzsu requested a review from hsliuustc0106 as a code owner April 18, 2026 09:32
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@timzsu
Copy link
Copy Markdown
Contributor Author

timzsu commented Apr 18, 2026

@princepride @hsliuustc0106 The PR is ready for review. Basically rebased PR #1328 onto the latest main and addressed a few remaining comments.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR #2897 Review — MoT Fused Kernels: Triton Implementation

Author: @timzsu (takes over PR #1328 by @yzhu802)
Size: +4041 / -224 across 16 files (substantial)


BLOCKER scan

Category Result
Correctness PASS
Reliability/Safety PASS
Breaking Changes PASS
Test Coverage PASS
Documentation PASS
Security PASS

OVERALL: NO BLOCKERS


What was validated

1. Triton kernels (mot_gemm.py, mot_rms_norm.py)

  • The unified MoT GEMM kernel (mot_unified_gemm_kernel) uses a clean 3-part architecture: router (_get_mot_pointers), static dispatch to compute cores (_core_standard_gemm / _core_weight_only_gemm), and a shared store phase. This is well-structured.
  • Compile-time constexpr branching on QUANT_TYPE, EVEN_K/N, stride-is-1 optimizations — no runtime overhead from unnecessary branches.
  • The router correctly maps PIDs to text vs VAE regions using num_pid_m_text * num_pid_n boundary, with indirect index loading via tl.load for MoT routing.
  • Store mask uses m_mask & n_mask — correctly avoids writing to padding rows where real_row_idxs may be 0.
  • RMS norm kernels (_mot_rms_norm_kernel, _mot_rms_norm_qk_kernel) properly route text/vae tokens and compute in fp32 with correct downcast.

2. MoT Layer wrappers (mot_qkv_parallel_linear.py, mot_row_parallel_linear.py)

  • Clean inheritance from vLLM's QKVParallelLinear / RowParallelLinear with gen_exp submodule for VAE weights.
  • _mot_gemm_dispatch routes by dtype: unquantized (BF16/FP16), FP8 W8A8, INT8 W8A16, or fallback to gather/scatter.
  • Non-CUDA platforms (not current_platform.is_cuda()) properly fall back to gather/scatter path.
  • Bias handling correctly accounts for skip_bias_add and TP rank (tp_rank > 0 → no bias for RowParallel).

3. Model integration (bagel_transformer.py)

  • _forward_sp_gen() correctly rewritten to use unified MoT API instead of deleted *_moe_gen layers.
  • The forward() path for both "und" and "gen" modes is now clean: text_indices=None for und mode, actual indices for gen mode.
  • BagelMLP updated to use SiluAndMul() (fused SiLU+gate) instead of separate nn.SiLU() + multiply — this is a minor performance improvement beyond the scope of MoT, but correct.
  • load_weights() properly remaps checkpoint names (_moe_gen suffixes → .gen_exp / .gen_weight), with ordered matching (most-specific patterns first).

4. Pipeline (pipeline_bagel.py)

  • load_weights() correctly expands allowed names for the new MoT parameter structure.

5. Test coverage

  • 21/21 kernel unit tests passed with high accuracy (cos_sim > 0.999, max_abs < 0.0625).
  • Tests cover both bias=True and bias=False via @pytest.mark.parametrize.
  • E2E benchmark on 2x RTX 6000 Ada shows 1.05x speedup over baseline (conservative defaults, not auto-tuned).
  • CPU tests (13/13) verify no regression to existing functionality.

6. Benchmark infrastructure

  • mot_linear_benchmarks.py provides auto-tuning with proper search space pruning (SRAM capacity, register pressure, occupancy).
  • 3-tier config loading (env → built-in → default) mirrors vLLM's fused_moe pattern.

7. Documentation

  • configs/README documents the tuning mechanism.
  • PR description is thorough with purpose, test plan, results, and checklist.

Non-blocking suggestions

  1. _save_generated_outputs silent exception (diffusion_benchmark_serving.py:856): The except Exception as e only prints a warning but does not track failures. Consider logging with logging.warning() for structured observability, or incrementing a failure counter.

  2. warmup-num-inference-steps default change (1→2): This is a behavior change for all diffusion benchmarks, not just BAGEL. Users running other models may notice slightly longer warmup. Consider adding a comment noting this was changed specifically for models where num_timesteps-1 == 0 denoising steps, and suggest that model-specific overrides could be added in the future if this causes issues for other pipelines.

  3. gen_exp as bare torch.nn.Module(): In mot_qkv_parallel_linear.py and mot_row_parallel_linear.py, self.gen_exp = torch.nn.Module() creates a module without __init__ call. While this works for creating a weight container, the gen_exp.quant_method = self.quant_method line adds an attribute after construction. This is functional but slightly unconventional. A minor comment explaining the design choice would help future readers.

  4. Benchmark test parametrize ids: In test_mot_linear.py, ids=lambda val: "" generates empty IDs for parametrized tests. Consider using descriptive IDs like f"img{image_num}" for clearer pytest output.

  5. get_mot_default_config small-M fallback: When M <= 16, BLOCK_SIZE_M=16 is used. If M is not a power of 2 (e.g., M=3), the kernel will still allocate a full 16-row tile. This is fine for correctness (masks handle it) but worth a brief comment.


Verdict

This is a well-executed performance enhancement. The MoT fused Triton kernels eliminate the gather/scatter overhead of running two separate linear projections for text and VAE tokens. The 3-tier config system and auto-tuning infrastructure are production-ready. Test coverage is thorough with both correctness and performance validation.

APPROVE — no blocking issues found.

@princepride princepride added the merge-test label to trigger buildkite merge test CI label Apr 18, 2026
@@ -0,0 +1,50 @@
This directory contains auto-tuned Triton kernel configurations for the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shoudl not be placed here

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is docs/user_guide/diffusion the correct location for this README.md?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @timzsu,

I noticed the CI here is currently blocked by the test_bagel_img2img_shared_memory_connector pixel mismatch. The max deviation is around ~12, which slightly exceeds the strict tolerance of 10. Since text2img passed, this is just the expected numerical drift from the Triton kernels' FMA accumulation over the diffusion steps.

To unblock this, I have already relaxed the tolerance to 15 in my original PR (#1328), along with addressing all your review comments (adding benchmark logs, fixing pytest ids, adding design comments, move related README.md to docs/user_guide/diffusion and renaming it) and rebasing everything cleanly on top of your changes and upstream/main.

Since my PR is now fully up-to-date, conflict-free, and contains the CI fix + all your commits, would it be possible to just merge my original PR (#1328) instead? I have linked that specific PR on my resume, so having it marked as 'Merged' would be incredibly helpful for me!

Thank you so much for your guidance on getting this framework-ready!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Let's move the discussion to your PR #1328. I will review your changes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

merge-test label to trigger buildkite merge test CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants