Skip to content

[CUDA] Increase number of output elements per-thread block if the K-dimension is small#20635

Open
gaugarg-nv wants to merge 3 commits intoggml-org:masterfrom
gaugarg-nv:small_k_optimization
Open

[CUDA] Increase number of output elements per-thread block if the K-dimension is small#20635
gaugarg-nv wants to merge 3 commits intoggml-org:masterfrom
gaugarg-nv:small_k_optimization

Conversation

@gaugarg-nv
Copy link
Contributor

@gaugarg-nv gaugarg-nv commented Mar 16, 2026

The K-dimension (inner dot product dimension) of the FFN-down matrices can be quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3-235B-A22B has a k-dimension of 1536. The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices.

This change increases the number of output elements per block for such matrices.

This change is also helpful for Tensor parallelism (PR #19378), where FFN-down is split along the K dimension.

Single GPU Performance on 1x RTX Pro 6000 Blackwell
model_type n_ubatch n_prompt master-avg_ts pr-avg_ts Speed-up
qwen3moe 30B.A3B Q4_K - Medium 1 512 231.4418 239.4359 1.03
qwen3moe 30B.A3B Q4_K - Medium 2 512 336.3564 353.3403 1.05
qwen3moe 30B.A3B Q4_K - Medium 4 512 498.7951 544.9048 1.09
qwen3moe 30B.A3B Q4_K - Medium 8 512 579.7136 580.2928 1.00
qwen3moe 30B.A3B Q4_K - Medium 16 512 936.1984 934.2313 1.00
qwen3moe 30B.A3B Q4_K - Medium 32 512 1456.243 1453.281 1.00
qwen3moe 30B.A3B Q4_K - Medium 64 512 2185.851 2185.245 1.00
qwen3moe 30B.A3B Q4_K - Medium 128 512 2970.54 2969.02 1.00
qwen3moe 30B.A3B Q4_K - Medium 256 512 4774.641 4779.619 1.00
qwen3moe 30B.A3B Q4_K - Medium 512 512 6587.268 6592.251 1.00
qwen3moe 30B.A3B Q8_0 1 512 188.6321 189.3348 1.00
qwen3moe 30B.A3B Q8_0 2 512 296.4038 304.8155 1.03
qwen3moe 30B.A3B Q8_0 4 512 446.3545 480.4061 1.08
qwen3moe 30B.A3B Q8_0 8 512 513.8571 513.5698 1.00
qwen3moe 30B.A3B Q8_0 16 512 814.9273 809.3003 0.99
qwen3moe 30B.A3B Q8_0 32 512 1309.532 1310.682 1.00
qwen3moe 30B.A3B Q8_0 64 512 2145.738 2147.491 1.00
qwen3moe 30B.A3B Q8_0 128 512 3039.336 3040.037 1.00
qwen3moe 30B.A3B Q8_0 256 512 4908.882 4912.358 1.00
qwen3moe 30B.A3B Q8_0 512 512 6795.054 6800.975 1.00
qwen3 4B Q4_K - Medium 1 512 270.4391 270.4142 1.00
qwen3 4B Q4_K - Medium 2 512 522.5462 523.2189 1.00
qwen3 4B Q4_K - Medium 4 512 888.7895 891.6788 1.00
qwen3 4B Q4_K - Medium 8 512 1331.554 1333.544 1.00
qwen3 4B Q4_K - Medium 16 512 2609.212 2613.457 1.00
qwen3 4B Q4_K - Medium 32 512 4131.247 4153.166 1.01
qwen3 4B Q4_K - Medium 64 512 6010.69 6040.168 1.00
qwen3 4B Q4_K - Medium 128 512 8336.18 8368.532 1.00
qwen3 4B Q4_K - Medium 256 512 12653.47 12680.27 1.00
qwen3 4B Q4_K - Medium 512 512 16933.91 16990.33 1.00
gpt-oss 20B MXFP4 MoE 1 512 327.1843 327.2503 1.00
gpt-oss 20B MXFP4 MoE 2 512 487.6076 487.2249 1.00
gpt-oss 20B MXFP4 MoE 4 512 722.2551 722.1628 1.00
gpt-oss 20B MXFP4 MoE 8 512 909.277 911.6954 1.00
gpt-oss 20B MXFP4 MoE 16 512 1475.936 1474.678 1.00
gpt-oss 20B MXFP4 MoE 32 512 2448.124 2449.26 1.00
gpt-oss 20B MXFP4 MoE 64 512 4019.604 4021.089 1.00
gpt-oss 20B MXFP4 MoE 128 512 5825.155 5820.645 1.00
gpt-oss 20B MXFP4 MoE 256 512 8901.978 8885.761 1.00
gpt-oss 20B MXFP4 MoE 512 512 11634.01 11628.95 1.00
llama 8B Q4_K - Medium 1 512 218.6202 218.5992 1.00
llama 8B Q4_K - Medium 2 512 426.0842 425.9328 1.00
llama 8B Q4_K - Medium 4 512 753.1047 753.2821 1.00
llama 8B Q4_K - Medium 8 512 1043.164 1042.673 1.00
llama 8B Q4_K - Medium 16 512 2306.093 2301.84 1.00
llama 8B Q4_K - Medium 32 512 3720.924 3730.606 1.00
llama 8B Q4_K - Medium 64 512 5444.508 5457.328 1.00
llama 8B Q4_K - Medium 128 512 7452.762 7408.557 0.99
llama 8B Q4_K - Medium 256 512 10174.56 10179.98 1.00
llama 8B Q4_K - Medium 512 512 12917.97 12923.66 1.00
llama 8B Q4_0 1 512 232.4301 232.74 1.00
llama 8B Q4_0 2 512 461.9919 461.8752 1.00
llama 8B Q4_0 4 512 889.2508 889.4003 1.00
llama 8B Q4_0 8 512 1377.003 1377.244 1.00
llama 8B Q4_0 16 512 2338.211 2335.362 1.00
llama 8B Q4_0 32 512 3822.713 3822.771 1.00
llama 8B Q4_0 64 512 5891.381 5883.02 1.00
llama 8B Q4_0 128 512 7699.878 7715.334 1.00
llama 8B Q4_0 256 512 10874.01 10842.19 1.00
llama 8B Q4_0 512 512 14034.76 14027.07 1.00
Tensor Parallelism Performance on 2x RTX Pro 6000 Blackwell with PR 19378
      ae0334f PR Speed-up
2xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 pp512 2165.37 2167.83 1.00
2xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 tg128 71.51 75.37 1.05
2xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 pp512 8357.29 8359.93 1.00
2xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 tg128 182.1 194.26 1.07
4xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 pp512 2367.91 2342.61 0.99
4xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 tg128 66.05 71.5 1.08
4xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 pp512 8408.73 8415.25 1.00
4xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 tg128 155.57 162.79 1.05

@gaugarg-nv gaugarg-nv requested a review from a team as a code owner March 16, 2026 11:26
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Mar 16, 2026
@am17an
Copy link
Contributor

am17an commented Mar 16, 2026

This also helps in Qwen3.5 which has a down shape of 512

@gaugarg-nv
Copy link
Contributor Author

This also helps in Qwen3.5 which has a down shape of 512

Adding Qwen3.5-35B-A3B data on RTX Pro 6000 BW:

model_type n_ubatch n_prompt master-avg_ts pr-avg_ts  
qwen35moe 35B.A3B Q4_K - Medium 1 512 201.1068 207.3763 1.03
qwen35moe 35B.A3B Q4_K - Medium 2 512 287.9365 300.4817 1.04
qwen35moe 35B.A3B Q4_K - Medium 4 512 476.962 513.3622 1.08
qwen35moe 35B.A3B Q4_K - Medium 8 512 566.0659 564.7415 1.00
qwen35moe 35B.A3B Q4_K - Medium 16 512 795.9801 797.4574 1.00
qwen35moe 35B.A3B Q4_K - Medium 32 512 1291.454 1291.694 1.00
qwen35moe 35B.A3B Q4_K - Medium 64 512 1979.488 1979.35 1.00
qwen35moe 35B.A3B Q4_K - Medium 128 512 2660.565 2659.532 1.00
qwen35moe 35B.A3B Q4_K - Medium 256 512 4538.735 4538.854 1.00
qwen35moe 35B.A3B Q4_K - Medium 512 512 6729.192 6747.521 1.00

{ \
constexpr int c_ncols_dst = C_NCOLS_DST; \
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); \
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be done inside the cuda kernel using multiplications by re-ordering this expression? It would simplify the code quite a bit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, are you suggesting to remove small_k template parameter from the kernel? I think that should be doable.

But we will still need this code on the host as we modify rows_per_block for small_k, which in turn modifies the grid dimensions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. On the host side we can create a function to return the correct dims instead of doing the if/else and the macro

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to simplify the host code.

But I think removing small_k as a template parameter from the kernel would mean rows_per_cuda_block can no longer be constexpr. And some of the local register and shared memory allocation depend on this value to be constexpr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then it makes sense to leave the template parameter as it is. My only worry was the compile times/binary sizes. Do you notice any difference in them? If you're using ninja build it should be pretty easy to see via .ninja_log

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@am17an Sorry for the late follow-up on this. I was busy with some other work.

Regarding build times, I am seeing an increase of 12 seconds and an increase of 2MB in libggml-cuda.so.
IMO, given that most SOTA models are MOE, it is worth taking this hit.

Regarding host-side code simplification, I think we can NOT avoid if-else as small_k is a template parameter.
Let me know if there are any specific ideas you have regarding code simplification.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you compiling only for blackwell? I think this change will slow down the CI. We can limit this change to ncols_dst = 1, since that 99% of the use-case.

228.646s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/mmvq.cu.o
169.613s        ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/mmq-instance-q2_k.cu.o                                             
163.399s        ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/fattn-tile-instance-dkq256-dv256.cu.o                              
153.896s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/mmq-instance-q6_k.cu.o
147.442s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu.o
144.375s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/mmq-instance-q3_k.cu.o
144.037s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/mmq-instance-iq2_s.cu.o
139.735s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/mmq-instance-q5_0.cu.o
138.991s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/mmq-instance-iq2_xs.cu.o
132.618s	ggml/src/ggml-cuda/CMakeFiles/ggml-cuda.dir/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu.o

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have limited the change to ncols_dst = 1 for now. This would obviously mean no scaling for BS > 1. I will spend more time on this kernel later and see how we can add specialization for small-k without adding too much compilation time and library size.

Copy link
Contributor

@am17an am17an Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typical way is to separate ncols into separate template files. But I think it might be overkill for now

Comment on lines +482 to +484
// When K is small, increase rows_per_block to match nwarps so each warp
// processes a different row. This amortizes y-vector reads and reduces block count.
// Trigger when the full thread block covers all K blocks in a single loop iteration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you increase rows_per_block that will not result in different warps working on different src0 rows. Each thread will still work on every row/column assigned to a CUDA block, all that changes is that the inner loop is over more rows so the compiler should in principle be able to recognize that the same data is being loaded multiple times. To avoid wasted work I would suggest you reduce the number of warps per CUDA block instead.

Regardless of the above, it may very well be that increasing rows_per_block is beneficial on Blackwell (in general, I did not test this). Did you test the impact of your small_k config for larger matrices.

Generally speaking, MMVQ is of comparatively poor code quality for historical reasons. It's among the first kernels that I wrote so I was less experienced (llama.cpp/ggml was my first contact with CUDA) and I didn't yet find the time to circle back to it to look into how it could be improved. It may be worthwhile for you to look for optimization opportunities beyond applications for TP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you increase rows_per_block that will not result in different warps working on different src0 rows. Each thread will still work on every row/column assigned to a CUDA block, all that changes is that the inner loop is over more rows so the compiler should in principle be able to recognize that the same data is being loaded multiple times.

Thanks @JohannesGaessler for the pointers. You're right about the kernel. I will update the comments.

To avoid wasted work I would suggest you reduce the number of warps per CUDA block instead.

I tried setting n_warps to 1 or 2 for small_k without modifying rows_per_block . None of them show any perf improvements. In general, I think too small CTAs are not very efficient.

Regardless of the above, it may very well be that increasing rows_per_block is beneficial on Blackwell (in general, I did not test this). Did you test the impact of your small_k config for larger matrices.

Yes, I tried setting small_k to true for all cases, which would increase rows_per_block in all cases. This gave me a regression of 4-6% for a few models in the BS=1 case. So, rejected the idea.

It may be worthwhile for you to look for optimization opportunities beyond applications for TP.

Yes, the main motivation of this work was Tensor parallelism. I will explore more ideas.
In general, for BS=1, the performance of kernels looks good except for some corner cases like small-k. I think one idea worth pursuing is using 1 warp per element for small-K, so that we avoid using shared memory and use light-weight warpReduce instead of block-level reduce. I can try experimenting with this idea.

Would you like to proceed with this PR first, or should we do more exploration around these ideas first?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that if this PR empirically improves performance for some cases we should keep it, just change the rationale since it is misleading.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler updated the comments and PR description.

With tensor parallelism, the K-dimension of the FFN-down matrices is split, which makes it quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3235B-A22B has k-dimension of 1536.
The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices.

This change increases the number of output elements per block for such cases.
@gaugarg-nv gaugarg-nv force-pushed the small_k_optimization branch from 4f20a44 to cfbbfb2 Compare March 18, 2026 11:27
@gaugarg-nv gaugarg-nv changed the title [CUDA] Use a single warp per element instead of a single block per element if the K-dimension is small [CUDA] Increase number of output elements per-thread block if the K-dimension is small Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants