Skip to content

[WIP][Kernel] Generalized LL GEMMs with PDL#39897

Draft
LopezCastroRoberto wants to merge 34 commits into
vllm-project:mainfrom
LopezCastroRoberto:feature/ll_gemm_pdl
Draft

[WIP][Kernel] Generalized LL GEMMs with PDL#39897
LopezCastroRoberto wants to merge 34 commits into
vllm-project:mainfrom
LopezCastroRoberto:feature/ll_gemm_pdl

Conversation

@LopezCastroRoberto
Copy link
Copy Markdown
Contributor

@LopezCastroRoberto LopezCastroRoberto commented Apr 15, 2026

This is still WIP, but feedback is welcomed

Motivation

#38772

Screenshot 2026-05-12 at 16 14 18 Screenshot 2026-05-12 at 16 14 03

Benchmarks

LL BF16 kernels (nvidia/DeepSeek-V3.2-NVFP4)

vllm serve nvidia/DeepSeek-V3.2-NVFP4 -tp 4 --port 8001 --kv-cache-dtype fp8
vllm bench serve   --backend vllm --model nvidia/DeepSeek-V3.2-NVFP4   --input-len 128 --output-len 2048 --num-prompts 16 --max-concurrency 8

MAIN:

============ Serving Benchmark Result ============
Successful requests:                     16        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  50.28     
Total input tokens:                      10599     
Total generated tokens:                  32768     
Request throughput (req/s):              0.32      
Output token throughput (tok/s):         651.71    
Peak output token throughput (tok/s):    680.00    
Peak concurrent requests:                16.00     
Total token throughput (tok/s):          862.52    
---------------Time to First Token----------------
Mean TTFT (ms):                          96.97     
Median TTFT (ms):                        91.43     
P99 TTFT (ms):                           139.45    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.23     
Median TPOT (ms):                        12.24     
P99 TPOT (ms):                           12.28     
---------------Inter-token Latency----------------
Mean ITL (ms):                           12.23     
Median ITL (ms):                         12.18     
P99 ITL (ms):                            12.92     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     16        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  46.50     
Total input tokens:                      10599     
Total generated tokens:                  32768     
Request throughput (req/s):              0.34      
Output token throughput (tok/s):         704.64    
Peak output token throughput (tok/s):    736.00    
Peak concurrent requests:                16.00     
Total token throughput (tok/s):          932.56    
---------------Time to First Token----------------
Mean TTFT (ms):                          83.26     
Median TTFT (ms):                        63.63     
P99 TTFT (ms):                           112.33    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          11.32     
Median TPOT (ms):                        11.32     
P99 TPOT (ms):                           11.33     
---------------Inter-token Latency----------------
Mean ITL (ms):                           11.32     
Median ITL (ms):                         11.13     
P99 ITL (ms):                            11.92     
==================================================

TPOT improves ~8%

LL FP8 kernels (mistralai/Mistral-Medium-3.5-128B) -- Per-tensor scaling

vllm serve mistralai/Mistral-Medium-3.5-128B -tp 8
vllm bench serve     --backend vllm --model mistralai/Mistral-Medium-3.5-128B     --input-len 128 --output-len 2048 --num-prompts 16     --max-concurrency 8

MAIN:

============ Serving Benchmark Result ============
Successful requests:                     16        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  30.00     
Total input tokens:                      2049      
Total generated tokens:                  32768     
Request throughput (req/s):              0.53      
Output token throughput (tok/s):         1092.14   
Peak output token throughput (tok/s):    1104.00   
Peak concurrent requests:                15.00     
Total token throughput (tok/s):          1160.43   
---------------Time to First Token----------------
Mean TTFT (ms):                          39.33     
Median TTFT (ms):                        40.08     
P99 TTFT (ms):                           44.58     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.31      
Median TPOT (ms):                        7.31      
P99 TPOT (ms):                           7.31      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.31      
Median ITL (ms):                         7.30      
P99 ITL (ms):                            7.49      
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     16        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  28.66     
Total input tokens:                      2049      
Total generated tokens:                  32768     
Request throughput (req/s):              0.56      
Output token throughput (tok/s):         1143.34   
Peak output token throughput (tok/s):    1160.00   
Peak concurrent requests:                16.00     
Total token throughput (tok/s):          1214.83   
---------------Time to First Token----------------
Mean TTFT (ms):                          33.16     
Median TTFT (ms):                        33.31     
P99 TTFT (ms):                           38.25     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          6.98      
Median TPOT (ms):                        6.98      
P99 TPOT (ms):                           6.99      
---------------Inter-token Latency----------------
Mean ITL (ms):                           6.98      
Median ITL (ms):                         6.98      
P99 ITL (ms):                            7.14      
==================================================

TPOT improves  ~5%

  Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft April 15, 2026 11:48
@mergify mergify Bot added the performance Performance-related issues label Apr 15, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a low-latency (LL) router GEMM implementation using cuteDSL (CUTLASS Python) to optimize MoE gate linear layers on SM90+ hardware for small batch sizes (M <= 16). The changes include new bf16 and fp8 kernels, a JIT-compiled Python wrapper, and integration into the GateLinear dispatch logic as a high-priority tier. Review feedback identifies two critical issues: the fp8 implementation will fail if the hidden dimension K is odd due to invalid int16 casting, and the kernels' use of 32-byte alignment for vectorized loads requires K to be a multiple of 16 to prevent incorrect data access or crashes during expert/token indexing.

Comment on lines +94 to +96
if is_fp8:
a_flat = hidden_states.view(torch.int16).reshape(-1)
b_flat = router_weight.view(torch.int16).reshape(-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fp8 implementation will crash with a RuntimeError if K is odd because torch.view(torch.int16) requires the last dimension to be a multiple of 2. Additionally, K_eff = K // 2 on line 108 will cause the kernel to ignore the last element of each row, leading to incorrect results. The dispatch logic in gate_linear.py should be updated to ensure K is even for fp8 inputs, or the implementation should be adjusted to handle the remainder.

Comment on lines +100 to +105
bp = (gB.iterator + (n_idx * K_dim + kb)).align(32)
bt = cute.make_tensor(bp, cute.make_layout((VPT,)))
br = cute.make_rmem_tensor((VPT,), elem)
cute.autovec_copy(bt, br)
for m in cutlass.range_constexpr(M):
ap = (gA.iterator + (m * K_dim + kb)).align(32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The kernel uses .align(32) for bf16 loads, which requires the hidden dimension K to be a multiple of 16 (32 bytes) so that every expert's row in the weight matrix gB and every token's row in gA is properly aligned. If K is not a multiple of 16, n_idx * K_dim (for n_idx > 0) or m * K_dim (for m > 0) will not be 32-byte aligned, leading to incorrect data being loaded or a crash. Similar alignment assumptions exist in the fp8 kernel. The dispatch logic should enforce K % 16 == 0 or the kernel should be made truly alignment-agnostic.

@LucasWilkinson
Copy link
Copy Markdown
Collaborator

should we be using TVM FFI? see: Dao-AILab/flash-attention#2042

Comment thread vllm/model_executor/layers/fused_moe/router/_ll_router_gemm_kernels.py Outdated
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto LopezCastroRoberto changed the title [Kernel] Generalized LL GEMMs with PDL [WIP][Kernel] Generalized LL GEMMs with PDL Apr 17, 2026
LopezCastroRoberto and others added 13 commits April 17, 2026 11:51
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@mergify mergify Bot added the deepseek Related to DeepSeek models label Apr 28, 2026
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as ready for review April 28, 2026 15:56
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 28, 2026

Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft April 28, 2026 16:01
Comment on lines +90 to +119
@dsl_user_op
def fused_fp8_mma_2n(
c0,
c1,
c2,
c3,
c4,
c5,
c6,
c7,
a0_lo,
a0_hi,
a1_lo,
a1_hi,
a2_lo,
a2_hi,
a3_lo,
a3_hi,
b0_lo,
b0_hi,
b1_lo,
b1_hi,
b2_lo,
b2_hi,
b3_lo,
b3_hi,
*,
loc=None,
ip=None,
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use cute primitives instead?

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 28, 2026

Hi @LopezCastroRoberto, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LopezCastroRoberto.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 30, 2026
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as ready for review May 12, 2026 12:48
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft May 13, 2026 20:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models needs-rebase performance Performance-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants