Skip to content

Vulkan Scalar Flash Attention Refactor#19625

Merged
0cc4m merged 46 commits intomasterfrom
0cc4m/vulkan-fa-scalar-opt
Feb 24, 2026
Merged

Vulkan Scalar Flash Attention Refactor#19625
0cc4m merged 46 commits intomasterfrom
0cc4m/vulkan-fa-scalar-opt

Conversation

@0cc4m
Copy link
Contributor

@0cc4m 0cc4m commented Feb 14, 2026

This started out as an attempt to go through the scalar FA version and add proper float16 support to improve AMD and Intel performance and went quite a bit further. @jeffbolznv Sorry about the amount of changes, let me know if there's something I can do to make the review easier. Please also let me know if you have architectural concerns. Flash Attention has so many dimensions and making it work well on so much hardware and models is pretty hard. I had to spend quite a lot of time figuring out and fixing regressions on specific configurations.

AI-generated summary of changes

Scalar Flash Attention Core Optimizations

  • Implemented row splitting within workgroups (row_split = 1 or 4) for better subgroup utilization
  • Added shared memory staging for K and V loads on Nvidia GPUs when head sizes < 256
  • Cached Q values in registers for KQ computation when HSK_per_thread > 16
  • Fused loop for Lf accumulation and Of scaling by eMf
  • Changed to vectorized vec4 stores for output
  • Optimized masksh layout with stride padding (Br + 1) and removed unnecessary barrier

Row Size Tiering

  • Replaced binary small_rows/large_rows with three-tier system: FA_ROWS_1, FA_ROWS_SMALL, FA_ROWS_LARGE
  • Dynamic Br selection based on head sizes, device vendor, and architecture
  • FA_ROWS_1 uses Br=1 for N=1, FA_ROWS_SMALL uses Br=8, FA_ROWS_LARGE uses Br=16
  • Device-specific adjustments: AMD GCN uses smaller Br, Intel uses Br=8 maximum

Vendor-Specific Optimizations

  • AMD RDNA: Use wave32 subgroup size for scalar FA when N=1
  • Intel: Added shader core count lookup table for Alchemist and Battlemage GPUs
  • Intel: Disable subgroup operations in favor of shared memory reductions
  • Intel Alchemist: Apply 2x shader core count multiplier for split_k calculation
  • Adjusted workgroup sizes per vendor and head size combinations

split_k Enhancements

  • Relaxed split_k conditions to support non-GQA workloads
  • Fixed dispatch logic to handle both GQA and non-GQA cases correctly
  • Improved split_k calculation based on total workgroup count and shader cores

Device Compatibility

  • Added FP32 shader variants (_fp32 suffix) for devices without FP16 support
  • Made FLOAT_TYPE conditional on device capabilities
  • Updated dequantize4 functions to use FLOAT_TYPE instead of hardcoded float

Shared Memory Management

  • Dynamic tmpsh sizing based on row_split and subgroup configuration
  • Added kvsh buffer for K/V staging (size conditional on SHMEM_STAGING flag)
  • Improved Qf buffer stride calculation
  • Fixed tmpsh size calculation for split_k temporaries

Code Path Selection

  • Switch from coopmat1 to scalar when N=1 or rows=FA_ROWS_1
  • Improved shared memory size checks for scalar path fallback
  • Better alignment checking and stride validation

Shader Compilation

  • Made coopmat1/coopmat2 pipeline creation conditional on device FP16 support
  • Added subgroup size configuration per code path and row configuration
  • Removed hardcoded subgroup size assumptions

Benchmarks

AMD Radeon Pro VII
model size params ngl fa test t/s (ROCm) t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 1003.15 ± 0.89 800.28 ± 1.41 827.57 ± 0.74 +3.4%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 85.12 ± 1.39 98.55 ± 0.55 97.83 ± 0.47 -0.7%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d8192 689.31 ± 0.64 174.36 ± 0.42 388.72 ± 3.37 +122.9%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d8192 69.91 ± 0.20 55.97 ± 0.20 72.24 ± 0.34 +29.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 525.25 ± 1.68 84.33 ± 0.11 247.07 ± 1.51 +193.0%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 60.48 ± 0.17 41.46 ± 0.12 57.70 ± 0.57 +39.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 1061.99 ± 7.85 1319.64 ± 7.82 1321.90 ± 6.90 +0.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 110.86 ± 0.97 136.10 ± 0.27 127.75 ± 0.88 -6.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 745.39 ± 1.25 757.62 ± 3.94 740.88 ± 4.66 -2.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 101.64 ± 0.41 116.38 ± 0.17 113.37 ± 0.93 -2.6%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 577.95 ± 3.32 509.10 ± 3.64 484.85 ± 2.85 -4.8%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 99.23 ± 0.21 107.31 ± 0.68 102.88 ± 1.13 -4.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 351.98 ± 3.24 749.40 ± 5.15 759.11 ± 4.74 +1.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 68.83 ± 0.11 95.12 ± 0.22 93.94 ± 0.45 -1.2%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d8192 295.91 ± 3.09 207.63 ± 0.63 312.17 ± 5.34 +50.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d8192 60.01 ± 0.77 55.87 ± 0.35 73.73 ± 0.68 +32.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 247.76 ± 0.77 114.90 ± 0.42 191.18 ± 1.32 +66.4%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 55.69 ± 0.30 44.11 ± 0.11 61.76 ± 0.63 +40.0%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 641.90 ± 2.66 657.73 ± 3.46 740.63 ± 1.78 +12.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 47.72 ± 0.13 64.38 ± 0.19 65.54 ± 0.32 +1.8%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d8192 293.28 ± 0.54 83.15 ± 0.33 129.38 ± 0.69 +55.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d8192 38.76 ± 0.07 35.93 ± 0.20 37.94 ± 0.33 +5.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 189.33 ± 0.18 41.62 ± 0.24 70.77 ± 0.49 +70.0%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 31.80 ± 0.08 24.39 ± 0.36 26.41 ± 0.22 +8.3%
AMD 8060S
model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 994.34 ± 34.50 947.41 ± 7.78 -4.7%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 45.14 ± 0.44 44.86 ± 0.42 -0.6%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d8192 418.71 ± 11.10 397.77 ± 8.90 -5.0%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d8192 35.83 ± 0.09 35.68 ± 0.08 -0.4%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 234.05 ± 5.66 246.05 ± 11.58 +5.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 30.53 ± 0.08 30.13 ± 0.11 -1.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 1263.73 ± 34.96 1208.77 ± 37.78 -4.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 73.19 ± 0.13 72.68 ± 0.10 -0.7%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 920.01 ± 4.93 919.00 ± 4.71 -0.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 66.74 ± 0.45 66.42 ± 0.13 -0.5%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 670.22 ± 4.61 670.46 ± 5.07 +0.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 61.53 ± 0.78 61.78 ± 1.08 +0.4%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 945.03 ± 32.97 992.30 ± 11.33 +5.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 91.76 ± 0.06 91.60 ± 0.53 -0.2%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d8192 487.96 ± 2.76 479.56 ± 4.25 -1.7%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d8192 66.47 ± 0.33 66.13 ± 0.27 -0.5%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 302.07 ± 1.01 286.72 ± 1.03 -5.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 50.54 ± 0.19 49.64 ± 0.88 -1.8%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 924.97 ± 10.45 923.58 ± 4.06 -0.2%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 61.52 ± 0.34 61.43 ± 0.41 -0.1%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 @ d8192 306.02 ± 0.84 297.15 ± 0.91 -2.9%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 @ d8192 38.31 ± 0.20 39.20 ± 0.17 +2.3%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 @ d16384 192.72 ± 0.35 182.25 ± 0.82 -5.4%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 @ d16384 27.83 ± 0.16 28.83 ± 0.01 +3.6%
AMD 8060S (Without Coopmat)
model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 815.03 ± 7.22 822.68 ± 4.39 +0.9%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 44.96 ± 0.22 45.36 ± 0.30 +0.9%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d8192 67.06 ± 4.00 190.34 ± 2.98 +183.8%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d8192 31.53 ± 0.13 35.31 ± 0.28 +12.0%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 28.05 ± 0.85 78.89 ± 4.18 +181.2%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 25.53 ± 0.17 29.71 ± 0.08 +16.4%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 1249.96 ± 37.10 1187.02 ± 15.67 -5.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 73.17 ± 0.06 72.39 ± 0.23 -1.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 681.99 ± 1.44 681.63 ± 2.60 -0.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 66.34 ± 0.35 66.37 ± 0.21 +0.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 438.09 ± 2.70 408.44 ± 7.02 -6.8%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 61.46 ± 0.62 61.54 ± 0.76 +0.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 617.33 ± 13.14 614.00 ± 6.22 -0.5%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 94.84 ± 0.20 92.14 ± 0.22 -2.8%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d8192 179.49 ± 0.92 227.94 ± 1.12 +27.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d8192 57.91 ± 0.39 67.14 ± 0.11 +15.9%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 86.39 ± 0.78 128.04 ± 0.64 +48.2%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 43.22 ± 0.18 51.58 ± 0.14 +19.3%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 727.26 ± 4.81 810.87 ± 5.13 +11.5%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 61.59 ± 0.70 61.90 ± 0.12 +0.5%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 @ d8192 105.57 ± 0.50 178.01 ± 0.22 +68.6%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 @ d8192 38.58 ± 0.19 39.50 ± 0.33 +2.4%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 @ d16384 52.56 ± 0.29 94.60 ± 0.41 +80.0%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 @ d16384 28.02 ± 0.18 28.98 ± 0.06 +3.4%
Intel A770
model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 818.22 ± 0.63 812.84 ± 1.85 -0.7%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 32.64 ± 0.07 32.45 ± 0.05 -0.6%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d2048 97.15 ± 0.05 550.81 ± 1.20 +467.0%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d2048 21.67 ± 0.02 27.75 ± 0.02 +28.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d4096 43.79 ± 2.97 405.21 ± 0.78 +825.3%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d4096 17.28 ± 0.00 25.06 ± 0.01 +45.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 930.73 ± 3.24 898.65 ± 3.47 -3.4%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 41.29 ± 0.07 37.53 ± 0.11 -9.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d2048 701.16 ± 3.52 670.17 ± 4.91 -4.4%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d2048 31.19 ± 0.06 31.73 ± 0.03 +1.7%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d4096 545.63 ± 1.16 495.18 ± 0.71 -9.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d4096 28.83 ± 0.09 29.27 ± 0.04 +1.5%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 640.10 ± 3.55 657.27 ± 3.54 +2.7%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 33.43 ± 0.08 30.04 ± 0.03 -10.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d2048 60.27 ± 4.78 281.25 ± 1.21 +366.7%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d2048 20.16 ± 0.02 22.98 ± 0.03 +14.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d4096 26.38 ± 0.63 310.19 ± 1.68 +1075.9%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d4096 18.27 ± 0.03 23.61 ± 0.08 +29.2%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 167.35 ± 0.17 66.63 ± 0.23 -60.2%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 19.23 ± 0.01 20.38 ± 0.03 +6.0%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d2048 26.23 ± 1.02 25.38 ± 0.01 -3.2%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d2048 5.95 ± 0.00 13.59 ± 0.01 +128.4%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d4096 25.54 ± 0.02 25.29 ± 0.04 -1.0%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d4096 3.64 ± 0.00 10.37 ± 0.00 +184.9%
Nvidia RTX 3090 (Coopmat2)
model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 4666.60 ± 19.46 4721.23 ± 12.32 +1.2%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 144.71 ± 1.53 147.49 ± 0.52 +1.9%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d8192 3426.64 ± 19.29 3428.98 ± 22.04 +0.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d8192 114.85 ± 0.97 115.92 ± 0.34 +0.9%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 2695.37 ± 16.65 2692.89 ± 16.34 -0.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 99.65 ± 0.73 99.82 ± 0.29 +0.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 4520.31 ± 33.68 4513.71 ± 30.22 -0.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 177.65 ± 0.75 177.15 ± 0.77 -0.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 4040.47 ± 78.90 4049.94 ± 174.56 +0.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 156.59 ± 1.58 155.91 ± 0.78 -0.4%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 3546.97 ± 21.35 3529.89 ± 36.63 -0.5%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 147.96 ± 0.76 145.37 ± 0.48 -1.8%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 3469.59 ± 17.36 3465.49 ± 34.45 -0.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 178.72 ± 0.64 177.48 ± 2.05 -0.7%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d8192 2508.75 ± 42.02 2500.37 ± 34.47 -0.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d8192 141.66 ± 0.54 141.16 ± 0.65 -0.4%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 1942.67 ± 15.90 1936.24 ± 20.12 -0.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 123.39 ± 0.72 123.21 ± 0.29 -0.1%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 2287.89 ± 11.77 2289.12 ± 9.34 +0.1%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 116.47 ± 0.80 114.38 ± 3.56 -1.8%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d8192 1047.29 ± 9.19 1047.12 ± 9.51 -0.0%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d8192 90.74 ± 0.34 90.44 ± 0.37 -0.3%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 647.46 ± 3.70 644.65 ± 3.78 -0.4%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 81.92 ± 0.81 82.07 ± 0.20 +0.2%
Nvidia RTX 3090 (Coopmat1)
model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 4117.11 ± 10.81 4052.19 ± 17.94 -1.6%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 145.98 ± 1.84 144.04 ± 0.74 -1.3%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d8192 2182.12 ± 11.97 2359.95 ± 10.14 +8.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d8192 115.72 ± 0.56 116.46 ± 0.62 +0.6%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 1486.54 ± 4.89 1671.90 ± 9.35 +12.5%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 99.15 ± 0.74 101.36 ± 0.32 +2.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 3062.95 ± 94.07 3090.31 ± 33.32 +0.9%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 175.29 ± 0.83 175.87 ± 0.88 +0.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 2439.28 ± 32.02 2494.98 ± 47.57 +2.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 148.99 ± 14.70 154.40 ± 2.18 +3.6%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 1964.74 ± 21.60 2098.26 ± 19.00 +6.8%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 147.55 ± 0.70 147.66 ± 0.69 +0.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 2839.27 ± 26.12 2837.32 ± 30.26 -0.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 174.78 ± 1.25 176.05 ± 1.26 +0.7%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d8192 1505.57 ± 14.41 1639.74 ± 14.94 +8.9%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d8192 137.34 ± 0.86 139.22 ± 2.10 +1.4%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 1010.90 ± 10.49 1146.23 ± 14.19 +13.4%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 119.58 ± 0.71 121.95 ± 0.88 +2.0%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 1968.30 ± 10.15 1954.94 ± 33.29 -0.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 114.35 ± 0.87 115.05 ± 0.80 +0.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d8192 554.73 ± 1.56 555.49 ± 1.82 +0.1%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d8192 62.50 ± 0.51 63.21 ± 0.34 +1.1%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 314.59 ± 0.93 315.91 ± 1.26 +0.4%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 43.01 ± 0.10 43.98 ± 0.15 +2.3%
Nvidia RTX 3090 (Without Coopmat)
model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 2129.81 ± 5.52 2081.00 ± 42.53 -2.3%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 145.98 ± 0.24 144.26 ± 0.53 -1.2%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d8192 997.77 ± 3.31 1048.43 ± 25.28 +5.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d8192 110.19 ± 0.54 112.16 ± 0.12 +1.8%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 637.54 ± 1.09 701.26 ± 11.14 +10.0%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 94.33 ± 0.22 95.27 ± 0.31 +1.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 2410.79 ± 15.88 2331.15 ± 89.00 -3.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 176.60 ± 0.74 173.28 ± 0.72 -1.9%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 1582.99 ± 17.17 1429.18 ± 11.60 -9.7%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 153.60 ± 1.60 150.58 ± 0.91 -2.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 1114.36 ± 154.82 1009.61 ± 23.16 -9.4%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 146.14 ± 0.64 143.19 ± 1.18 -2.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 1159.21 ± 12.74 1137.29 ± 13.35 -1.9%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 177.45 ± 1.07 175.96 ± 1.95 -0.8%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d8192 592.47 ± 4.68 620.55 ± 6.11 +4.7%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d8192 130.00 ± 0.58 135.84 ± 1.70 +4.5%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 387.10 ± 1.89 425.32 ± 0.85 +9.9%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 113.49 ± 0.51 117.90 ± 0.71 +3.9%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 1050.83 ± 17.39 1092.14 ± 16.92 +3.9%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 114.66 ± 2.79 115.36 ± 3.33 +0.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d8192 281.20 ± 1.84 342.26 ± 2.76 +21.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d8192 63.73 ± 0.06 63.90 ± 0.37 +0.3%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 159.38 ± 1.00 202.89 ± 2.03 +27.3%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 43.40 ± 0.05 44.22 ± 0.09 +1.9%

@0cc4m 0cc4m requested a review from jeffbolznv February 14, 2026 11:58
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Feb 14, 2026
@netrunnereve
Copy link
Collaborator

I did some quick runs on my RX 470, the tests are passing and performance seems pretty similar to what it was like before. Nothing crazy at least.

PR

model size params backend ngl fa test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 pp512 197.77 ± 0.04
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 tg128 37.70 ± 0.02
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 pp512 1269.96 ± 1.45
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 tg128 174.82 ± 1.56
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 pp512 @ d8192 460.61 ± 64.78
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 tg128 @ d8192 126.68 ± 0.25
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 pp512 @ d16384 307.06 ± 1.44
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 tg128 @ d16384 105.13 ± 0.04

Master

model size params backend ngl fa test t/s
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 pp512 194.26 ± 0.24
llama 7B Q4_0 3.56 GiB 6.74 B Vulkan 100 1 tg128 37.45 ± 0.04
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 pp512 1272.19 ± 1.79
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 tg128 184.46 ± 3.22
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 pp512 @ d8192 457.34 ± 95.60
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 tg128 @ d8192 124.14 ± 0.29
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 pp512 @ d16384 325.59 ± 0.43
llama 1B Q4_0 606.53 MiB 1.10 B Vulkan 100 1 tg128 @ d16384 104.48 ± 0.32

@engrtipusultan
Copy link

Thank very much 0cc4m. Huge improvements on 8k and 16k context. I believe for any meaning full conversation pp and tg is least important on depth 0. It requires context to communicate and this has huge improvements and making FA worthy.

@engrtipusultan
Copy link

engrtipusultan commented Feb 15, 2026

AMD Vega 8 APU.
Where build Build dbddc35e8 (8102) is master merged with PRs (#19597, #19625, #19422, #19509)

Model Name Size Backend Threads ubatch FA dio Generation type Depth Build 079feab (8055) Build dbddc35e8 (8102) Difference
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d0 153 151.59 -0.93%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d0 15.67 16.22 3.39%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d1024 138.46 143.46 3.49%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d1024 15.03 15.56 3.41%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d2048 129.66 133.91 3.17%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d2048 15.98 15.58 -2.57%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d8096 84.95 101.07 15.95%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d8096 14.91 14.87 -0.27%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d16192 59.03 76.58 22.92%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d16192 13.72 13.52 -1.48%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 pp512 d0 104.48 104.25 -0.22%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 tg128 d0 12.04 11.65 -3.35%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 pp512 d1024 95.05 101.43 6.29%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 tg128 d1024 11.98 11.49 -4.26%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 pp512 d2048 95.13 96.99 1.92%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 tg128 d2048 12.01 12.03 0.17%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 pp512 d8096 84.61 94.07 10.06%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 tg128 d8096 11.49 11.76 2.30%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 pp512 d16192 68.73 83.01 17.20%
nemotron_h_moe 31B.A3.5B Q8_0 31.27 GiB 31.58 B Vulkan,BLAS 8 1024 1 1 tg128 d16192 10.85 11.44 5.16%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 pp512 d0 86.94 95.4 8.87%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 tg128 d0 11.1 11.23 1.16%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 pp512 d1024 53.89 68.33 21.13%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 tg128 d1024 9.87 10.41 5.19%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 pp512 d2048 38.53 52.2 26.19%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 tg128 d2048 8.92 9.67 7.76%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 pp512 d8096 14.41 20.5 29.71%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 tg128 d8096 5.62 6.61 14.98%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 pp512 d16192 7.39 11.34 34.83%
deepseek2 30B.A3B Q8_0 29.65 GiB 29.94 B Vulkan,BLAS 8 1024 1 1 tg128 d16192 3.74 4.54 17.62%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 d0 61.82 64.82 4.63%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 d0 8.62 8.7 0.92%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 d1024 61.46 63.87 3.77%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 d1024 8.48 8.65 1.97%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 d2048 58.68 62.14 5.57%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 d2048 8.36 8.59 2.68%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 d8096 44.79 46.71 4.11%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 d8096 7.56 8.33 9.24%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 d16192 34.42 28.05 -22.71%
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 d16192 6.78 7.98 15.04%

@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 15, 2026

I see a regression in prompt processing for GPT-OSS 20B and Qwen3MoE on RX 6800 XT, I'll try to fix it.

@engrtipusultan
Copy link

@0cc4m can you also kindly look at qwen3next this architecture it also has PP regression atleast in my setup. This architecture is already used by Qwen team in three models. I believe going forward they will be using it and other different models as well.

Previously shared benchmarks also had some other un-merged PRs related to vulkan and Qwen3Next. To isolate I just merged this PR in master removed all other PRs that I was merging. Regression is still there.

Master Branch

bash  ./llama-bench -m /home/tipu/AI/models/bartowski/Qwen3-Coder-Next/Qwen_Qwen3-Coder-Next-Q4_K_L.gguf -ngl 100 --ubatch-size 1024 --batch-size 2048 --mmap 0 -fa 1 -d 8096,16192 -r 3 -dio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon Graphics (RADV RENOIR) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 0 | matrix cores: none

model size params backend threads n_ubatch fa dio test t/s
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 @ d8096 46.87 ± 1.36
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 @ d8096 7.60 ± 0.04
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 @ d16192 36.35 ± 1.00
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 @ d16192 6.88 ± 0.00

build: 684b361 (8057)

Master only merge with this PR.

bash  ./llama-bench -m /home/tipu/AI/models/bartowski/Qwen3-Coder-Next/Qwen_Qwen3-Coder-Next-Q4_K_L.gguf -ngl 100 --ubatch-size 1024 --batch-size 2048 --mmap 0 -fa 1 -d 8096,16192 -r 3 -dio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon Graphics (RADV RENOIR) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 0 | matrix cores: none

model size params backend threads n_ubatch fa dio test t/s
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 @ d8096 45.98 ± 0.12
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 @ d8096 8.28 ± 0.00
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 pp512 @ d16192 27.98 ± 0.07
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1024 1 1 tg128 @ d16192 7.98 ± 0.00

build: 6cdddc6e0 (8089)

@0cc4m 0cc4m marked this pull request as draft February 15, 2026 14:30
@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 15, 2026

Back to draft while I improve parameter selection to make tuning easier.

@0cc4m 0cc4m force-pushed the 0cc4m/vulkan-fa-scalar-opt branch 2 times, most recently from d0cf725 to c6ee63e Compare February 18, 2026 10:29
@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 18, 2026

I fixed all regressions I could find. @jeffbolznv I refactored the way FA parameters are set, so that parameters are set only in one place. I tried to port Coopmat2 behaviour correctly and only slightly tweaked Coopmat1 parameters. Let me know if you agree with this approach or have a better idea.

Benchmarks

Radeon Pro VII:

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 88.61 ± 0.00 257.10 ± 0.00 +190.1%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 43.02 ± 0.00 63.79 ± 0.00 +48.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 533.10 ± 0.00 503.08 ± 0.00 -5.6%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 110.54 ± 0.00 110.01 ± 0.00 -0.5%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 119.16 ± 0.00 201.54 ± 0.00 +69.1%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 44.77 ± 0.00 64.09 ± 0.00 +43.2%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 pp512 @ d16384 33.72 ± 0.00 229.42 ± 0.00 +580.4%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 tg128 @ d16384 35.71 ± 0.00 37.23 ± 0.00 +4.3%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 43.91 ± 0.00 76.28 ± 0.00 +73.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 25.65 ± 0.00 28.14 ± 0.00 +9.7%

AMD RX 6800 XT:

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 235.61 ± 0.29 520.53 ± 0.00 +120.9%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 57.09 ± 0.03 66.27 ± 0.00 +16.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 1237.53 ± 9.61 1050.05 ± 0.00 -15.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 129.86 ± 0.19 130.28 ± 0.00 +0.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 248.69 ± 0.29 381.32 ± 0.00 +53.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 85.90 ± 0.05 98.23 ± 0.00 +14.4%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 pp512 @ d16384 91.75 ± 0.28 264.58 ± 0.00 +188.4%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 tg128 @ d16384 44.90 ± 0.01 44.61 ± 0.00 -0.6%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 112.77 ± 0.15 155.29 ± 0.00 +37.7%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 53.91 ± 0.03 58.33 ± 0.00 +8.2%

AMD RX 8060S (without coopmat):

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 26.43 ± 0.00 42.34 ± 0.00 +60.2%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 25.61 ± 0.00 30.43 ± 0.00 +18.8%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 445.30 ± 0.00 436.39 ± 0.00 -2.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 61.07 ± 0.00 61.71 ± 0.00 +1.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 83.85 ± 0.00 137.94 ± 0.00 +64.5%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 43.01 ± 0.00 50.06 ± 0.00 +16.4%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 @ d16384 53.20 ± 0.00 99.00 ± 0.00 +86.1%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 @ d16384 28.23 ± 0.00 29.47 ± 0.00 +4.4%

AMD RX 8060S (with coopmat):

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 229.34 ± 0.00 240.92 ± 0.00 +5.0%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 30.55 ± 0.00 30.28 ± 0.00 -0.9%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 684.15 ± 0.00 683.27 ± 0.00 -0.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 61.56 ± 0.00 62.96 ± 0.00 +2.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 308.56 ± 0.00 301.50 ± 0.00 -2.3%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 51.17 ± 0.00 50.98 ± 0.00 -0.4%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 pp512 @ d16384 405.06 ± 0.00 414.63 ± 0.00 +2.4%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 tg128 @ d16384 18.16 ± 0.00 17.94 ± 0.00 -1.2%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 pp512 @ d16384 199.41 ± 0.00 188.99 ± 0.00 -5.2%
deepseek2 30B.A3B Q4_0 16.03 GiB 29.94 B 99 1 tg128 @ d16384 28.65 ± 0.00 29.55 ± 0.00 +3.1%

Intel A770:

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d4096 54.15 ± 0.00 400.07 ± 0.00 +638.8%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d4096 17.25 ± 0.00 26.74 ± 0.00 +55.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d4096 558.12 ± 0.00 611.40 ± 0.00 +9.5%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d4096 29.34 ± 0.00 28.86 ± 0.00 -1.6%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d4096 36.87 ± 0.00 315.53 ± 0.00 +755.8%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d4096 18.24 ± 0.00 25.57 ± 0.00 +40.2%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 pp512 @ d4096 258.96 ± 0.00 221.82 ± 0.00 -14.3%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 tg128 @ d4096 14.73 ± 0.00 17.50 ± 0.00 +18.8%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d4096 25.52 ± 0.00 25.26 ± 0.00 -1.0%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d4096 3.62 ± 0.00 10.61 ± 0.00 +193.1%

Nvidia RTX 3090 (without coopmat):

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 651.54 ± 1.29 748.99 ± 1.15 +15.0%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 88.68 ± 17.10 95.77 ± 0.21 +8.0%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 1218.53 ± 10.73 1025.35 ± 5.76 -15.9%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 136.34 ± 25.87 132.75 ± 24.28 -2.6%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 407.04 ± 1.25 459.86 ± 1.55 +13.0%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 114.69 ± 0.37 119.93 ± 0.20 +4.6%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 pp512 @ d16384 429.81 ± 2.28 585.10 ± 3.84 +36.1%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 tg128 @ d16384 68.11 ± 0.06 67.87 ± 0.12 -0.4%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 168.55 ± 0.59 216.48 ± 0.63 +28.4%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 44.17 ± 0.06 45.10 ± 0.13 +2.1%

Nvidia RTX 3090 (coopmat1):

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 1526.53 ± 6.58 1720.57 ± 4.45 +12.7%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 101.16 ± 0.14 103.92 ± 0.23 +2.7%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 2034.12 ± 20.53 2161.28 ± 18.58 +6.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 145.81 ± 15.49 154.12 ± 2.51 +5.7%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 1053.16 ± 7.36 1184.50 ± 8.26 +12.5%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 123.91 ± 0.33 127.12 ± 0.29 +2.6%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 pp512 @ d16384 1122.62 ± 15.29 1111.88 ± 13.51 -1.0%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 tg128 @ d16384 57.18 ± 0.27 57.14 ± 0.05 -0.1%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 329.25 ± 0.98 332.16 ± 1.01 +0.9%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 44.78 ± 0.02 45.64 ± 0.03 +1.9%

Nvidia RTX 3090 (coopmat2):

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 pp512 @ d16384 2737.42 ± 14.92 2731.76 ± 22.52 -0.2%
llama 8B Q4_0 4.33 GiB 8.03 B 99 1 tg128 @ d16384 101.18 ± 0.46 100.96 ± 0.15 -0.2%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d16384 3594.56 ± 37.84 3584.04 ± 37.04 -0.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d16384 145.09 ± 2.51 146.02 ± 3.07 +0.6%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 pp512 @ d16384 1986.18 ± 25.31 1977.68 ± 20.39 -0.4%
qwen3moe 30B.A3B Q2_K - Medium 10.48 GiB 30.53 B 99 1 tg128 @ d16384 125.14 ± 0.45 124.66 ± 0.24 -0.4%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 pp512 @ d16384 2277.10 ± 33.85 2280.22 ± 43.93 +0.1%
gemma2 9B Q4_K - Small 5.10 GiB 9.24 B 99 1 tg128 @ d16384 62.92 ± 0.08 62.34 ± 0.05 -0.9%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 pp512 @ d16384 659.69 ± 4.27 658.85 ± 3.66 -0.1%
deepseek2 30B.A3B Q3_K - Small 12.37 GiB 29.94 B 99 1 tg128 @ d16384 85.16 ± 1.03 85.09 ± 0.31 -0.1%

@0cc4m 0cc4m marked this pull request as ready for review February 18, 2026 10:46
@engrtipusultan
Copy link

Thank you :) regression is gone and further improvements are observed.

bash  ./llama-bench -m /home/tipu/AI/models/bartowski/Qwen3-Coder-Next/Qwen_Qwen3-Coder-Next-Q4_K_L.gguf -ngl 100 --mmap 0 -fa 1 -d 8096,16192 -r 3 -dio 1
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = AMD Radeon Graphics (RADV RENOIR) (radv) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 64 | shared memory: 65536 | int dot: 0 | matrix cores: none

model size params backend threads fa dio test t/s
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1 1 pp512 @ d8096 55.22 ± 0.06
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1 1 tg128 @ d8096 8.24 ± 0.05
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1 1 pp512 @ d16192 47.73 ± 0.27
qwen3next 80B.A3B Q4_K - Medium 45.43 GiB 79.67 B Vulkan,BLAS 8 1 1 tg128 @ d16192 7.95 ± 0.03

build: f91542c05 (8128)

barrier();

vec4 Of[Br][HSV_per_thread / 4];
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I found recently that the output can be stored in fp16 even when using GGML_PREC_F32, and his can help a lot with register usage for large head sizes (e.g. deepseek/GLM-Flash/etc).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should add that to scalar and coopmat1, yes.

[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change to:

Suggested change
if (c < Bc) {
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {

this would allow the compiler to optimize out the branch on all except the last iteration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.


// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
if ((HSK % 16) != 0) {
if ((HSK % 16) != 0 || (HSV % 16) != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If HSK and HSV are both not aligned and are different values, seems like the smaller one will fetch stale values from the larger one. I think that would be very uncommon, maybe just disable SHMEM_STAGING for those cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought here was that since I stage both K and V, I should check both for whether zero-initialization is required, so that it happens also if only HSV is off. Looking at it again, it might not be necessary to zero-pad in the beginning at all since both SHMEM_STAGING loads run with HS*_pad, so should also put zero values when parts of the tile are out of bounds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like the load for K/V will pad with zero. I think filling Q will zero is still necessary to avoid inf*0=nan, but for that just the HSK check should be enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the HSV check and the kvsh loop.

coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth a comment to say that we load V from shared memory even when SHMEM_STAGING is 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly to K, if SHMEM_STAGING is not set, we only stage through shmem if the V datatype or bounds checks don't allow direct loading from global memory. But in those cases we load a much smaller tile compared to SHMEM_STAGING. I don't see what is different about how V is handled compared to how K was already handled previously.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found the logic here kind of confusing - seeing V get staged even when SHMEM_STAGING is false. I just thought it was worth a comment to clarify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I was just wondering if something about V was special here. I added a comment to K and to V to explain the behaviour.

@jeffbolznv
Copy link
Contributor

I did perf testing for the cm2 and scalar paths. cm2 was roughly unchanged. For scalar I didn't see any real improvements, and gpt-oss got slower on my system. I don't think this is a big deal since my system isn't really the target, but here are the results anyway:

before

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -p 512 -n 128 -d 0-32768+8192 -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: none
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      6187.85 ± 40.90 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        366.43 ± 0.38 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   pp512 @ d8192 |      3431.53 ± 23.84 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   tg128 @ d8192 |        314.81 ± 0.96 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  pp512 @ d16384 |       2381.07 ± 9.49 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  tg128 @ d16384 |        292.72 ± 0.83 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  pp512 @ d24576 |      1809.95 ± 11.72 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  tg128 @ d24576 |        268.06 ± 0.43 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  pp512 @ d32768 |       1463.89 ± 1.86 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  tg128 @ d32768 |        248.55 ± 0.18 |

build: ea003229d (8090)

after

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -p 512 -n 128 -d 0-32768+8192 -m c:\models\gpt-oss-20b-mxfp4.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: none
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      6113.21 ± 31.73 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        362.19 ± 0.98 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   pp512 @ d8192 |      2783.32 ± 18.60 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   tg128 @ d8192 |        305.55 ± 0.97 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  pp512 @ d16384 |       1791.49 ± 5.89 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  tg128 @ d16384 |        281.82 ± 0.59 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  pp512 @ d24576 |       1317.98 ± 6.33 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  tg128 @ d24576 |        254.87 ± 0.56 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  pp512 @ d32768 |       1041.12 ± 1.23 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |  tg128 @ d32768 |        234.47 ± 0.57 |

build: c6ee63e0f (8127)

@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 18, 2026

I've seen that gpt-oss hit on some other tests as well. The old shader worked very well for small head size. I haven't nailed down exactly why the new version is a little bit worse for that.

@engrtipusultan
Copy link

For some reason I do not have regression as compared to my old release. Maybe there is different behavior for different hardware.

Model Name Size Backend Threads ubatch FA dio Generation type Depth Build 079feab (8055) build: f91542c05 (8128) Difference
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d0 153 155.41 1.55%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d0 15.67 16.61 5.65%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d1024 138.46 141.83 2.37%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d1024 15.03 16.19 7.16%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d2048 129.66 133.86 3.13%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d2048 15.98 16.11 0.80%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d8096 84.95 97.57 12.93%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d8096 14.91 15.09 1.19%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 pp512 d16192 59.03 70.87 16.70%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B Vulkan,BLAS 8 1024 1 1 tg128 d16192 13.72 13.99 1.92%

shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];

const uint32_t osh_stride = row_split * MatBc / 4;
shared f16vec4 pvsh[MatBc * osh_stride];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is accounted for in the shared memory calculation. Maybe it could share with kvsh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to update that, yes. I want them to share, but if Of is float16_t and ACC_TYPE is f32 they are different types. GLSL doesn't let me reuse shared memory with a different type.

@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 20, 2026

Using float16 Of caused another nasty regression on AMD RDNA for smaller head sizes, while improving larger ones. I'll try to find a way to keep good speeds for both cases.

visorcraft pushed a commit to visorcraft/llama.cpp that referenced this pull request Feb 20, 2026
@0cc4m 0cc4m requested a review from jeffbolznv February 21, 2026 19:34
@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 21, 2026

The issue on RDNA was that the register reduction from Of in float16 improved occupancy to a point where enough subgroups ran at once to thrash the cache. Performance is restored when occupancy is reduced again, so I forced this with a large unused shmem buffer. This is hacky, but I didn't find a better way. Let me know if you have concerns or suggestions.

@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 21, 2026

I'm not sure if this CM2 CI issue is an error I introduced or if I triggered the Turing coopmat bug on CM2 as well now.

@jeffbolznv
Copy link
Contributor

I think it's likely to be the same Turing bug. I don't know what to do about it other than to disable the coopmat2 flash attention path on Turing.

@0cc4m 0cc4m force-pushed the 0cc4m/vulkan-fa-scalar-opt branch 2 times, most recently from ea7dfdf to cf28133 Compare February 22, 2026 08:53
@0cc4m 0cc4m force-pushed the 0cc4m/vulkan-fa-scalar-opt branch from cf28133 to 1482e30 Compare February 23, 2026 06:37
@0cc4m 0cc4m force-pushed the 0cc4m/vulkan-fa-scalar-opt branch from 1482e30 to ae849d3 Compare February 23, 2026 09:41
@0cc4m
Copy link
Contributor Author

0cc4m commented Feb 23, 2026

@masamaru-san I disabled fp16 FA on GCN with proprietary driver, similar to what you did. Can you try it?

@masamaru-san
Copy link

@0cc4m Sorry for the slow reply.
It seems good since it passed all the tests. 👌

@0cc4m 0cc4m merged commit aa6f918 into master Feb 24, 2026
77 of 78 checks passed
@0cc4m 0cc4m deleted the 0cc4m/vulkan-fa-scalar-opt branch February 24, 2026 07:35
bartowski1182 pushed a commit to bartowski1182/llama.cpp that referenced this pull request Mar 2, 2026
* vulkan: allow using fp16 in scalar flash attention shader

* split rows inside of subgroups for faster synchronization

* use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

* use f32 scalar FA if f16 is not supported by device

* fix amd workgroup size issue

* optimize masksh use

* add medium rows FA shader Br size

* fixes

* add padding to mask shmem buffer

* cache q values into registers for KQ

* fuse lf accumulation, pf and v accumulation into a loop

* stage K loads through shmem

* stage V loads through shmem

* only stage through shmem on Nvidia

* default to Bc 32

* also stage V through shmem when this is done for K

* dynamic subgroups for intel

* use vectorized stores

* use float_type for dequantize4 functions

* use smaller scalar rows size for smaller rows count

* relax flash attention split_k condition to allow non-gqa use

* use minimal subgroup size on Intel

* fix shmem support function

* fix rebase issues

* fixes

* Bc 4 for scalar FA is not a valid configuration

* Use wave32 on AMD RDNA for scalar FA

* add Intel shader core count lookup-table

* fix regressions

* device tuning

* tmpsh size fix

* fix editorconfig

* refactor fa tuning logic into a single place

* fix gqa opt logic

* fix block_rows with small n_rows

* amd tuning

* fix hsk=72/80 issue

* tuning

* allow condition skipping for column check

* use float16 for Of if available

* address feedback

* fix bad RDNA performance on head size <= 128 by limiting occupancy

* allow printing pipeline stats

* cleanup and fixes

* limit occupancy for GCN for small batch FA with large HSK

* disable f16 FA for GCN AMD GPUs on the proprietary driver
ArberSephirotheca pushed a commit to ArberSephirotheca/llama.cpp that referenced this pull request Mar 3, 2026
* vulkan: allow using fp16 in scalar flash attention shader

* split rows inside of subgroups for faster synchronization

* use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

* use f32 scalar FA if f16 is not supported by device

* fix amd workgroup size issue

* optimize masksh use

* add medium rows FA shader Br size

* fixes

* add padding to mask shmem buffer

* cache q values into registers for KQ

* fuse lf accumulation, pf and v accumulation into a loop

* stage K loads through shmem

* stage V loads through shmem

* only stage through shmem on Nvidia

* default to Bc 32

* also stage V through shmem when this is done for K

* dynamic subgroups for intel

* use vectorized stores

* use float_type for dequantize4 functions

* use smaller scalar rows size for smaller rows count

* relax flash attention split_k condition to allow non-gqa use

* use minimal subgroup size on Intel

* fix shmem support function

* fix rebase issues

* fixes

* Bc 4 for scalar FA is not a valid configuration

* Use wave32 on AMD RDNA for scalar FA

* add Intel shader core count lookup-table

* fix regressions

* device tuning

* tmpsh size fix

* fix editorconfig

* refactor fa tuning logic into a single place

* fix gqa opt logic

* fix block_rows with small n_rows

* amd tuning

* fix hsk=72/80 issue

* tuning

* allow condition skipping for column check

* use float16 for Of if available

* address feedback

* fix bad RDNA performance on head size <= 128 by limiting occupancy

* allow printing pipeline stats

* cleanup and fixes

* limit occupancy for GCN for small batch FA with large HSK

* disable f16 FA for GCN AMD GPUs on the proprietary driver
aldehir pushed a commit to aldehir/llama.cpp that referenced this pull request Mar 6, 2026
* vulkan: allow using fp16 in scalar flash attention shader

* split rows inside of subgroups for faster synchronization

* use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

* use f32 scalar FA if f16 is not supported by device

* fix amd workgroup size issue

* optimize masksh use

* add medium rows FA shader Br size

* fixes

* add padding to mask shmem buffer

* cache q values into registers for KQ

* fuse lf accumulation, pf and v accumulation into a loop

* stage K loads through shmem

* stage V loads through shmem

* only stage through shmem on Nvidia

* default to Bc 32

* also stage V through shmem when this is done for K

* dynamic subgroups for intel

* use vectorized stores

* use float_type for dequantize4 functions

* use smaller scalar rows size for smaller rows count

* relax flash attention split_k condition to allow non-gqa use

* use minimal subgroup size on Intel

* fix shmem support function

* fix rebase issues

* fixes

* Bc 4 for scalar FA is not a valid configuration

* Use wave32 on AMD RDNA for scalar FA

* add Intel shader core count lookup-table

* fix regressions

* device tuning

* tmpsh size fix

* fix editorconfig

* refactor fa tuning logic into a single place

* fix gqa opt logic

* fix block_rows with small n_rows

* amd tuning

* fix hsk=72/80 issue

* tuning

* allow condition skipping for column check

* use float16 for Of if available

* address feedback

* fix bad RDNA performance on head size <= 128 by limiting occupancy

* allow printing pipeline stats

* cleanup and fixes

* limit occupancy for GCN for small batch FA with large HSK

* disable f16 FA for GCN AMD GPUs on the proprietary driver
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants