Skip to content

[lora/moe] Improve fused MoE‑LoRA kernel indexing and memory access#32770

Merged
vllm-bot merged 137 commits intovllm-project:mainfrom
cwazai:pr1/moe-lora-kernel-opt
Jan 26, 2026
Merged

[lora/moe] Improve fused MoE‑LoRA kernel indexing and memory access#32770
vllm-bot merged 137 commits intovllm-project:mainfrom
cwazai:pr1/moe-lora-kernel-opt

Conversation

@cwazai
Copy link
Copy Markdown
Contributor

@cwazai cwazai commented Jan 21, 2026

This PR optimizes the Triton kernel’s inner‑loop indexing and memory‑access patterns
Key changes:
Use int32/uint32 for indices (cheaper than int64 in PTX).
Remove the % N modulo for column offsets; instead compute a plain pid_n * BLOCK_SIZE_N and mask with cn_mask.
Pass MAX_LORAS_TOTAL as a Python scalar (not via tl.num_programs), and let the kernel early‑exit invalid lora_id using unsigned comparison (automatically filters -1).
Reduce axis‑2 program count from lora_a_stacked[0].shape[0] to lora_ids.numel() – launches only for actually‑used LoRA adapters.
Add USE_B_L2_CACHE=True for B‑matrix loads (.ca modifier), improving cache locality for repeated expert‑weight reads.
Use other=-1 for expert_id/token_id loads, simplifying mask logic.
Performance impact:
Faster index arithmetic in the hot path.
Fewer wasted CUDA threads (axis‑2 early‑exit).
Better L2 hit rate for B (expert weights).

Testing

unit tests (tests/lora/test_fused_moe_lora_kernel.py): All pass
[befoe]
======================================================================================== 48 passed, 14 warnings in 279.58s (0:04:39) ========================================================================================
[after]
============================================================== 48 passed, 14 warnings in 397.86s (0:06:37) ==============================================================
[after]
tests/lora/test_olmoe_tp.py:All pass
=============================================================== 6 passed, 8 warnings in 406.74s (0:06:46) ===============================================================

Benchmarks

Hardware: Nvidia H800, Software: CUDA/driver:13.0
Workload: model:Qwen3-30B-A3B, rank:64

Two Lora

[before]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  60.10     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.31     
Output token throughput (tok/s):         1331.02   
Peak output token throughput (tok/s):    3125.00   
Peak concurrent requests:                110.00    
Total Token throughput (tok/s):          14960.72  
---------------Time to First Token----------------
Mean TTFT (ms):                          642.25    
Median TTFT (ms):                        518.54    
P99 TTFT (ms):                           2830.93   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          54.02     
Median TPOT (ms):                        54.39     
P99 TPOT (ms):                           66.80     
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.48     
Median ITL (ms):                         26.41     
P99 ITL (ms):                            457.24    
==================================================

[after]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  58.53     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.67     
Output token throughput (tok/s):         1366.80   
Peak output token throughput (tok/s):    3280.00   
Peak concurrent requests:                113.00    
Total Token throughput (tok/s):          15362.88  
---------------Time to First Token----------------
Mean TTFT (ms):                          616.34    
Median TTFT (ms):                        480.47    
P99 TTFT (ms):                           3497.42   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          52.60     
Median TPOT (ms):                        52.30     
P99 TPOT (ms):                           62.77     
---------------Inter-token Latency----------------
Mean ITL (ms):                           52.07     
Median ITL (ms):                         25.21     
P99 ITL (ms):                            397.00    
==================================================

@cwazai cwazai requested a review from jeejeelee as a code owner January 21, 2026 09:48
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several well-implemented optimizations to the fused MoE-LoRA Triton kernel, aiming to enhance performance. The changes, including the use of more efficient integer types, improved memory access patterns via L2 cache and removal of modulo operations, and a reduction in launched CUDA threads, are clear and logical. I have one suggestion to further simplify the kernel by removing a redundant mask, which could slightly improve performance and readability. Overall, this is a high-quality contribution.

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 21, 2026

Hi @cwazai, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 21, 2026

Hi @cwazai, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@cwazai
Copy link
Copy Markdown
Contributor Author

cwazai commented Jan 22, 2026

  1. cheaper bounds checks + more int32/uint32 math
  2. remove % N and use mask for N dimension
  3. optional B-side L2 cache modifier

[before]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  60.14     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.30     
Output token throughput (tok/s):         1330.22   
Peak output token throughput (tok/s):    3120.00   
Peak concurrent requests:                111.00    
Total Token throughput (tok/s):          14951.67  
---------------Time to First Token----------------
Mean TTFT (ms):                          615.32    
Median TTFT (ms):                        366.47    
P99 TTFT (ms):                           3908.08   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          54.31     
Median TPOT (ms):                        54.83     
P99 TPOT (ms):                           64.95     
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.77     
Median ITL (ms):                         26.41     
P99 ITL (ms):                            411.98    
==================================================

[1.cheaper bounds checks + more int32/uint32 math]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  60.53     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.22     
Output token throughput (tok/s):         1321.65   
Peak output token throughput (tok/s):    3211.00   
Peak concurrent requests:                110.00    
Total Token throughput (tok/s):          14855.29  
---------------Time to First Token----------------
Mean TTFT (ms):                          598.98    
Median TTFT (ms):                        346.95    
P99 TTFT (ms):                           3881.16   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          54.89     
Median TPOT (ms):                        55.97     
P99 TPOT (ms):                           65.15     
---------------Inter-token Latency----------------
Mean ITL (ms):                           54.34     
Median ITL (ms):                         26.08     
P99 ITL (ms):                            437.81    
==================================================

[2.remove % N and use mask for N dimension]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  59.04     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.55     
Output token throughput (tok/s):         1355.12   
Peak output token throughput (tok/s):    3200.00   
Peak concurrent requests:                114.00    
Total Token throughput (tok/s):          15231.53  
---------------Time to First Token----------------
Mean TTFT (ms):                          602.63    
Median TTFT (ms):                        332.31    
P99 TTFT (ms):                           3612.90   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.32     
Median TPOT (ms):                        53.87     
P99 TPOT (ms):                           60.76     
---------------Inter-token Latency----------------
Mean ITL (ms):                           52.79     
Median ITL (ms):                         25.86     
P99 ITL (ms):                            430.04    
==================================================
``` python
[3.optional B-side L2 cache modifier]
``` python
============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  59.42     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.46     
Output token throughput (tok/s):         1346.39   
Peak output token throughput (tok/s):    3200.00   
Peak concurrent requests:                115.00    
Total Token throughput (tok/s):          15133.47  
---------------Time to First Token----------------
Mean TTFT (ms):                          632.19    
Median TTFT (ms):                        356.63    
P99 TTFT (ms):                           3627.23   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.42     
Median TPOT (ms):                        54.57     
P99 TPOT (ms):                           60.28     
---------------Inter-token Latency----------------
Mean ITL (ms):                           52.89     
Median ITL (ms):                         26.37     
P99 ITL (ms):                            487.52    
==================================================

k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# pre-fetch lora weight
Copy link
Copy Markdown
Collaborator

@jeejeelee jeejeelee Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why delete these comments?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out. This was an oversight on my part—I originally intended to finalize these adjustments in the final step once everything was confirmed ready to merge. I’ll restore the comments right away.

Thanks again for the careful review.

a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
)
a2d = a_intermediate_cache1.view(-1, a_intermediate_cache1.shape[3])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rename the a_intermediate_cache1 ?

accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.

acc += tl.dot(a, b)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't rename these variables

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I've reverted the variable names as requested.

# Early exit for the no-lora case.
max_loras_u32 = tl.full((), MAX_LORAS_TOTAL, tl.uint32)
num_valid_u32 = tl.full((), num_valid_tokens, tl.uint32)
num_experts_u32 = tl.full((), num_experts, tl.uint32)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jikunshang Will changing it to uint32 affect XPU?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @jeejeelee , i will verify it on XPU.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The relevant part has been removed.

@cwazai
Copy link
Copy Markdown
Contributor Author

cwazai commented Jan 22, 2026

Modifications to remove irrelevant optimizations

This PR introduces three targeted optimizations to the LoRA MoE kernel to reduce overhead and improve memory efficiency, while keeping the core algorithm unchanged:

Reduce invalid programs: Use lora_ids.numel() instead of max_loras for grid axis 2, eliminating programs that would immediately exit.
Improve B-weight memory access: Add optional L2 caching hint (.ca) for B-weight loads to reduce DRAM pressure.
Remove unnecessary modulo: Replace (pid_n*BLOCK_SIZE_N + ...) % N with direct bounds checking in the load mask.

unit tests (tests/lora/test_fused_moe_lora_kernel.py): All pass
[after]
============================================================== 48 passed, 14 warnings in 397.86s (0:06:37) ==============================================================
[after]
tests/lora/test_olmoe_tp.py:=================================================================== 6 passed, 8 warnings in 300.51s (0:05:00) ====================================================================
[before]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  57.01     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              14.03     
Output token throughput (tok/s):         1403.25   
Peak output token throughput (tok/s):    3120.00   
Peak concurrent requests:                112.00    
Total Token throughput (tok/s):          15772.48  
---------------Time to First Token----------------
Mean TTFT (ms):                          581.81    
Median TTFT (ms):                        345.49    
P99 TTFT (ms):                           3510.19   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          51.31     
Median TPOT (ms):                        52.83     
P99 TPOT (ms):                           58.08     
---------------Inter-token Latency----------------
Mean ITL (ms):                           50.80     
Median ITL (ms):                         26.36     
P99 ITL (ms):                            264.29    
==================================================

[after]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  55.03     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              14.54     
Output token throughput (tok/s):         1453.69   
Peak output token throughput (tok/s):    3200.00   
Peak concurrent requests:                119.00    
Total Token throughput (tok/s):          16339.45  
---------------Time to First Token----------------
Mean TTFT (ms):                          580.13    
Median TTFT (ms):                        489.42    
P99 TTFT (ms):                           3385.90   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          49.53     
Median TPOT (ms):                        49.56     
P99 TPOT (ms):                           58.31     
---------------Inter-token Latency----------------
Mean ITL (ms):                           49.03     
Median ITL (ms):                         26.21     
P99 ITL (ms):                            227.43    
==================================================

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 22, 2026

Hi @cwazai, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@dcmaddix
Copy link
Copy Markdown
Contributor

  1. cheaper bounds checks + more int32/uint32 math
  2. remove % N and use mask for N dimension
  3. optional B-side L2 cache modifier

[before]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  60.14     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.30     
Output token throughput (tok/s):         1330.22   
Peak output token throughput (tok/s):    3120.00   
Peak concurrent requests:                111.00    
Total Token throughput (tok/s):          14951.67  
---------------Time to First Token----------------
Mean TTFT (ms):                          615.32    
Median TTFT (ms):                        366.47    
P99 TTFT (ms):                           3908.08   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          54.31     
Median TPOT (ms):                        54.83     
P99 TPOT (ms):                           64.95     
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.77     
Median ITL (ms):                         26.41     
P99 ITL (ms):                            411.98    
==================================================

[1.cheaper bounds checks + more int32/uint32 math]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  60.53     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.22     
Output token throughput (tok/s):         1321.65   
Peak output token throughput (tok/s):    3211.00   
Peak concurrent requests:                110.00    
Total Token throughput (tok/s):          14855.29  
---------------Time to First Token----------------
Mean TTFT (ms):                          598.98    
Median TTFT (ms):                        346.95    
P99 TTFT (ms):                           3881.16   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          54.89     
Median TPOT (ms):                        55.97     
P99 TPOT (ms):                           65.15     
---------------Inter-token Latency----------------
Mean ITL (ms):                           54.34     
Median ITL (ms):                         26.08     
P99 ITL (ms):                            437.81    
==================================================

[2.remove % N and use mask for N dimension]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  59.04     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.55     
Output token throughput (tok/s):         1355.12   
Peak output token throughput (tok/s):    3200.00   
Peak concurrent requests:                114.00    
Total Token throughput (tok/s):          15231.53  
---------------Time to First Token----------------
Mean TTFT (ms):                          602.63    
Median TTFT (ms):                        332.31    
P99 TTFT (ms):                           3612.90   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.32     
Median TPOT (ms):                        53.87     
P99 TPOT (ms):                           60.76     
---------------Inter-token Latency----------------
Mean ITL (ms):                           52.79     
Median ITL (ms):                         25.86     
P99 ITL (ms):                            430.04    
==================================================
``` python
[3.optional B-side L2 cache modifier]
``` python
============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  59.42     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              13.46     
Output token throughput (tok/s):         1346.39   
Peak output token throughput (tok/s):    3200.00   
Peak concurrent requests:                115.00    
Total Token throughput (tok/s):          15133.47  
---------------Time to First Token----------------
Mean TTFT (ms):                          632.19    
Median TTFT (ms):                        356.63    
P99 TTFT (ms):                           3627.23   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.42     
Median TPOT (ms):                        54.57     
P99 TPOT (ms):                           60.28     
---------------Inter-token Latency----------------
Mean ITL (ms):                           52.89     
Median ITL (ms):                         26.37     
P99 ITL (ms):                            487.52    
==================================================

/
Thanks for testing each optimization separately. It looks like there is no gain from the mod operator change? cc: @xyang16

@@ -62,6 +62,7 @@ def _fused_moe_lora_kernel(
num_experts,
lora_ids,
adapter_enabled,
MAX_LORAS_TOTAL, # <<< PR2: new, used for masks when grid axis-2 != max_loras
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this new variable? please define it. We have PR to update it to be the num-active-lora #32005. cc: @yugong333

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel previously used tl.num_programs(axis=2) to determine the mask bound for sorted_token_ids/expert_ids/num_tokens_post_padded. After changing grid axis‑2 from max_loras to lora_ids.numel() (to avoid launching inactive LoRA programs), we can no longer rely on that. MAX_LORAS_TOTAL is the size of the first dimension of those auxiliary tensors, used solely for bounds‑checking loads.

Definition in Python:

max_loras_total = sorted_token_ids.shape[0] # exactly the required upper bound

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the grid axis is also done in PR: #32005. Can you please check that there are no conflicts or overlaps? I still recommend renaming this variable since MAX_LORAS_TOTAL vs MAX_LORAS is not clear from the naming

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct that MAX_LORAS_TOTAL vs MAX_LORAS is confusing. I've renamed this parameter to max_loras to better indicate it represents the actual maximum number of LoRAs (the dimension of the stacked weight tensors). This makes the kernel signature clearer and more consistent with existing naming conventions.

@@ -133,7 +136,8 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size

offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
# remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! this is cleaner and can also be updated in fused_moe kernel in follow-up PR. Why do we not see gains from it though?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right – the primary cost of these GEMM kernels is the tl.dot compute and memory traffic. Index‑arithmetic changes (removing % N, using int32) have negligible impact on end‑to‑end latency. We keep them because:

Removing % N avoids wrap‑around reads at the boundary, making the load‑mask behavior more explicit.
Using int32 for offs_bn, offs_token_id reduces register pressure slightly and is sufficient for the problem sizes we encounter.
These are clean‑up changes that align with typical Triton GEMM patterns; we don’t claim measurable performance improvement from them.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -247,13 +256,14 @@ def _fused_moe_lora_shrink(
}

b_ptr = _get_ptr(lora_a_stacked, device)
max_loras_total = sorted_token_ids.shape[0] # <<< PR2: new
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be +1? please check #32277

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensors are allocated with shape[0] = max_loras. Valid lora_ids are [0, max_loras-1]. Adding +1 would allow an out‑of‑bounds index to pass the mask. To be extra safe, we added an early‑return in the kernel:

if lora_id >= max_loras:
return
This guarantees no out‑of‑bounds access. No +1 is needed.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@@ -270,6 +280,7 @@ def _fused_moe_lora_shrink(
num_experts,
lora_ids,
adapter_enabled,
max_loras_total, # new
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this and max_loras? the naming is confusing.

grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
split_k * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what effect does this split_k change have and is it needed?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this change

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You‘re right. I’ve removed that change as it wasn‘t needed.

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 23, 2026

Hi @cwazai, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@cwazai
Copy link
Copy Markdown
Contributor Author

cwazai commented Jan 23, 2026

code changes in this commit

unit tests (tests/lora/test_fused_moe_lora_kernel.py): All pass
[after]
============================================================== 48 passed, 14 warnings in 397.86s (0:06:37) ==============================================================
[after]
tests/lora/test_olmoe_tp.py:=================================================================== 6 passed, 8 warnings in 300.51s (0:05:00) ====================================================================
[before]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  64.73     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              12.36     
Output token throughput (tok/s):         1235.91   
Peak output token throughput (tok/s):    3120.00   
Peak concurrent requests:                110.00    
Total Token throughput (tok/s):          13891.68  
---------------Time to First Token----------------
Mean TTFT (ms):                          691.20    
Median TTFT (ms):                        520.48    
P99 TTFT (ms):                           3698.53   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          58.19     
Median TPOT (ms):                        60.41     
P99 TPOT (ms):                           70.03     
---------------Inter-token Latency----------------
Mean ITL (ms):                           57.61     
Median ITL (ms):                         26.57     
P99 ITL (ms):                            521.51    
==================================================

[after]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  62.61     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              12.78     
Output token throughput (tok/s):         1277.83   
Peak output token throughput (tok/s):    3280.00   
Peak concurrent requests:                110.00    
Total Token throughput (tok/s):          14362.77  
---------------Time to First Token----------------
Mean TTFT (ms):                          690.31    
Median TTFT (ms):                        509.80    
P99 TTFT (ms):                           3449.49   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          56.06     
Median TPOT (ms):                        56.58     
P99 TPOT (ms):                           66.55     
---------------Inter-token Latency----------------
Mean ITL (ms):                           55.50     
Median ITL (ms):                         25.09     
P99 ITL (ms):                            526.90    
==================================================

@@ -104,7 +106,10 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0:
# Early exit for the no moe lora case.
return
max_loras = tl.num_programs(axis=2)

max_loras = MAX_LORAS_TOTAL # <<< : was tl.num_programs(axis=2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not directly pass MAX_LORAS?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! I've updated the code to directly pass lora_*_stacked[0].shape[0] from Python to the kernel as the max_loras parameter. This eliminates the intermediate variable and makes the data flow more direct and transparent.

len(lora_b_stacked),
lora_b_stacked[0].shape[0],
lora_ids.numel(), # <<< PR2: was lora_b_stacked[0].shape[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also please revert this change from this PR. I think this will break cudagraph.
It is handled in this PR: #32005.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out the overlap with PR#32005. You're absolutely right - I've reverted the grid axis-2 changes (lora_ids.numel()lora_*_stacked[0].shape[0]) to avoid any conflict.

@dcmaddix
Copy link
Copy Markdown
Contributor

Also please run all the pytests in the lora folder including the model ones not just the kernel ones thanks!

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 24, 2026

Hi @cwazai, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@cwazai
Copy link
Copy Markdown
Contributor Author

cwazai commented Jan 24, 2026

unit tests : All pass
test_olmoe_tp.py
========================================================================== 6 passed, 8 warnings in 247.76s (0:04:07) ===========================================================================
test_fused_moe_lora_kernel.py
========================================================================= 48 passed, 14 warnings in 301.46s (0:05:01) ==========================================================================
test_moe_lora_align_sum.py
=============================================================================== 16 passed, 2 warnings in 13.72s ================================================================================
test_qwen3moe_tp.py
================================================================================= 1 passed, 3 warnings in 39.07s ==================================================================================
================================================================================= 1 passed, 3 warnings in 52.36s ==================================================================================
============================================================================ 1 passed, 3 warnings in 69.85s (0:01:09) =============================================================================
test_layers.py
=================================================================== 123 passed, 114 skipped, 2 warnings in 86.69s (0:01:26) ====================================================================
test_llm_with_multi_loras.py
=========================================================================== 2 passed, 4 warnings in 99.70s (0:01:39) ===========================================================================
test_punica_ops.py
========================================================================== 4272 passed, 2 warnings in 713.42s (0:11:53) ===========================================================================

[before]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  50.76     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              15.76     
Output token throughput (tok/s):         1575.96   
Peak output token throughput (tok/s):    3300.00   
Peak concurrent requests:                115.00    
Total Token throughput (tok/s):          17713.78  
---------------Time to First Token----------------
Mean TTFT (ms):                          554.45    
Median TTFT (ms):                        514.38    
P99 TTFT (ms):                           2700.87   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          45.48     
Median TPOT (ms):                        46.16     
P99 TPOT (ms):                           49.01     
---------------Inter-token Latency----------------
Mean ITL (ms):                           45.03     
Median ITL (ms):                         25.58     
P99 ITL (ms):                            175.31    
==================================================

[after]

============ Serving Benchmark Result ============
Successful requests:                     800       
Failed requests:                         0         
Maximum request concurrency:             80        
Benchmark duration (s):                  49.11     
Total input tokens:                      819200    
Total generated tokens:                  80000     
Request throughput (req/s):              16.29     
Output token throughput (tok/s):         1629.03   
Peak output token throughput (tok/s):    3360.00   
Peak concurrent requests:                111.00    
Total Token throughput (tok/s):          18310.29  
---------------Time to First Token----------------
Mean TTFT (ms):                          486.53    
Median TTFT (ms):                        332.43    
P99 TTFT (ms):                           2597.25   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          44.49     
Median TPOT (ms):                        45.49     
P99 TPOT (ms):                           47.46     
---------------Inter-token Latency----------------
Mean ITL (ms):                           44.04     
Median ITL (ms):                         24.79     
P99 ITL (ms):                            167.76    
==================================================

@cwazai
Copy link
Copy Markdown
Contributor Author

cwazai commented Jan 24, 2026

Also please run all the pytests in the lora folder including the model ones not just the kernel ones thanks!

Hi Reviewer,
Thanks for the guidance. I’ve updated the code and have also run the pytests for the LoRA folder as suggested.
To make sure I‘m focusing on the right scope: for the testing, should we concentrate specifically on the MoE LoRA related tests? Since the directory contains many test files, I want to ensure we're efficiently covering the impacted areas.
I‘ve also added the test data/results in the comment below for your reference.

@dcmaddix
Copy link
Copy Markdown
Contributor

Also please run all the pytests in the lora folder including the model ones not just the kernel ones thanks!

Hi Reviewer, Thanks for the guidance. I’ve updated the code and have also run the pytests for the LoRA folder as suggested. To make sure I‘m focusing on the right scope: for the testing, should we concentrate specifically on the MoE LoRA related tests? Since the directory contains many test files, I want to ensure we're efficiently covering the impacted areas. I‘ve also added the test data/results in the comment below for your reference.

Yes just the MoE-LORA related pytests:

pytest -s -v tests/lora/test_gptoss_tp.py tests/lora/test_fused_moe_lora_kernel.py tests/lora/test_moe_lora_align_sum.py tests/lora/test_deepseekv2_tp.py tests/lora/test_qwen3moe_tp.py tests/lora/test_olmoe_tp.py

@@ -104,7 +106,9 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0:
# Early exit for the no moe lora case.
return
max_loras = tl.num_programs(axis=2)

if lora_id >= max_loras:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on why we need to add this case. Technically since you are passing it correctly it shouldn't go in here but ok to have this check.

@dcmaddix
Copy link
Copy Markdown
Contributor

LGTM to me now - thanks! cc: @jeejeelee

@cwazai
Copy link
Copy Markdown
Contributor Author

cwazai commented Jan 25, 2026

Yes just the MoE-LORA related pytests:

pytest -s -v tests/lora/test_gptoss_tp.py tests/lora/test_fused_moe_lora_kernel.py tests/lora/test_moe_lora_align_sum.py tests/lora/test_deepseekv2_tp.py tests/lora/test_qwen3moe_tp.py tests/lora/test_olmoe_tp.py

Thank you for the clarification. I have now run these MoE-LoRA tests as requested. Here are the results:
test_fused_moe_lora_kernel.py
========================================================================= 48 passed, 14 warnings in 301.46s (0:05:01) ==========================================================================
test_moe_lora_align_sum.py
=============================================================================== 16 passed, 2 warnings in 13.72s ================================================================================
test_deepseekv2_tp.py
============================================================================ 4 passed, 6 warnings in 160.24s (0:02:40) ============================================================================
test_qwen3moe_tp.py
============================================================================ 3 passed, 5 warnings in 157.05s (0:02:37) ============================================================================
test_olmoe_tp.py
========================================================================== 6 passed, 8 warnings in 247.76s (0:04:07) ===========================================================================

@mergify mergify bot added deepseek Related to DeepSeek models frontend llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models nvidia labels Jan 25, 2026
@mergify mergify bot added the rocm Related to AMD ROCm label Jan 25, 2026
@mergify mergify bot added the tpu Related to Google TPUs label Jan 25, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 25, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cwazai.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: cwazai <38356712+cwazai@users.noreply.github.com>
@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Jan 25, 2026
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Jan 26, 2026
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Jan 26, 2026
@vllm-bot vllm-bot merged commit e33192b into vllm-project:main Jan 26, 2026
47 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output v1

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.