[lora/moe] Improve fused MoE‑LoRA kernel indexing and memory access#32770
[lora/moe] Improve fused MoE‑LoRA kernel indexing and memory access#32770vllm-bot merged 137 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several well-implemented optimizations to the fused MoE-LoRA Triton kernel, aiming to enhance performance. The changes, including the use of more efficient integer types, improved memory access patterns via L2 cache and removal of modulo operations, and a reduction in launched CUDA threads, are clear and logical. I have one suggestion to further simplify the kernel by removing a redundant mask, which could slightly improve performance and readability. Overall, this is a high-quality contribution.
|
Hi @cwazai, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @cwazai, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
[before] ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 60.14
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 13.30
Output token throughput (tok/s): 1330.22
Peak output token throughput (tok/s): 3120.00
Peak concurrent requests: 111.00
Total Token throughput (tok/s): 14951.67
---------------Time to First Token----------------
Mean TTFT (ms): 615.32
Median TTFT (ms): 366.47
P99 TTFT (ms): 3908.08
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 54.31
Median TPOT (ms): 54.83
P99 TPOT (ms): 64.95
---------------Inter-token Latency----------------
Mean ITL (ms): 53.77
Median ITL (ms): 26.41
P99 ITL (ms): 411.98
==================================================[1.cheaper bounds checks + more int32/uint32 math] ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 60.53
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 13.22
Output token throughput (tok/s): 1321.65
Peak output token throughput (tok/s): 3211.00
Peak concurrent requests: 110.00
Total Token throughput (tok/s): 14855.29
---------------Time to First Token----------------
Mean TTFT (ms): 598.98
Median TTFT (ms): 346.95
P99 TTFT (ms): 3881.16
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 54.89
Median TPOT (ms): 55.97
P99 TPOT (ms): 65.15
---------------Inter-token Latency----------------
Mean ITL (ms): 54.34
Median ITL (ms): 26.08
P99 ITL (ms): 437.81
==================================================[2.remove % N and use mask for N dimension] ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 59.04
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 13.55
Output token throughput (tok/s): 1355.12
Peak output token throughput (tok/s): 3200.00
Peak concurrent requests: 114.00
Total Token throughput (tok/s): 15231.53
---------------Time to First Token----------------
Mean TTFT (ms): 602.63
Median TTFT (ms): 332.31
P99 TTFT (ms): 3612.90
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 53.32
Median TPOT (ms): 53.87
P99 TPOT (ms): 60.76
---------------Inter-token Latency----------------
Mean ITL (ms): 52.79
Median ITL (ms): 25.86
P99 ITL (ms): 430.04
==================================================
``` python
[3.optional B-side L2 cache modifier]
``` python
============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 59.42
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 13.46
Output token throughput (tok/s): 1346.39
Peak output token throughput (tok/s): 3200.00
Peak concurrent requests: 115.00
Total Token throughput (tok/s): 15133.47
---------------Time to First Token----------------
Mean TTFT (ms): 632.19
Median TTFT (ms): 356.63
P99 TTFT (ms): 3627.23
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 53.42
Median TPOT (ms): 54.57
P99 TPOT (ms): 60.28
---------------Inter-token Latency----------------
Mean ITL (ms): 52.89
Median ITL (ms): 26.37
P99 ITL (ms): 487.52
==================================================
陈 |
| k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) | ||
| # GDC wait waits for ALL programs in the prior kernel to complete | ||
| # before continuing. | ||
| # pre-fetch lora weight |
There was a problem hiding this comment.
Why delete these comments?
There was a problem hiding this comment.
Thank you for pointing this out. This was an oversight on my part—I originally intended to finalize these adjustments in the final step once everything was confirmed ready to merge. I’ll restore the comments right away.
Thanks again for the careful review.
| a_intermediate_cache1 = a_intermediate_cache1.view( | ||
| -1, a_intermediate_cache1.shape[3] | ||
| ) | ||
| a2d = a_intermediate_cache1.view(-1, a_intermediate_cache1.shape[3]) |
There was a problem hiding this comment.
Why rename the a_intermediate_cache1 ?
| accumulator += tl.dot(a, b) | ||
| # Advance the ptrs to the next K block. | ||
|
|
||
| acc += tl.dot(a, b) |
There was a problem hiding this comment.
Please don't rename these variables
There was a problem hiding this comment.
Thanks for the feedback. I've reverted the variable names as requested.
| # Early exit for the no-lora case. | ||
| max_loras_u32 = tl.full((), MAX_LORAS_TOTAL, tl.uint32) | ||
| num_valid_u32 = tl.full((), num_valid_tokens, tl.uint32) | ||
| num_experts_u32 = tl.full((), num_experts, tl.uint32) |
There was a problem hiding this comment.
@jikunshang Will changing it to uint32 affect XPU?
There was a problem hiding this comment.
The relevant part has been removed.
Modifications to remove irrelevant optimizationsThis PR introduces three targeted optimizations to the LoRA MoE kernel to reduce overhead and improve memory efficiency, while keeping the core algorithm unchanged: Reduce invalid programs: Use lora_ids.numel() instead of max_loras for grid axis 2, eliminating programs that would immediately exit. unit tests (tests/lora/test_fused_moe_lora_kernel.py): All pass ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 57.01
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 14.03
Output token throughput (tok/s): 1403.25
Peak output token throughput (tok/s): 3120.00
Peak concurrent requests: 112.00
Total Token throughput (tok/s): 15772.48
---------------Time to First Token----------------
Mean TTFT (ms): 581.81
Median TTFT (ms): 345.49
P99 TTFT (ms): 3510.19
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 51.31
Median TPOT (ms): 52.83
P99 TPOT (ms): 58.08
---------------Inter-token Latency----------------
Mean ITL (ms): 50.80
Median ITL (ms): 26.36
P99 ITL (ms): 264.29
==================================================[after] ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 55.03
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 14.54
Output token throughput (tok/s): 1453.69
Peak output token throughput (tok/s): 3200.00
Peak concurrent requests: 119.00
Total Token throughput (tok/s): 16339.45
---------------Time to First Token----------------
Mean TTFT (ms): 580.13
Median TTFT (ms): 489.42
P99 TTFT (ms): 3385.90
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 49.53
Median TPOT (ms): 49.56
P99 TPOT (ms): 58.31
---------------Inter-token Latency----------------
Mean ITL (ms): 49.03
Median ITL (ms): 26.21
P99 ITL (ms): 227.43
================================================== |
|
Hi @cwazai, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
/ |
| @@ -62,6 +62,7 @@ def _fused_moe_lora_kernel( | |||
| num_experts, | |||
| lora_ids, | |||
| adapter_enabled, | |||
| MAX_LORAS_TOTAL, # <<< PR2: new, used for masks when grid axis-2 != max_loras | |||
There was a problem hiding this comment.
why do we need this new variable? please define it. We have PR to update it to be the num-active-lora #32005. cc: @yugong333
There was a problem hiding this comment.
The kernel previously used tl.num_programs(axis=2) to determine the mask bound for sorted_token_ids/expert_ids/num_tokens_post_padded. After changing grid axis‑2 from max_loras to lora_ids.numel() (to avoid launching inactive LoRA programs), we can no longer rely on that. MAX_LORAS_TOTAL is the size of the first dimension of those auxiliary tensors, used solely for bounds‑checking loads.
Definition in Python:
max_loras_total = sorted_token_ids.shape[0] # exactly the required upper bound
There was a problem hiding this comment.
changing the grid axis is also done in PR: #32005. Can you please check that there are no conflicts or overlaps? I still recommend renaming this variable since MAX_LORAS_TOTAL vs MAX_LORAS is not clear from the naming
There was a problem hiding this comment.
You're correct that MAX_LORAS_TOTAL vs MAX_LORAS is confusing. I've renamed this parameter to max_loras to better indicate it represents the actual maximum number of LoRAs (the dimension of the stacked weight tensors). This makes the kernel signature clearer and more consistent with existing naming conventions.
| @@ -133,7 +136,8 @@ def _fused_moe_lora_kernel( | |||
| cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) | |||
| cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size | |||
|
|
|||
| offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N | |||
| # remove modulo wrap-around | |||
| offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) | |||
There was a problem hiding this comment.
nice! this is cleaner and can also be updated in fused_moe kernel in follow-up PR. Why do we not see gains from it though?
There was a problem hiding this comment.
You are right – the primary cost of these GEMM kernels is the tl.dot compute and memory traffic. Index‑arithmetic changes (removing % N, using int32) have negligible impact on end‑to‑end latency. We keep them because:
Removing % N avoids wrap‑around reads at the boundary, making the load‑mask behavior more explicit.
Using int32 for offs_bn, offs_token_id reduces register pressure slightly and is sufficient for the problem sizes we encounter.
These are clean‑up changes that align with typical Triton GEMM patterns; we don’t claim measurable performance improvement from them.
There was a problem hiding this comment.
Yes this is a nice cleanup and we should apply it in fused_moe kernel too https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py#L200
| @@ -247,13 +256,14 @@ def _fused_moe_lora_shrink( | |||
| } | |||
|
|
|||
| b_ptr = _get_ptr(lora_a_stacked, device) | |||
| max_loras_total = sorted_token_ids.shape[0] # <<< PR2: new | |||
There was a problem hiding this comment.
The tensors are allocated with shape[0] = max_loras. Valid lora_ids are [0, max_loras-1]. Adding +1 would allow an out‑of‑bounds index to pass the mask. To be extra safe, we added an early‑return in the kernel:
if lora_id >= max_loras:
return
This guarantees no out‑of‑bounds access. No +1 is needed.
| @@ -270,6 +280,7 @@ def _fused_moe_lora_shrink( | |||
| num_experts, | |||
| lora_ids, | |||
| adapter_enabled, | |||
| max_loras_total, # new | |||
There was a problem hiding this comment.
do we need this and max_loras? the naming is confusing.
| grid = lambda META: ( | ||
| triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), | ||
| split_k * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), |
There was a problem hiding this comment.
what effect does this split_k change have and is it needed?
There was a problem hiding this comment.
Please revert this change
There was a problem hiding this comment.
You‘re right. I’ve removed that change as it wasn‘t needed.
|
Hi @cwazai, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
code changes in this commitunit tests (tests/lora/test_fused_moe_lora_kernel.py): All pass ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 64.73
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 12.36
Output token throughput (tok/s): 1235.91
Peak output token throughput (tok/s): 3120.00
Peak concurrent requests: 110.00
Total Token throughput (tok/s): 13891.68
---------------Time to First Token----------------
Mean TTFT (ms): 691.20
Median TTFT (ms): 520.48
P99 TTFT (ms): 3698.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 58.19
Median TPOT (ms): 60.41
P99 TPOT (ms): 70.03
---------------Inter-token Latency----------------
Mean ITL (ms): 57.61
Median ITL (ms): 26.57
P99 ITL (ms): 521.51
==================================================[after] ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 62.61
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 12.78
Output token throughput (tok/s): 1277.83
Peak output token throughput (tok/s): 3280.00
Peak concurrent requests: 110.00
Total Token throughput (tok/s): 14362.77
---------------Time to First Token----------------
Mean TTFT (ms): 690.31
Median TTFT (ms): 509.80
P99 TTFT (ms): 3449.49
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 56.06
Median TPOT (ms): 56.58
P99 TPOT (ms): 66.55
---------------Inter-token Latency----------------
Mean ITL (ms): 55.50
Median ITL (ms): 25.09
P99 ITL (ms): 526.90
================================================== |
| @@ -104,7 +106,10 @@ def _fused_moe_lora_kernel( | |||
| if moe_enabled == 0: | |||
| # Early exit for the no moe lora case. | |||
| return | |||
| max_loras = tl.num_programs(axis=2) | |||
|
|
|||
| max_loras = MAX_LORAS_TOTAL # <<< : was tl.num_programs(axis=2) | |||
There was a problem hiding this comment.
why not directly pass MAX_LORAS?
There was a problem hiding this comment.
Good suggestion! I've updated the code to directly pass lora_*_stacked[0].shape[0] from Python to the kernel as the max_loras parameter. This eliminates the intermediate variable and makes the data flow more direct and transparent.
| len(lora_b_stacked), | ||
| lora_b_stacked[0].shape[0], | ||
| lora_ids.numel(), # <<< PR2: was lora_b_stacked[0].shape[0] |
There was a problem hiding this comment.
Let's also please revert this change from this PR. I think this will break cudagraph.
It is handled in this PR: #32005.
There was a problem hiding this comment.
Thank you for pointing out the overlap with PR#32005. You're absolutely right - I've reverted the grid axis-2 changes (lora_ids.numel() → lora_*_stacked[0].shape[0]) to avoid any conflict.
|
Also please run all the pytests in the lora folder including the model ones not just the kernel ones thanks! |
|
Hi @cwazai, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
unit tests : All pass [before] ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 50.76
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 15.76
Output token throughput (tok/s): 1575.96
Peak output token throughput (tok/s): 3300.00
Peak concurrent requests: 115.00
Total Token throughput (tok/s): 17713.78
---------------Time to First Token----------------
Mean TTFT (ms): 554.45
Median TTFT (ms): 514.38
P99 TTFT (ms): 2700.87
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 45.48
Median TPOT (ms): 46.16
P99 TPOT (ms): 49.01
---------------Inter-token Latency----------------
Mean ITL (ms): 45.03
Median ITL (ms): 25.58
P99 ITL (ms): 175.31
==================================================[after] ============ Serving Benchmark Result ============
Successful requests: 800
Failed requests: 0
Maximum request concurrency: 80
Benchmark duration (s): 49.11
Total input tokens: 819200
Total generated tokens: 80000
Request throughput (req/s): 16.29
Output token throughput (tok/s): 1629.03
Peak output token throughput (tok/s): 3360.00
Peak concurrent requests: 111.00
Total Token throughput (tok/s): 18310.29
---------------Time to First Token----------------
Mean TTFT (ms): 486.53
Median TTFT (ms): 332.43
P99 TTFT (ms): 2597.25
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 44.49
Median TPOT (ms): 45.49
P99 TPOT (ms): 47.46
---------------Inter-token Latency----------------
Mean ITL (ms): 44.04
Median ITL (ms): 24.79
P99 ITL (ms): 167.76
================================================== |
Hi Reviewer, |
Yes just the MoE-LORA related pytests:
|
| @@ -104,7 +106,9 @@ def _fused_moe_lora_kernel( | |||
| if moe_enabled == 0: | |||
| # Early exit for the no moe lora case. | |||
| return | |||
| max_loras = tl.num_programs(axis=2) | |||
|
|
|||
| if lora_id >= max_loras: | |||
There was a problem hiding this comment.
Add comment on why we need to add this case. Technically since you are passing it correctly it shouldn't go in here but ok to have this check.
|
LGTM to me now - thanks! cc: @jeejeelee |
Thank you for the clarification. I have now run these MoE-LoRA tests as requested. Here are the results: |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: cwazai <38356712+cwazai@users.noreply.github.com>
This PR optimizes the Triton kernel’s inner‑loop indexing and memory‑access patterns
Key changes:
Use int32/uint32 for indices (cheaper than int64 in PTX).
Remove the % N modulo for column offsets; instead compute a plain pid_n * BLOCK_SIZE_N and mask with cn_mask.
Pass MAX_LORAS_TOTAL as a Python scalar (not via tl.num_programs), and let the kernel early‑exit invalid lora_id using unsigned comparison (automatically filters -1).
Reduce axis‑2 program count from lora_a_stacked[0].shape[0] to lora_ids.numel() – launches only for actually‑used LoRA adapters.
Add USE_B_L2_CACHE=True for B‑matrix loads (.ca modifier), improving cache locality for repeated expert‑weight reads.
Use other=-1 for expert_id/token_id loads, simplifying mask logic.
Performance impact:
Faster index arithmetic in the hot path.
Fewer wasted CUDA threads (axis‑2 early‑exit).
Better L2 hit rate for B (expert weights).
Testing
unit tests (
tests/lora/test_fused_moe_lora_kernel.py): All pass[befoe]
======================================================================================== 48 passed, 14 warnings in 279.58s (0:04:39) ========================================================================================
[after]
============================================================== 48 passed, 14 warnings in 397.86s (0:06:37) ==============================================================
[after]
tests/lora/test_olmoe_tp.py:All pass
=============================================================== 6 passed, 8 warnings in 406.74s (0:06:46) ===============================================================
Benchmarks
Hardware: Nvidia H800, Software: CUDA/driver:13.0
Workload: model:Qwen3-30B-A3B, rank:64
Two Lora
[before]
[after]