Skip to content

mmq.cu: tune mmq/rocblas switching for RDNA#18537

Merged
JohannesGaessler merged 5 commits intoggml-org:masterfrom
Beinsezii:beinsezii/rocm_mmq_tune
Jan 6, 2026
Merged

mmq.cu: tune mmq/rocblas switching for RDNA#18537
JohannesGaessler merged 5 commits intoggml-org:masterfrom
Beinsezii:beinsezii/rocm_mmq_tune

Conversation

@Beinsezii
Copy link
Contributor

@Beinsezii Beinsezii commented Jan 2, 2026

Continuing from #18442 I applied similar benchmarking as #14949 and #18202 to try and minimize bad cases on RDNA while keeping the logic simple.

TL;DR

Over an average of all models on https://huggingface.co/Beinsezii/mmq_test over a variety of µbatch sizes I have

mmq% blas% tuned%
95.9 85.7 98.8

Where 100% is a theoretical maximum if it were to optimally choose mmq or rocblas in each case.

Current master is functionally equivalent to mmq for this and all other benchmarks.

Benchmarks

compare-llama-bench affected quants
GPU Model Microbatch size Test t/s b7600 t/s beinsezii/rocm_mmq_tune Speedup
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 256 pp2048 11710.08 14162.79 1.21
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 512 pp2048 13389.07 13443.75 1.00
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 1024 pp2048 15232.29 16217.56 1.06
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 256 pp2048 11419.26 14492.09 1.27
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 512 pp2048 12980.69 13579.75 1.05
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 1024 pp2048 14666.17 16202.63 1.10
RX 7900 XTX llama 1B Q2_K_M 256 pp2048 10308.52 13559.73 1.32
RX 7900 XTX llama 1B Q2_K_M 512 pp2048 11137.57 12940.25 1.16
RX 7900 XTX llama 1B Q2_K_M 1024 pp2048 12807.04 15554.49 1.21
RX 7900 XTX llama 1B Q6_K 256 pp2048 8985.03 14162.37 1.58
RX 7900 XTX llama 1B Q6_K 512 pp2048 10344.84 13421.27 1.30
RX 7900 XTX llama 1B Q6_K 1024 pp2048 11870.41 15930.02 1.34
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 256 pp2048 1324.46 1455.10 1.10
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 512 pp2048 1391.35 1503.48 1.08
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 1024 pp2048 1498.16 1591.97 1.06
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 256 pp2048 1282.21 1465.35 1.14
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 512 pp2048 1363.45 1476.22 1.08
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 1024 pp2048 1458.91 1562.08 1.07
RX 7900 XTX qwen3 14B Q2_K_M 256 pp2048 1060.96 1421.48 1.34
RX 7900 XTX qwen3 14B Q2_K_M 512 pp2048 1112.72 1447.63 1.30
RX 7900 XTX qwen3 14B Q2_K_M 1024 pp2048 1198.96 1623.59 1.35
RX 7900 XTX qwen3 14B Q6_K 256 pp2048 987.74 1414.90 1.43
RX 7900 XTX qwen3 14B Q6_K 512 pp2048 1036.21 1530.71 1.48
RX 7900 XTX qwen3 14B Q6_K 1024 pp2048 1114.71 1577.11 1.41
compare-llama-bench full
GPU Model Microbatch size Test t/s b7600 t/s beinsezii/rocm_mmq_tune Speedup
RX 7900 XTX gpt-oss 20B MXFP4 MoE 64 pp2048 1526.87 1533.15 1.00
RX 7900 XTX gpt-oss 20B MXFP4 MoE 128 pp2048 1736.72 1744.08 1.00
RX 7900 XTX gpt-oss 20B MXFP4 MoE 256 pp2048 2670.39 2723.63 1.02
RX 7900 XTX gpt-oss 20B MXFP4 MoE 512 pp2048 3844.00 3857.77 1.00
RX 7900 XTX gpt-oss 20B MXFP4 MoE 1024 pp2048 4767.12 4776.70 1.00
RX 7900 XTX granitehybrid 1B Q4_K_M 64 pp2048 1996.18 2051.94 1.03
RX 7900 XTX granitehybrid 1B Q4_K_M 128 pp2048 2363.13 2458.62 1.04
RX 7900 XTX granitehybrid 1B Q4_K_M 256 pp2048 3752.94 3848.70 1.03
RX 7900 XTX granitehybrid 1B Q4_K_M 512 pp2048 5144.10 5239.24 1.02
RX 7900 XTX granitehybrid 1B Q4_K_M 1024 pp2048 5951.86 6029.21 1.01
RX 7900 XTX llama 1B IQ1_S - 1.5625 bpw 64 pp2048 7027.68 6666.87 0.95
RX 7900 XTX llama 1B IQ1_S - 1.5625 bpw 128 pp2048 9170.36 9228.60 1.01
RX 7900 XTX llama 1B IQ1_S - 1.5625 bpw 256 pp2048 13441.39 13776.57 1.02
RX 7900 XTX llama 1B IQ1_S - 1.5625 bpw 512 pp2048 15024.81 15560.55 1.04
RX 7900 XTX llama 1B IQ1_S - 1.5625 bpw 1024 pp2048 17155.51 17523.91 1.02
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 64 pp2048 6637.23 6414.86 0.97
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 128 pp2048 8048.83 8034.94 1.00
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 256 pp2048 11710.08 14162.79 1.21
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 512 pp2048 13389.07 13443.75 1.00
RX 7900 XTX llama 1B IQ2_S - 2.5 bpw 1024 pp2048 15232.29 16217.56 1.06
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 64 pp2048 6507.77 6242.87 0.96
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 128 pp2048 7801.62 7795.86 1.00
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 256 pp2048 11419.26 14492.09 1.27
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 512 pp2048 12980.69 13579.75 1.05
RX 7900 XTX llama 1B IQ2_XS - 2.3125 bpw 1024 pp2048 14666.17 16202.63 1.10
RX 7900 XTX llama 1B IQ2_XXS - 2.0625 bpw 64 pp2048 6667.97 6304.72 0.95
RX 7900 XTX llama 1B IQ2_XXS - 2.0625 bpw 128 pp2048 9311.87 9311.16 1.00
RX 7900 XTX llama 1B IQ2_XXS - 2.0625 bpw 256 pp2048 13530.44 13821.29 1.02
RX 7900 XTX llama 1B IQ2_XXS - 2.0625 bpw 512 pp2048 14918.22 15340.57 1.03
RX 7900 XTX llama 1B IQ2_XXS - 2.0625 bpw 1024 pp2048 16953.24 17495.25 1.03
RX 7900 XTX llama 1B IQ3_S - 3.4375 bpw 64 pp2048 6758.54 6809.36 1.01
RX 7900 XTX llama 1B IQ3_S - 3.4375 bpw 128 pp2048 9395.78 9383.74 1.00
RX 7900 XTX llama 1B IQ3_S - 3.4375 bpw 256 pp2048 13616.08 13611.88 1.00
RX 7900 XTX llama 1B IQ3_S - 3.4375 bpw 512 pp2048 15014.23 15088.34 1.00
RX 7900 XTX llama 1B IQ3_S - 3.4375 bpw 1024 pp2048 16841.61 17049.62 1.01
RX 7900 XTX llama 1B IQ3_XXS - 3.0625 bpw 64 pp2048 6989.48 6992.81 1.00
RX 7900 XTX llama 1B IQ3_XXS - 3.0625 bpw 128 pp2048 9329.01 9360.47 1.00
RX 7900 XTX llama 1B IQ3_XXS - 3.0625 bpw 256 pp2048 13667.47 14401.52 1.05
RX 7900 XTX llama 1B IQ3_XXS - 3.0625 bpw 512 pp2048 15279.04 15682.87 1.03
RX 7900 XTX llama 1B IQ3_XXS - 3.0625 bpw 1024 pp2048 17511.66 17846.33 1.02
RX 7900 XTX llama 1B IQ4_NL - 4.5 bpw 64 pp2048 7739.35 7772.91 1.00
RX 7900 XTX llama 1B IQ4_NL - 4.5 bpw 128 pp2048 10303.88 10303.47 1.00
RX 7900 XTX llama 1B IQ4_NL - 4.5 bpw 256 pp2048 14758.83 14790.00 1.00
RX 7900 XTX llama 1B IQ4_NL - 4.5 bpw 512 pp2048 16346.91 16396.66 1.00
RX 7900 XTX llama 1B IQ4_NL - 4.5 bpw 1024 pp2048 18466.96 18591.02 1.01
RX 7900 XTX llama 1B IQ4_XS - 4.25 bpw 64 pp2048 7811.50 7863.27 1.01
RX 7900 XTX llama 1B IQ4_XS - 4.25 bpw 128 pp2048 10306.37 10308.72 1.00
RX 7900 XTX llama 1B IQ4_XS - 4.25 bpw 256 pp2048 14751.24 14769.14 1.00
RX 7900 XTX llama 1B IQ4_XS - 4.25 bpw 512 pp2048 16341.84 16349.54 1.00
RX 7900 XTX llama 1B IQ4_XS - 4.25 bpw 1024 pp2048 18228.46 18714.99 1.03
RX 7900 XTX llama 1B Q2_K_M 64 pp2048 5192.98 4812.87 0.93
RX 7900 XTX llama 1B Q2_K_M 128 pp2048 7344.04 7361.62 1.00
RX 7900 XTX llama 1B Q2_K_M 256 pp2048 10308.52 13559.73 1.32
RX 7900 XTX llama 1B Q2_K_M 512 pp2048 11137.57 12940.25 1.16
RX 7900 XTX llama 1B Q2_K_M 1024 pp2048 12807.04 15554.49 1.21
RX 7900 XTX llama 1B Q3_K_M 64 pp2048 6953.68 6973.70 1.00
RX 7900 XTX llama 1B Q3_K_M 128 pp2048 8895.70 8907.37 1.00
RX 7900 XTX llama 1B Q3_K_M 256 pp2048 12954.66 13047.81 1.01
RX 7900 XTX llama 1B Q3_K_M 512 pp2048 14688.59 14692.35 1.00
RX 7900 XTX llama 1B Q3_K_M 1024 pp2048 16542.45 16624.59 1.00
RX 7900 XTX llama 1B Q4_0 64 pp2048 7562.55 7610.92 1.01
RX 7900 XTX llama 1B Q4_0 128 pp2048 10106.39 10096.22 1.00
RX 7900 XTX llama 1B Q4_0 256 pp2048 14566.00 14613.30 1.00
RX 7900 XTX llama 1B Q4_0 512 pp2048 16083.62 16129.93 1.00
RX 7900 XTX llama 1B Q4_0 1024 pp2048 18178.06 18300.13 1.01
RX 7900 XTX llama 1B Q4_1 64 pp2048 7509.82 6766.76 0.90
RX 7900 XTX llama 1B Q4_1 128 pp2048 8963.73 9026.38 1.01
RX 7900 XTX llama 1B Q4_1 256 pp2048 13117.91 13168.37 1.00
RX 7900 XTX llama 1B Q4_1 512 pp2048 14950.75 15028.10 1.01
RX 7900 XTX llama 1B Q4_1 1024 pp2048 16740.56 16966.23 1.01
RX 7900 XTX llama 1B Q4_K_M 64 pp2048 6909.62 6912.13 1.00
RX 7900 XTX llama 1B Q4_K_M 128 pp2048 8315.25 8360.44 1.01
RX 7900 XTX llama 1B Q4_K_M 256 pp2048 12427.88 13770.19 1.11
RX 7900 XTX llama 1B Q4_K_M 512 pp2048 14203.03 15579.25 1.10
RX 7900 XTX llama 1B Q4_K_M 1024 pp2048 16391.75 17443.34 1.06
RX 7900 XTX llama 1B Q5_0 64 pp2048 7118.02 7161.32 1.01
RX 7900 XTX llama 1B Q5_0 128 pp2048 9771.66 9755.90 1.00
RX 7900 XTX llama 1B Q5_0 256 pp2048 14112.98 14134.09 1.00
RX 7900 XTX llama 1B Q5_0 512 pp2048 15561.64 15640.29 1.01
RX 7900 XTX llama 1B Q5_0 1024 pp2048 17557.71 17678.22 1.01
RX 7900 XTX llama 1B Q5_1 64 pp2048 6777.00 6776.62 1.00
RX 7900 XTX llama 1B Q5_1 128 pp2048 8387.54 8441.55 1.01
RX 7900 XTX llama 1B Q5_1 256 pp2048 12443.72 12487.28 1.00
RX 7900 XTX llama 1B Q5_1 512 pp2048 14397.02 14466.45 1.00
RX 7900 XTX llama 1B Q5_1 1024 pp2048 16301.91 16438.11 1.01
RX 7900 XTX llama 1B Q5_K_M 64 pp2048 6959.61 6974.76 1.00
RX 7900 XTX llama 1B Q5_K_M 128 pp2048 7993.31 8064.42 1.01
RX 7900 XTX llama 1B Q5_K_M 256 pp2048 11951.11 13219.55 1.11
RX 7900 XTX llama 1B Q5_K_M 512 pp2048 13692.80 14974.54 1.09
RX 7900 XTX llama 1B Q5_K_M 1024 pp2048 15581.72 16945.35 1.09
RX 7900 XTX llama 1B Q6_K 64 pp2048 5576.08 5604.39 1.01
RX 7900 XTX llama 1B Q6_K 128 pp2048 5925.98 5955.94 1.01
RX 7900 XTX llama 1B Q6_K 256 pp2048 8985.03 14162.37 1.58
RX 7900 XTX llama 1B Q6_K 512 pp2048 10344.84 13421.27 1.30
RX 7900 XTX llama 1B Q6_K 1024 pp2048 11870.41 15930.02 1.34
RX 7900 XTX llama 1B Q8_0 64 pp2048 7481.56 7506.52 1.00
RX 7900 XTX llama 1B Q8_0 128 pp2048 10159.60 10164.84 1.00
RX 7900 XTX llama 1B Q8_0 256 pp2048 14538.35 14473.59 1.00
RX 7900 XTX llama 1B Q8_0 512 pp2048 16438.17 16492.02 1.00
RX 7900 XTX llama 1B Q8_0 1024 pp2048 18724.04 18843.78 1.01
RX 7900 XTX qwen3 14B IQ1_S - 1.5625 bpw 64 pp2048 1140.61 1141.17 1.00
RX 7900 XTX qwen3 14B IQ1_S - 1.5625 bpw 128 pp2048 1359.67 1352.72 0.99
RX 7900 XTX qwen3 14B IQ1_S - 1.5625 bpw 256 pp2048 1563.62 1597.59 1.02
RX 7900 XTX qwen3 14B IQ1_S - 1.5625 bpw 512 pp2048 1655.34 1685.87 1.02
RX 7900 XTX qwen3 14B IQ1_S - 1.5625 bpw 1024 pp2048 1753.76 1786.72 1.02
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 64 pp2048 1030.98 1029.35 1.00
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 128 pp2048 1135.86 1132.83 1.00
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 256 pp2048 1324.46 1455.10 1.10
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 512 pp2048 1391.35 1503.48 1.08
RX 7900 XTX qwen3 14B IQ2_S - 2.5 bpw 1024 pp2048 1498.16 1591.97 1.06
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 64 pp2048 1001.96 1000.14 1.00
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 128 pp2048 1099.44 1094.66 1.00
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 256 pp2048 1282.21 1465.35 1.14
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 512 pp2048 1363.45 1476.22 1.08
RX 7900 XTX qwen3 14B IQ2_XS - 2.3125 bpw 1024 pp2048 1458.91 1562.08 1.07
RX 7900 XTX qwen3 14B IQ2_XXS - 2.0625 bpw 64 pp2048 1032.14 1034.30 1.00
RX 7900 XTX qwen3 14B IQ2_XXS - 2.0625 bpw 128 pp2048 1366.24 1361.42 1.00
RX 7900 XTX qwen3 14B IQ2_XXS - 2.0625 bpw 256 pp2048 1544.25 1580.51 1.02
RX 7900 XTX qwen3 14B IQ2_XXS - 2.0625 bpw 512 pp2048 1632.63 1667.30 1.02
RX 7900 XTX qwen3 14B IQ2_XXS - 2.0625 bpw 1024 pp2048 1748.30 1779.36 1.02
RX 7900 XTX qwen3 14B IQ3_S - 3.4375 bpw 64 pp2048 1045.70 1044.43 1.00
RX 7900 XTX qwen3 14B IQ3_S - 3.4375 bpw 128 pp2048 1377.53 1374.77 1.00
RX 7900 XTX qwen3 14B IQ3_S - 3.4375 bpw 256 pp2048 1563.80 1558.34 1.00
RX 7900 XTX qwen3 14B IQ3_S - 3.4375 bpw 512 pp2048 1640.04 1648.42 1.01
RX 7900 XTX qwen3 14B IQ3_S - 3.4375 bpw 1024 pp2048 1762.73 1737.97 0.99
RX 7900 XTX qwen3 14B IQ3_XXS - 3.0625 bpw 64 pp2048 1096.26 1093.64 1.00
RX 7900 XTX qwen3 14B IQ3_XXS - 3.0625 bpw 128 pp2048 1394.21 1390.97 1.00
RX 7900 XTX qwen3 14B IQ3_XXS - 3.0625 bpw 256 pp2048 1600.20 1649.71 1.03
RX 7900 XTX qwen3 14B IQ3_XXS - 3.0625 bpw 512 pp2048 1688.29 1721.06 1.02
RX 7900 XTX qwen3 14B IQ3_XXS - 3.0625 bpw 1024 pp2048 1801.51 1794.65 1.00
RX 7900 XTX qwen3 14B IQ4_NL - 4.5 bpw 64 pp2048 1247.63 1243.49 1.00
RX 7900 XTX qwen3 14B IQ4_NL - 4.5 bpw 128 pp2048 1523.49 1522.58 1.00
RX 7900 XTX qwen3 14B IQ4_NL - 4.5 bpw 256 pp2048 1755.56 1753.71 1.00
RX 7900 XTX qwen3 14B IQ4_NL - 4.5 bpw 512 pp2048 1874.83 1872.05 1.00
RX 7900 XTX qwen3 14B IQ4_NL - 4.5 bpw 1024 pp2048 1961.00 2002.52 1.02
RX 7900 XTX qwen3 14B IQ4_XS - 4.25 bpw 64 pp2048 1262.09 1263.23 1.00
RX 7900 XTX qwen3 14B IQ4_XS - 4.25 bpw 128 pp2048 1524.40 1523.44 1.00
RX 7900 XTX qwen3 14B IQ4_XS - 4.25 bpw 256 pp2048 1751.32 1752.22 1.00
RX 7900 XTX qwen3 14B IQ4_XS - 4.25 bpw 512 pp2048 1872.45 1870.28 1.00
RX 7900 XTX qwen3 14B IQ4_XS - 4.25 bpw 1024 pp2048 1984.84 1963.28 0.99
RX 7900 XTX qwen3 14B Q2_K_M 64 pp2048 718.31 718.29 1.00
RX 7900 XTX qwen3 14B Q2_K_M 128 pp2048 952.23 951.22 1.00
RX 7900 XTX qwen3 14B Q2_K_M 256 pp2048 1060.96 1421.48 1.34
RX 7900 XTX qwen3 14B Q2_K_M 512 pp2048 1112.72 1447.63 1.30
RX 7900 XTX qwen3 14B Q2_K_M 1024 pp2048 1198.96 1623.59 1.35
RX 7900 XTX qwen3 14B Q3_K_M 64 pp2048 1079.69 1079.12 1.00
RX 7900 XTX qwen3 14B Q3_K_M 128 pp2048 1296.56 1295.70 1.00
RX 7900 XTX qwen3 14B Q3_K_M 256 pp2048 1511.63 1510.58 1.00
RX 7900 XTX qwen3 14B Q3_K_M 512 pp2048 1594.04 1583.39 0.99
RX 7900 XTX qwen3 14B Q3_K_M 1024 pp2048 1711.89 1711.06 1.00
RX 7900 XTX qwen3 14B Q4_0 64 pp2048 1201.50 1200.02 1.00
RX 7900 XTX qwen3 14B Q4_0 128 pp2048 1496.31 1495.05 1.00
RX 7900 XTX qwen3 14B Q4_0 256 pp2048 1710.85 1707.70 1.00
RX 7900 XTX qwen3 14B Q4_0 512 pp2048 1814.45 1813.08 1.00
RX 7900 XTX qwen3 14B Q4_0 1024 pp2048 1942.44 1924.26 0.99
RX 7900 XTX qwen3 14B Q4_1 64 pp2048 1211.52 1210.78 1.00
RX 7900 XTX qwen3 14B Q4_1 128 pp2048 1321.09 1319.58 1.00
RX 7900 XTX qwen3 14B Q4_1 256 pp2048 1559.01 1557.82 1.00
RX 7900 XTX qwen3 14B Q4_1 512 pp2048 1646.61 1645.66 1.00
RX 7900 XTX qwen3 14B Q4_1 1024 pp2048 1759.53 1758.46 1.00
RX 7900 XTX qwen3 14B Q4_K_M 64 pp2048 1091.35 1090.71 1.00
RX 7900 XTX qwen3 14B Q4_K_M 128 pp2048 1223.08 1222.06 1.00
RX 7900 XTX qwen3 14B Q4_K_M 256 pp2048 1463.26 1594.43 1.09
RX 7900 XTX qwen3 14B Q4_K_M 512 pp2048 1551.95 1693.88 1.09
RX 7900 XTX qwen3 14B Q4_K_M 1024 pp2048 1663.48 1776.79 1.07
RX 7900 XTX qwen3 14B Q5_0 64 pp2048 1125.17 1124.47 1.00
RX 7900 XTX qwen3 14B Q5_0 128 pp2048 1431.73 1430.83 1.00
RX 7900 XTX qwen3 14B Q5_0 256 pp2048 1637.76 1638.62 1.00
RX 7900 XTX qwen3 14B Q5_0 512 pp2048 1736.37 1737.31 1.00
RX 7900 XTX qwen3 14B Q5_0 1024 pp2048 1862.31 1860.42 1.00
RX 7900 XTX qwen3 14B Q5_1 64 pp2048 1151.87 1151.59 1.00
RX 7900 XTX qwen3 14B Q5_1 128 pp2048 1269.13 1269.21 1.00
RX 7900 XTX qwen3 14B Q5_1 256 pp2048 1492.63 1492.59 1.00
RX 7900 XTX qwen3 14B Q5_1 512 pp2048 1573.52 1574.66 1.00
RX 7900 XTX qwen3 14B Q5_1 1024 pp2048 1691.33 1692.26 1.00
RX 7900 XTX qwen3 14B Q5_K_M 64 pp2048 1095.75 1095.21 1.00
RX 7900 XTX qwen3 14B Q5_K_M 128 pp2048 1178.49 1177.27 1.00
RX 7900 XTX qwen3 14B Q5_K_M 256 pp2048 1399.05 1516.18 1.08
RX 7900 XTX qwen3 14B Q5_K_M 512 pp2048 1478.38 1608.14 1.09
RX 7900 XTX qwen3 14B Q5_K_M 1024 pp2048 1594.73 1684.33 1.06
RX 7900 XTX qwen3 14B Q6_K 64 pp2048 827.01 826.16 1.00
RX 7900 XTX qwen3 14B Q6_K 128 pp2048 829.64 829.41 1.00
RX 7900 XTX qwen3 14B Q6_K 256 pp2048 987.74 1414.90 1.43
RX 7900 XTX qwen3 14B Q6_K 512 pp2048 1036.21 1530.71 1.48
RX 7900 XTX qwen3 14B Q6_K 1024 pp2048 1114.71 1577.11 1.41
RX 7900 XTX qwen3 14B Q8_0 64 pp2048 1192.57 1196.07 1.00
RX 7900 XTX qwen3 14B Q8_0 128 pp2048 1490.71 1495.53 1.00
RX 7900 XTX qwen3 14B Q8_0 256 pp2048 1748.60 1753.73 1.00
RX 7900 XTX qwen3 14B Q8_0 512 pp2048 1854.19 1859.18 1.00
RX 7900 XTX qwen3 14B Q8_0 1024 pp2048 1974.66 1983.18 1.00
RX 7900 XTX qwen3moe 30B.A3B Q4_K_M 64 pp2048 1095.47 1100.05 1.00
RX 7900 XTX qwen3moe 30B.A3B Q4_K_M 128 pp2048 1152.03 1158.15 1.01
RX 7900 XTX qwen3moe 30B.A3B Q4_K_M 256 pp2048 1801.08 1823.55 1.01
RX 7900 XTX qwen3moe 30B.A3B Q4_K_M 512 pp2048 2523.29 2555.23 1.01
RX 7900 XTX qwen3moe 30B.A3B Q4_K_M 1024 pp2048 3331.65 3374.67 1.01
mmq/blas/tuned breakdown
model_filename model_n_params n_ubatch n_prompt mmq_ts blas_ts tuned_ts mmq/blas mmq% blas% tuned%
14B/IQ1_S.gguf 14768307200 64 2048 1140.61 557.03 1141.17 104.76 100.0 48.84 100.05
14B/IQ1_S.gguf 14768307200 128 2048 1359.67 988.76 1352.72 37.51 100.0 72.72 99.49
14B/IQ1_S.gguf 14768307200 256 2048 1563.62 1513.45 1597.59 3.31 100.0 96.79 102.17
14B/IQ1_S.gguf 14768307200 512 2048 1655.34 1485.16 1685.87 11.46 100.0 89.72 101.84
14B/IQ1_S.gguf 14768307200 1024 2048 1753.76 1566.91 1786.72 11.92 100.0 89.35 101.88
14B/IQ2_S.gguf 14768307200 64 2048 1030.98 549.34 1029.35 87.68 100.0 53.28 99.84
14B/IQ2_S.gguf 14768307200 128 2048 1135.86 977.31 1132.83 16.22 100.0 86.04 99.73
14B/IQ2_S.gguf 14768307200 256 2048 1324.46 1486.75 1455.1 -10.92 89.08 100.0 97.87
14B/IQ2_S.gguf 14768307200 512 2048 1391.35 1477.13 1503.48 -5.81 94.19 100.0 101.78
14B/IQ2_S.gguf 14768307200 1024 2048 1498.16 1552.33 1591.97 -3.49 96.51 100.0 102.55
14B/IQ2_XS.gguf 14768307200 64 2048 1001.96 550.12 1000.14 82.13 100.0 54.9 99.82
14B/IQ2_XS.gguf 14768307200 128 2048 1099.44 976.01 1094.66 12.65 100.0 88.77 99.56
14B/IQ2_XS.gguf 14768307200 256 2048 1282.21 1489.22 1465.35 -13.9 86.1 100.0 98.4
14B/IQ2_XS.gguf 14768307200 512 2048 1363.45 1492.71 1476.22 -8.66 91.34 100.0 98.9
14B/IQ2_XS.gguf 14768307200 1024 2048 1458.91 1555.26 1562.08 -6.2 93.8 100.0 100.44
14B/IQ2_XXS.gguf 14768307200 64 2048 1032.14 551.64 1034.3 87.11 100.0 53.45 100.21
14B/IQ2_XXS.gguf 14768307200 128 2048 1366.24 987.0 1361.42 38.42 100.0 72.24 99.65
14B/IQ2_XXS.gguf 14768307200 256 2048 1544.25 1501.45 1580.51 2.85 100.0 97.23 102.35
14B/IQ2_XXS.gguf 14768307200 512 2048 1632.63 1496.13 1667.3 9.12 100.0 91.64 102.12
14B/IQ2_XXS.gguf 14768307200 1024 2048 1748.3 1566.41 1779.36 11.61 100.0 89.6 101.78
14B/IQ3_S.gguf 14768307200 64 2048 1045.7 546.83 1044.43 91.23 100.0 52.29 99.88
14B/IQ3_S.gguf 14768307200 128 2048 1377.53 981.26 1374.77 40.38 100.0 71.23 99.8
14B/IQ3_S.gguf 14768307200 256 2048 1563.8 1474.86 1558.34 6.03 100.0 94.31 99.65
14B/IQ3_S.gguf 14768307200 512 2048 1640.04 1500.24 1648.42 9.32 100.0 91.48 100.51
14B/IQ3_S.gguf 14768307200 1024 2048 1762.73 1554.7 1737.97 13.38 100.0 88.2 98.59
14B/IQ3_XXS.gguf 14768307200 64 2048 1096.26 544.75 1093.64 101.24 100.0 49.69 99.76
14B/IQ3_XXS.gguf 14768307200 128 2048 1394.21 980.82 1390.97 42.15 100.0 70.35 99.77
14B/IQ3_XXS.gguf 14768307200 256 2048 1600.2 1439.52 1649.71 11.16 100.0 89.96 103.09
14B/IQ3_XXS.gguf 14768307200 512 2048 1688.29 1507.53 1721.06 11.99 100.0 89.29 101.94
14B/IQ3_XXS.gguf 14768307200 1024 2048 1801.51 1550.74 1794.65 16.17 100.0 86.08 99.62
14B/IQ4_NL.gguf 14768307200 64 2048 1247.63 542.82 1243.49 129.84 100.0 43.51 99.67
14B/IQ4_NL.gguf 14768307200 128 2048 1523.49 984.15 1522.58 54.8 100.0 64.6 99.94
14B/IQ4_NL.gguf 14768307200 256 2048 1755.56 1464.65 1753.71 19.86 100.0 83.43 99.89
14B/IQ4_NL.gguf 14768307200 512 2048 1874.83 1505.38 1872.05 24.54 100.0 80.29 99.85
14B/IQ4_NL.gguf 14768307200 1024 2048 1961.0 1558.94 2002.52 25.79 100.0 79.5 102.12
14B/IQ4_XS.gguf 14768307200 64 2048 1262.09 540.68 1263.23 133.43 100.0 42.84 100.09
14B/IQ4_XS.gguf 14768307200 128 2048 1524.4 985.87 1523.44 54.63 100.0 64.67 99.94
14B/IQ4_XS.gguf 14768307200 256 2048 1751.32 1468.94 1752.22 19.22 100.0 83.88 100.05
14B/IQ4_XS.gguf 14768307200 512 2048 1872.45 1495.29 1870.28 25.22 100.0 79.86 99.88
14B/IQ4_XS.gguf 14768307200 1024 2048 1984.84 1560.43 1963.28 27.2 100.0 78.62 98.91
14B/Q2_K.gguf 14768307200 64 2048 718.31 547.06 718.29 31.3 100.0 76.16 100.0
14B/Q2_K.gguf 14768307200 128 2048 952.23 971.48 951.22 -1.98 98.02 100.0 97.91
14B/Q2_K.gguf 14768307200 256 2048 1060.96 1445.4 1421.48 -26.6 73.4 100.0 98.35
14B/Q2_K.gguf 14768307200 512 2048 1112.72 1483.19 1447.63 -24.98 75.02 100.0 97.6
14B/Q2_K.gguf 14768307200 1024 2048 1198.96 1553.28 1623.59 -22.81 77.19 100.0 104.53
14B/Q3_K.gguf 14768307200 64 2048 1079.69 573.68 1079.12 88.2 100.0 53.13 99.95
14B/Q3_K.gguf 14768307200 128 2048 1296.56 1005.21 1295.7 28.98 100.0 77.53 99.93
14B/Q3_K.gguf 14768307200 256 2048 1511.63 1386.17 1510.58 9.05 100.0 91.7 99.93
14B/Q3_K.gguf 14768307200 512 2048 1594.04 1521.6 1583.39 4.76 100.0 95.46 99.33
14B/Q3_K.gguf 14768307200 1024 2048 1711.89 1566.33 1711.06 9.29 100.0 91.5 99.95
14B/Q4_0.gguf 14768307200 64 2048 1201.5 548.27 1200.02 119.15 100.0 45.63 99.88
14B/Q4_0.gguf 14768307200 128 2048 1496.31 995.61 1495.05 50.29 100.0 66.54 99.92
14B/Q4_0.gguf 14768307200 256 2048 1710.85 1475.79 1707.7 15.93 100.0 86.26 99.82
14B/Q4_0.gguf 14768307200 512 2048 1814.45 1523.29 1813.08 19.11 100.0 83.95 99.92
14B/Q4_0.gguf 14768307200 1024 2048 1942.44 1565.44 1924.26 24.08 100.0 80.59 99.06
14B/Q4_1.gguf 14768307200 64 2048 1211.52 545.63 1210.78 122.04 100.0 45.04 99.94
14B/Q4_1.gguf 14768307200 128 2048 1321.09 994.21 1319.58 32.88 100.0 75.26 99.89
14B/Q4_1.gguf 14768307200 256 2048 1559.01 1466.78 1557.82 6.29 100.0 94.08 99.92
14B/Q4_1.gguf 14768307200 512 2048 1646.61 1527.81 1645.66 7.78 100.0 92.79 99.94
14B/Q4_1.gguf 14768307200 1024 2048 1759.53 1564.76 1758.46 12.45 100.0 88.93 99.94
14B/Q4_K.gguf 14768307200 64 2048 1091.35 547.94 1090.71 99.17 100.0 50.21 99.94
14B/Q4_K.gguf 14768307200 128 2048 1223.08 990.97 1222.06 23.42 100.0 81.02 99.92
14B/Q4_K.gguf 14768307200 256 2048 1463.26 1443.11 1594.43 1.4 100.0 98.62 108.96
14B/Q4_K.gguf 14768307200 512 2048 1551.95 1532.37 1693.88 1.28 100.0 98.74 109.15
14B/Q4_K.gguf 14768307200 1024 2048 1663.48 1567.9 1776.79 6.1 100.0 94.25 106.81
14B/Q5_0.gguf 14768307200 64 2048 1125.17 509.11 1124.47 121.01 100.0 45.25 99.94
14B/Q5_0.gguf 14768307200 128 2048 1431.73 890.35 1430.83 60.81 100.0 62.19 99.94
14B/Q5_0.gguf 14768307200 256 2048 1637.76 1255.97 1638.62 30.4 100.0 76.69 100.05
14B/Q5_0.gguf 14768307200 512 2048 1736.37 1462.34 1737.31 18.74 100.0 84.22 100.05
14B/Q5_0.gguf 14768307200 1024 2048 1862.31 1524.64 1860.42 22.15 100.0 81.87 99.9
14B/Q5_1.gguf 14768307200 64 2048 1151.87 505.03 1151.59 128.08 100.0 43.84 99.98
14B/Q5_1.gguf 14768307200 128 2048 1269.13 891.1 1269.21 42.42 100.0 70.21 100.01
14B/Q5_1.gguf 14768307200 256 2048 1492.63 1260.43 1492.59 18.42 100.0 84.44 100.0
14B/Q5_1.gguf 14768307200 512 2048 1573.52 1460.43 1574.66 7.74 100.0 92.81 100.07
14B/Q5_1.gguf 14768307200 1024 2048 1691.33 1509.13 1692.26 12.07 100.0 89.23 100.06
14B/Q5_K.gguf 14768307200 64 2048 1095.75 533.91 1095.21 105.23 100.0 48.73 99.95
14B/Q5_K.gguf 14768307200 128 2048 1178.49 964.74 1177.27 22.16 100.0 81.86 99.9
14B/Q5_K.gguf 14768307200 256 2048 1399.05 1376.83 1516.18 1.61 100.0 98.41 108.37
14B/Q5_K.gguf 14768307200 512 2048 1478.38 1517.64 1608.14 -2.59 97.41 100.0 105.96
14B/Q5_K.gguf 14768307200 1024 2048 1594.73 1558.64 1684.33 2.32 100.0 97.74 105.62
14B/Q6_K.gguf 14768307200 64 2048 827.01 526.74 826.16 57.01 100.0 63.69 99.9
14B/Q6_K.gguf 14768307200 128 2048 829.64 972.01 829.41 -14.65 85.35 100.0 85.33
14B/Q6_K.gguf 14768307200 256 2048 987.74 1412.01 1414.9 -30.05 69.95 100.0 100.2
14B/Q6_K.gguf 14768307200 512 2048 1036.21 1527.21 1530.71 -32.15 67.85 100.0 100.23
14B/Q6_K.gguf 14768307200 1024 2048 1114.71 1556.72 1577.11 -28.39 71.61 100.0 101.31
14B/Q8_0.gguf 14768307200 64 2048 1192.57 503.44 1196.07 136.89 100.0 42.21 100.29
14B/Q8_0.gguf 14768307200 128 2048 1490.71 933.45 1495.53 59.7 100.0 62.62 100.32
14B/Q8_0.gguf 14768307200 256 2048 1748.6 1384.78 1753.73 26.27 100.0 79.19 100.29
14B/Q8_0.gguf 14768307200 512 2048 1854.19 1485.42 1859.18 24.83 100.0 80.11 100.27
14B/Q8_0.gguf 14768307200 1024 2048 1974.66 1558.26 1983.18 26.72 100.0 78.91 100.43
MOE/128e.gguf 30532122624 64 2048 1095.47 388.54 1100.05 181.95 100.0 35.47 100.42
MOE/128e.gguf 30532122624 128 2048 1152.03 608.09 1158.15 89.45 100.0 52.78 100.53
MOE/128e.gguf 30532122624 256 2048 1801.08 952.55 1823.55 89.08 100.0 52.89 101.25
MOE/128e.gguf 30532122624 512 2048 2523.29 1342.15 2555.23 88.0 100.0 53.19 101.27
MOE/128e.gguf 30532122624 1024 2048 3331.65 1724.47 3374.67 93.2 100.0 51.76 101.29
MOE/32e.gguf 20914757184 64 2048 1526.87 1262.36 1533.15 20.95 100.0 82.68 100.41
MOE/32e.gguf 20914757184 128 2048 1736.72 1913.88 1744.08 -9.26 90.74 100.0 91.13
MOE/32e.gguf 20914757184 256 2048 2670.39 2701.05 2723.63 -1.14 98.86 100.0 100.84
MOE/32e.gguf 20914757184 512 2048 3844.0 3490.39 3857.77 10.13 100.0 90.8 100.36
MOE/32e.gguf 20914757184 1024 2048 4767.12 4317.11 4776.7 10.42 100.0 90.56 100.2
MOE/64e.gguf 6939037248 64 2048 1996.18 660.28 2051.94 202.32 100.0 33.08 102.79
MOE/64e.gguf 6939037248 128 2048 2363.13 1104.21 2458.62 114.01 100.0 46.73 104.04
MOE/64e.gguf 6939037248 256 2048 3752.94 1663.58 3848.7 125.59 100.0 44.33 102.55
MOE/64e.gguf 6939037248 512 2048 5144.1 2498.93 5239.24 105.85 100.0 48.58 101.85
MOE/64e.gguf 6939037248 1024 2048 5951.86 3331.43 6029.21 78.66 100.0 55.97 101.3

Edge Cases

When excluding the 1B model which is noisy, there's exactly two outliers

model_filename model_n_params n_ubatch n_prompt mmq_ts blas_ts tuned_ts mmq/blas mmq% blas% tuned%
MOE/32e.gguf 20914757184 128 2048 1736.72 1913.88 1744.08 -9.26 90.74 100.0 91.13
14B/Q6_K.gguf 14768307200 128 2048 829.64 972.01 829.41 -14.65 85.35 100.0 85.33

Both of which are on bs=128, and both of which quickly flip < 128. If you wanted to fudge this case, you could probably do something like

diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index ccb9ebed5..b7c1b7dc2 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -344,6 +344,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
             // These quants are really bad on MMQ
             case GGML_TYPE_Q2_K:
             case GGML_TYPE_Q6_K:
+                return ne11 < 128;  // ==128 specifically is much better on rocblas
             // These quants are usually worse but not always
             case GGML_TYPE_IQ2_XS:
             case GGML_TYPE_IQ2_S:

to get avg tuned% > 99, but that might be considered splitting hairs.

Testing Setup

100% of my testing was on GFX1100, ROCm 7.1.1, compile flags

GGML_CUDA_FA_ALL_QUANTS=ON
GGML_HIP=ON
GGML_HIP_GRAPHS=ON

and forced mmq/cublas as appropriate for measuring.

I do not own any RDNA3.5 or RDNA4 hardware. I'm assuming RDNA3.5 will behave pretty much the same, but since RDNA4 has some implementation differences in MMQ, it may be worth for someone to re-measure in the future.

Methodology

The raw data for every combination of model / batch / backend can be viewed at measurements.csv which was generated by the scripts on huggingface.
Different from the other PRs, I've made the baseline MMQ as it seems to better handle most cases.
In general, I put little weight on the 1B results as it's extremely noisy, even with hip graphs. For cases that were a wash between µbatch sizes like Q4_K and Q5_K, I simply preferred MMQ.

Copy link
Contributor Author

@Beinsezii Beinsezii left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this should be satisfactory in all cases.

@IMbackK in #14949 you mentioned doing more comprehensive tests for CDNA in the future. Won't be 1:1 but possibly this could apply there too. At the very least I think the current CDNA code might be over-prioritizing rocblas with the default branch being false but I have no way to test this.

Comment on lines +336 to +341
// High expert counts almost always better on MMQ
// due to a large amount of graph splits
// https://github.com/ggml-org/llama.cpp/pull/18202
if (n_experts >= 64) {
return true;
}
Copy link
Contributor Author

@Beinsezii Beinsezii Jan 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

32 for GPT OSS 20B was also sort of a wash, but I'm assuming non-fp4 models will benefit more from the quant cases instead. There's not a lot of MOEs I can run fully in VRAM, so testing this was limited.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking more, since the default case is true anyways, this might not actually be needed. I think the reason it was added in #18202 is because the default case is false which leads to most quants running rocblas by default.

The worst case would be a highly sparse MoE running @ Q6_K or Q2_K so probably I can test this specifically using GPT 20B and Qwen 30B for 32 and 128 experts respectively this weekend.

Comment on lines +343 to +353
switch (type) {
// These quants are really bad on MMQ
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q6_K:
// These quants are usually worse but not always
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
return ne11 <= 128;
default:
return true;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason for the elaborate dataset is I was expecting this switch to be a lot larger, but I think most of the problem is just that Q2_K and Q6_K perform unusually poor on MMQ, typically by >30%

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 2, 2026
@Beinsezii
Copy link
Contributor Author

One particularly fun case is on models with heterogeneous mmq/rocblas layers like https://huggingface.co/Beinsezii/Mistral-Small-3.2-24B-Instruct-2506-Q6F-Q8A-GGUF this branch is faster than either forced mmq / cublas separately.

@Beinsezii Beinsezii changed the title mmq.cu: tune mmq/wmma switching for RDNA mmq.cu: tune mmq/rocblas switching for RDNA Jan 2, 2026
@Beinsezii
Copy link
Contributor Author

Added compare-llama-bench formatted version of results per request.

I did not include quants outside of cuda mmq scope. There's a variety of models demonstrated but they do not include Llama3-8B specifically. Can be added if needed.

@IMbackK
Copy link
Collaborator

IMbackK commented Jan 2, 2026

Some things to consider here:
Amd has 2 kernel generators for gemm, tensilelite and tensile.

  1. For low precision types tensilelite usually wins on all supported platforms except gfx908, gfx90a
  2. For fp32 tensile usually wins.
  3. Rocblas uses tensile unless you choose tensilelite via its api or via envvar
  4. For gfx12 tensilelt is used regardless

This pr is Assuming that the results for gfx11 on tensile has equivalent performance to gf12 on tensilelite which i find unlikely from CDNA experience.

CDNA1/2 have well tuned kernels in tensile while RDNA GPUs are less well tuned there, AFAIK the tensilelite kernels are reasonably well tuned for gfx11 (only the large register file versions like gfx1100) and gfx12. This is why the rocblas path wins more often on cdna.

Since this PR is comparing mmq against gfx11 on tensile which is a known slower path, it might make sense to compare against tensilelite too, or to switch gfx11 to tensilelite like gfx12 and then choose mmq where this is slower.

Compared to cdna on rDNA you also have to contend that the devices often have very little vram, while cdna can't have less then 32gb (mmq saves dequant buffers)

@IMbackK
Copy link
Collaborator

IMbackK commented Jan 2, 2026

Regardless, this PR is an overall improvement and the above can be investigated at a later time

@JohannesGaessler
Copy link
Contributor

Let me be frank here: As of right now I don't have the bandwidth to be micromanaging the AMD kernel selection logic to this degree. The kernel generator selection is something that rocBLAS should be deciding automatically. Alternatively I would like to outsource the corresponding logic in llama.cpp/ggml to someone else who would be available to maintain it long-term.

@IMbackK
Copy link
Collaborator

IMbackK commented Jan 2, 2026

I already see this bit as something in my area, i also have a gfx11 (gfx1100) device on the way so no need for you to do anything, aside from merging the reviewed PRs since I lost write access. If the selection logic gets long, which it might if we also start selecting between tensile and tensilelite, we can also move this logic go ggml-hip to get it out out of your way.

Anyhow this is soemthing for later.

@IMbackK
Copy link
Collaborator

IMbackK commented Jan 2, 2026

And yeah the whole situation with the terrible internal kernel selection in rocblas and hipblaslt is quite annoying, the selection of the "correct" kernel generator is only the tip of the iceberg in this regard.

@Beinsezii
Copy link
Contributor Author

Beinsezii commented Jan 2, 2026

Let me be frank here: As of right now I don't have the bandwidth to be micromanaging the AMD kernel selection logic to this degree.

Given the nonlinear quants are typically <+10%, if we really needed to this PR could be simplified to just

    if (amd_wmma_available(cc)) {
        return !(ne11 >= 128 && (type == GGML_TYPE_Q2_K || type == GGML_TYPE_Q6_K));
    }

as I think switching on these two at the least is mandatory since all improvements are >+30%

@Beinsezii
Copy link
Contributor Author

  1. Rocblas uses tensile unless you choose tensilelite via its api or via envvar

ROCBLAS_USE_HIPBLASLT=1? Last time I tried a few months ago this crashed for gfx1100, but I'm not opposed to re-running the bench if it works again. Since it's at runtime it should be easy.

@JohannesGaessler
Copy link
Contributor

The logic as it is in this PR is completely fine. My issue specifically has to do with potentially having to juggle multiple different rocBLAS versions with multiple optional environment variables. I think going forward I will review and maintain the kernel selection logic to align with the newest ROCm version at the default settings, and any other setup will need to be maintained by someone else.

@Beinsezii
Copy link
Contributor Author

I think going forward I will review and maintain the kernel selection logic to align with the newest ROCm version at the default settings

That's pretty much what I went for here. ROCm 6.4 had slightly faster PP for me back a few months ago but I'm not interested in redownloading 50 gigs of outdated libs just to see if there's maybe one or two more cases where rocblas is slightly faster than mmq on the old version.

HIP_GRAPHS I found made effectively 0% difference for pp, but for small/sparse models its about +10% tg so I always have it on.

I think as it stands this should only need re-visiting with major revisions to mmq or rocblas, and with significantly less benchmarking needed at that. Since RDNA support in MMQ is pretty fresh, I'm hoping that ideally as more work is done there it eventually just eats up these last 4 cases and we delete this whole block anyways.

@IMbackK
Copy link
Collaborator

IMbackK commented Jan 2, 2026

  1. Rocblas uses tensile unless you choose tensilelite via its api or via envvar

ROCBLAS_USE_HIPBLASLT=1? Last time I tried a few months ago this crashed for gfx1100, but I'm not opposed to re-running the bench if it works again. Since it's at runtime it should be easy.

for testing that yes, but you can hint rocblas what kernels to use via its c api, which we could use to make it use the tensilelite kernels where they are faster. Sometimes they are alot faster (like 50%+) anyhow this is something for another time.

@Beinsezii
Copy link
Contributor Author

Beinsezii commented Jan 2, 2026

for testing that yes

on gfx1100 looks like ROCBLAS_USE_HIPBLASLT=1 makes exactly 0 difference vs unset. wonder if they just disabled that path on gfx11XX to stop the faults? either that or it's already the default now in 7.1

Probably I'll just do a re-measure if someone ends up adding the kernel hinting through C

@IMbackK
Copy link
Collaborator

IMbackK commented Jan 2, 2026

for testing that yes

on gfx1100 looks like ROCBLAS_USE_HIPBLASLT=1 makes exactly 0 difference vs unset. wonder if they just disabled that path on gfx11XX to stop the faults? either that or it's already the default now in 7.1

Probably I'll just do a re-measure if someone ends up adding the kernel hinting through C

If its the default now there is something we have to do to get best perfromance: unlike tensile, tensilelite supports V_WMMA_F32_16X16X16_F16. Currently because tensile only issues V_WMMA_F16_16X16X16_F16 we accumulate at fp16 and upconvert after, which is stupid from a performance perspective on this hardware and causes extra issues with overflow.

Anyhow as stated i recently ordered a gfx11 device and will get around to doing some optimization work on it in the near future.

@JohannesGaessler
Copy link
Contributor

RDNA4 performance
GPU Model Microbatch size Test t/s be47fb9 t/s a435c77 Speedup
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 1 pp2048 68.69 68.82 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 2 pp2048 120.09 120.13 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 4 pp2048 184.47 184.68 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 8 pp2048 205.58 205.62 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 16 pp2048 569.29 564.61 0.99
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 32 pp2048 747.23 744.42 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 64 pp2048 1253.65 1254.23 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 128 pp2048 1703.77 1677.91 0.98
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 256 pp2048 1858.69 1615.75 0.87
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 512 pp2048 1888.65 1653.46 0.88
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 1024 pp2048 1921.45 1679.51 0.87
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 2048 pp2048 1897.64 1644.51 0.87
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 1 pp2048 57.19 57.14 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 2 pp2048 98.20 98.17 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 4 pp2048 150.44 150.49 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 8 pp2048 204.61 204.58 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 16 pp2048 333.69 334.65 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 32 pp2048 640.28 640.19 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 64 pp2048 1120.56 1121.55 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 128 pp2048 1467.23 1466.47 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 256 pp2048 1587.87 373.67 0.24
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 512 pp2048 1623.16 385.20 0.24
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 1024 pp2048 1647.19 393.93 0.24
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 2048 pp2048 1621.33 394.44 0.24
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 1 pp2048 59.29 59.16 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 2 pp2048 101.31 101.17 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 4 pp2048 151.40 151.28 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 8 pp2048 205.95 205.68 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 16 pp2048 328.24 327.06 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 32 pp2048 634.16 633.63 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 64 pp2048 1099.59 1098.50 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 128 pp2048 1407.26 1406.07 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 256 pp2048 1527.93 338.65 0.22
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 512 pp2048 1558.01 350.58 0.23
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 1024 pp2048 1588.71 356.74 0.22
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 2048 pp2048 1575.42 358.88 0.23
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 1 pp2048 49.50 49.52 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 2 pp2048 88.52 88.46 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 4 pp2048 148.21 148.44 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 8 pp2048 177.61 177.97 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 16 pp2048 463.33 461.38 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 32 pp2048 635.31 635.33 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 64 pp2048 1178.84 1187.44 1.01
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 128 pp2048 1710.56 1705.17 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 256 pp2048 1865.63 1604.23 0.86
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 512 pp2048 1897.50 1642.77 0.87
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 1024 pp2048 1936.72 1675.52 0.87
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 2048 pp2048 1916.88 1661.18 0.87
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 1 pp2048 46.07 46.03 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 2 pp2048 85.77 85.63 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 4 pp2048 146.27 145.75 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 8 pp2048 178.68 178.46 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 16 pp2048 447.56 448.15 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 32 pp2048 628.32 631.06 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 64 pp2048 1187.46 1177.83 0.99
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 128 pp2048 1762.23 1757.89 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 256 pp2048 1900.47 1899.32 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 512 pp2048 1938.78 1935.11 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 1024 pp2048 1970.33 1963.70 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 2048 pp2048 1932.46 1930.39 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 1 pp2048 45.90 46.01 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 2 pp2048 84.59 85.00 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 4 pp2048 143.48 143.85 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 8 pp2048 174.65 175.03 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 16 pp2048 453.40 454.11 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 32 pp2048 643.74 644.38 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 64 pp2048 1196.28 1199.68 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 128 pp2048 1751.81 1756.49 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 256 pp2048 1892.27 1899.63 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 512 pp2048 1931.79 1936.89 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 1024 pp2048 1963.12 1967.22 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 2048 pp2048 1921.56 1932.94 1.01
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 1 pp2048 51.76 51.55 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 2 pp2048 93.53 93.13 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 4 pp2048 149.03 148.21 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 8 pp2048 180.44 180.01 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 16 pp2048 474.93 474.19 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 32 pp2048 658.92 657.57 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 64 pp2048 1218.19 1216.75 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 128 pp2048 1810.12 1797.37 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 256 pp2048 1959.46 1947.81 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 512 pp2048 2004.47 1992.24 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 1024 pp2048 2038.21 2026.73 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 2048 pp2048 1998.51 1988.25 0.99
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 1 pp2048 54.13 54.40 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 2 pp2048 94.62 95.23 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 4 pp2048 146.78 147.84 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 8 pp2048 182.08 182.64 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 16 pp2048 461.53 462.13 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 32 pp2048 673.79 675.92 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 64 pp2048 1226.37 1256.54 1.02
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 128 pp2048 1780.26 1793.39 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 256 pp2048 1937.10 1378.50 0.71
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 512 pp2048 1984.06 1410.04 0.71
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 1024 pp2048 2017.33 1441.88 0.71
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 2048 pp2048 1982.60 1431.47 0.72
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 1 pp2048 48.54 48.63 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 2 pp2048 92.22 92.32 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 4 pp2048 168.12 168.41 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 8 pp2048 193.58 193.93 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 16 pp2048 561.51 561.91 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 32 pp2048 782.72 783.87 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 64 pp2048 1389.64 1375.05 0.99
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 128 pp2048 1956.00 1994.64 1.02
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 256 pp2048 2169.17 2172.63 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 512 pp2048 2215.35 2217.34 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 1024 pp2048 2258.26 2262.33 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 2048 pp2048 2216.37 2219.76 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 1 pp2048 50.52 50.47 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 2 pp2048 96.60 96.69 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 4 pp2048 179.98 179.72 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 8 pp2048 208.33 207.92 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 16 pp2048 595.02 594.96 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 32 pp2048 797.95 797.78 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 64 pp2048 1437.76 1436.09 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 128 pp2048 2035.66 2041.45 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 256 pp2048 2238.64 2228.77 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 512 pp2048 2281.43 2271.05 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 1024 pp2048 2322.52 2316.75 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 2048 pp2048 2285.98 2276.98 1.00
RX 9060 XT llama 8B Q2_K_S 1 pp2048 67.29 67.27 1.00
RX 9060 XT llama 8B Q2_K_S 2 pp2048 105.20 104.06 0.99
RX 9060 XT llama 8B Q2_K_S 4 pp2048 128.24 126.94 0.99
RX 9060 XT llama 8B Q2_K_S 8 pp2048 142.94 141.15 0.99
RX 9060 XT llama 8B Q2_K_S 16 pp2048 347.71 345.56 0.99
RX 9060 XT llama 8B Q2_K_S 32 pp2048 462.14 459.29 0.99
RX 9060 XT llama 8B Q2_K_S 64 pp2048 740.70 738.53 1.00
RX 9060 XT llama 8B Q2_K_S 128 pp2048 905.54 899.19 0.99
RX 9060 XT llama 8B Q2_K_S 256 pp2048 978.28 348.91 0.36
RX 9060 XT llama 8B Q2_K_S 512 pp2048 964.15 361.71 0.38
RX 9060 XT llama 8B Q2_K_S 1024 pp2048 1022.04 369.67 0.36
RX 9060 XT llama 8B Q2_K_S 2048 pp2048 1015.14 371.65 0.37
RX 9060 XT llama 8B Q3_K_S 1 pp2048 50.23 50.03 1.00
RX 9060 XT llama 8B Q3_K_S 2 pp2048 83.84 83.80 1.00
RX 9060 XT llama 8B Q3_K_S 4 pp2048 119.53 119.01 1.00
RX 9060 XT llama 8B Q3_K_S 8 pp2048 143.45 142.72 0.99
RX 9060 XT llama 8B Q3_K_S 16 pp2048 456.60 455.65 1.00
RX 9060 XT llama 8B Q3_K_S 32 pp2048 666.93 662.43 0.99
RX 9060 XT llama 8B Q3_K_S 64 pp2048 1142.25 1141.32 1.00
RX 9060 XT llama 8B Q3_K_S 128 pp2048 1617.95 1608.82 0.99
RX 9060 XT llama 8B Q3_K_S 256 pp2048 1740.89 1714.79 0.99
RX 9060 XT llama 8B Q3_K_S 512 pp2048 1775.65 1766.64 0.99
RX 9060 XT llama 8B Q3_K_S 1024 pp2048 1809.70 1800.47 0.99
RX 9060 XT llama 8B Q3_K_S 2048 pp2048 1796.37 1784.52 0.99
RX 9060 XT llama 8B Q4_0 1 pp2048 48.80 48.83 1.00
RX 9060 XT llama 8B Q4_0 2 pp2048 92.57 92.57 1.00
RX 9060 XT llama 8B Q4_0 4 pp2048 169.50 169.56 1.00
RX 9060 XT llama 8B Q4_0 8 pp2048 200.24 200.41 1.00
RX 9060 XT llama 8B Q4_0 16 pp2048 554.36 552.12 1.00
RX 9060 XT llama 8B Q4_0 32 pp2048 758.38 758.83 1.00
RX 9060 XT llama 8B Q4_0 64 pp2048 1355.11 1349.07 1.00
RX 9060 XT llama 8B Q4_0 128 pp2048 1973.03 1974.87 1.00
RX 9060 XT llama 8B Q4_0 256 pp2048 2142.05 2145.10 1.00
RX 9060 XT llama 8B Q4_0 512 pp2048 2187.72 2188.37 1.00
RX 9060 XT llama 8B Q4_0 1024 pp2048 2237.77 2240.40 1.00
RX 9060 XT llama 8B Q4_0 2048 pp2048 2210.63 2211.25 1.00
RX 9060 XT llama 8B Q4_1 1 pp2048 46.14 46.13 1.00
RX 9060 XT llama 8B Q4_1 2 pp2048 87.36 87.37 1.00
RX 9060 XT llama 8B Q4_1 4 pp2048 161.56 161.47 1.00
RX 9060 XT llama 8B Q4_1 8 pp2048 212.05 211.77 1.00
RX 9060 XT llama 8B Q4_1 16 pp2048 554.58 551.61 0.99
RX 9060 XT llama 8B Q4_1 32 pp2048 763.75 763.82 1.00
RX 9060 XT llama 8B Q4_1 64 pp2048 1316.07 1314.47 1.00
RX 9060 XT llama 8B Q4_1 128 pp2048 1709.67 1704.02 1.00
RX 9060 XT llama 8B Q4_1 256 pp2048 1840.83 1836.41 1.00
RX 9060 XT llama 8B Q4_1 512 pp2048 1875.75 1850.53 0.99
RX 9060 XT llama 8B Q4_1 1024 pp2048 1909.05 1903.01 1.00
RX 9060 XT llama 8B Q4_1 2048 pp2048 1886.32 1882.86 1.00
RX 9060 XT llama 8B Q4_K_S 1 pp2048 48.96 49.02 1.00
RX 9060 XT llama 8B Q4_K_S 2 pp2048 89.77 89.66 1.00
RX 9060 XT llama 8B Q4_K_S 4 pp2048 132.20 131.90 1.00
RX 9060 XT llama 8B Q4_K_S 8 pp2048 148.52 148.43 1.00
RX 9060 XT llama 8B Q4_K_S 16 pp2048 522.11 522.04 1.00
RX 9060 XT llama 8B Q4_K_S 32 pp2048 736.98 737.38 1.00
RX 9060 XT llama 8B Q4_K_S 64 pp2048 1269.63 1270.78 1.00
RX 9060 XT llama 8B Q4_K_S 128 pp2048 1735.84 1734.80 1.00
RX 9060 XT llama 8B Q4_K_S 256 pp2048 1872.27 1846.87 0.99
RX 9060 XT llama 8B Q4_K_S 512 pp2048 1903.24 1901.62 1.00
RX 9060 XT llama 8B Q4_K_S 1024 pp2048 1938.90 1937.89 1.00
RX 9060 XT llama 8B Q4_K_S 2048 pp2048 1918.55 1916.50 1.00
RX 9060 XT llama 8B Q5_0 1 pp2048 43.53 43.58 1.00
RX 9060 XT llama 8B Q5_0 2 pp2048 82.54 82.59 1.00
RX 9060 XT llama 8B Q5_0 4 pp2048 150.46 150.69 1.00
RX 9060 XT llama 8B Q5_0 8 pp2048 195.36 195.40 1.00
RX 9060 XT llama 8B Q5_0 16 pp2048 484.35 484.93 1.00
RX 9060 XT llama 8B Q5_0 32 pp2048 683.42 684.76 1.00
RX 9060 XT llama 8B Q5_0 64 pp2048 1218.36 1222.12 1.00
RX 9060 XT llama 8B Q5_0 128 pp2048 1844.04 1846.53 1.00
RX 9060 XT llama 8B Q5_0 256 pp2048 2001.99 1997.67 1.00
RX 9060 XT llama 8B Q5_0 512 pp2048 2035.08 2032.12 1.00
RX 9060 XT llama 8B Q5_0 1024 pp2048 2073.51 2071.38 1.00
RX 9060 XT llama 8B Q5_0 2048 pp2048 2048.50 2042.04 1.00
RX 9060 XT llama 8B Q5_1 1 pp2048 41.40 41.50 1.00
RX 9060 XT llama 8B Q5_1 2 pp2048 77.02 77.00 1.00
RX 9060 XT llama 8B Q5_1 4 pp2048 144.96 144.90 1.00
RX 9060 XT llama 8B Q5_1 8 pp2048 229.51 229.47 1.00
RX 9060 XT llama 8B Q5_1 16 pp2048 406.97 407.04 1.00
RX 9060 XT llama 8B Q5_1 32 pp2048 604.87 604.56 1.00
RX 9060 XT llama 8B Q5_1 64 pp2048 1120.09 1119.52 1.00
RX 9060 XT llama 8B Q5_1 128 pp2048 1611.29 1610.55 1.00
RX 9060 XT llama 8B Q5_1 256 pp2048 1753.69 1752.39 1.00
RX 9060 XT llama 8B Q5_1 512 pp2048 1786.09 1796.75 1.01
RX 9060 XT llama 8B Q5_1 1024 pp2048 1821.42 1832.46 1.01
RX 9060 XT llama 8B Q5_1 2048 pp2048 1804.08 1814.60 1.01
RX 9060 XT llama 8B Q5_K_S 1 pp2048 43.83 43.88 1.00
RX 9060 XT llama 8B Q5_K_S 2 pp2048 80.23 80.47 1.00
RX 9060 XT llama 8B Q5_K_S 4 pp2048 127.81 128.12 1.00
RX 9060 XT llama 8B Q5_K_S 8 pp2048 145.77 146.13 1.00
RX 9060 XT llama 8B Q5_K_S 16 pp2048 526.08 526.38 1.00
RX 9060 XT llama 8B Q5_K_S 32 pp2048 734.17 734.97 1.00
RX 9060 XT llama 8B Q5_K_S 64 pp2048 1268.87 1272.27 1.00
RX 9060 XT llama 8B Q5_K_S 128 pp2048 1701.48 1706.23 1.00
RX 9060 XT llama 8B Q5_K_S 256 pp2048 1836.02 1839.31 1.00
RX 9060 XT llama 8B Q5_K_S 512 pp2048 1872.70 1875.40 1.00
RX 9060 XT llama 8B Q5_K_S 1024 pp2048 1908.40 1908.81 1.00
RX 9060 XT llama 8B Q5_K_S 2048 pp2048 1887.29 1888.04 1.00
RX 9060 XT llama 8B Q6_K 1 pp2048 38.88 38.99 1.00
RX 9060 XT llama 8B Q6_K 2 pp2048 73.73 73.87 1.00
RX 9060 XT llama 8B Q6_K 4 pp2048 125.21 125.35 1.00
RX 9060 XT llama 8B Q6_K 8 pp2048 155.97 156.21 1.00
RX 9060 XT llama 8B Q6_K 16 pp2048 423.32 424.10 1.00
RX 9060 XT llama 8B Q6_K 32 pp2048 572.10 572.61 1.00
RX 9060 XT llama 8B Q6_K 64 pp2048 890.37 883.60 0.99
RX 9060 XT llama 8B Q6_K 128 pp2048 1092.09 1089.40 1.00
RX 9060 XT llama 8B Q6_K 256 pp2048 1178.14 327.99 0.28
RX 9060 XT llama 8B Q6_K 512 pp2048 1191.51 342.57 0.29
RX 9060 XT llama 8B Q6_K 1024 pp2048 1210.27 350.61 0.29
RX 9060 XT llama 8B Q6_K 2048 pp2048 1200.71 352.17 0.29
RX 9060 XT llama 8B Q8_0 1 pp2048 32.86 32.88 1.00
RX 9060 XT llama 8B Q8_0 2 pp2048 61.02 61.09 1.00
RX 9060 XT llama 8B Q8_0 4 pp2048 115.13 115.26 1.00
RX 9060 XT llama 8B Q8_0 8 pp2048 182.72 182.61 1.00
RX 9060 XT llama 8B Q8_0 16 pp2048 428.24 427.08 1.00
RX 9060 XT llama 8B Q8_0 32 pp2048 664.97 663.10 1.00
RX 9060 XT llama 8B Q8_0 64 pp2048 1202.56 1201.90 1.00
RX 9060 XT llama 8B Q8_0 128 pp2048 1869.15 1862.62 1.00
RX 9060 XT llama 8B Q8_0 256 pp2048 2049.97 2045.47 1.00
RX 9060 XT llama 8B Q8_0 512 pp2048 2108.77 2099.31 1.00
RX 9060 XT llama 8B Q8_0 1024 pp2048 2177.10 2161.62 0.99
RX 9060 XT llama 8B Q8_0 2048 pp2048 2155.69 2142.31 0.99

On RDNA4 this PR is consistently making the performance worse so ggml_cuda_should_use_mmq should return true for that architecture.

@Beinsezii
Copy link
Contributor Author

On RDNA4 this PR is consistently making the performance worse so ggml_cuda_should_use_mmq should return true for that architecture.

Probably better to just gate the new block behind a RDNA3 check than make RDNA4 true imo. I'm surprised how much worse it is actually, the 3/4 mmq impls looked similar I wonder if this is rocblas problems.

@Beinsezii
Copy link
Contributor Author

Did GGML_CUDA_CC_IS_RDNA3 instead of GGML_CUDA_CC_IS_RDNA3_0 because I believe 3_0 and 3_5 have pretty much equal wmma intrinsics. Given how severe RDNA4 regression was still needs a test before merge though.

@JohannesGaessler
Copy link
Contributor

Did GGML_CUDA_CC_IS_RDNA3 instead of GGML_CUDA_CC_IS_RDNA3_0 because I believe 3_0 and 3_5 have pretty much equal wmma intrinsics.

That was my initial assumption too, then I got unexpected reports about performance regressions. But I think those were related to rocBLAS versions and environment variables. On Monday I should be able to check RDNA 3.5 performance to make sure.

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
@JohannesGaessler
Copy link
Contributor

My current status is this: AMD sent me an Asus ROG Flow Z13 machine where my Linux installation is currently broken due to what seems to be hardware-specific issues. I don't know how long it will take me to sort this out so for now I would simply merge this PR as-is and deal with any issues that arise after the fact.

@JohannesGaessler JohannesGaessler merged commit 9689295 into ggml-org:master Jan 6, 2026
71 checks passed
@elfarolab
Copy link

My current status is this: AMD sent me an Asus ROG Flow Z13 machine where my Linux installation is currently broken due to what seems to be hardware-specific issues. I don't know how long it will take me to sort this out so for now I would simply merge this PR as-is and deal with any issues that arise after the fact.

@JohannesGaessler
Hello Johannes, I'm an embedded Linux specialist. I don't want to be intrusive, and I'm not sure if you need help, but the work of the developers on this open-source project is incredibly important and so much appreciated. If you do need assistance with your Linux machine, I'd be happy to help with any issue.

Thank you so much.

@JohannesGaessler
Copy link
Contributor

My situation is as follows: on my desktop machine I have installed Manjaro with KDE which is working without issue. On the ROG Flow Z13 I received from AMD I also installed Manjaro in order to deal with only a single Linux distribution. With Manjaro v25 it was working as intended after installing the linux-flowx13 package from the AUR that provides a kernel with ROG-specific fixes. After upgrading to Manjaro v26 the screen would freeze upon login, SSH and tty3 work correctly. Using clean installs of Manjaro the v25 KDE variant and v26 XFCE variant seem to be working correctly, it's specifically KDE that is causing problems with Manjaro v26. The problem is kind of inconsistent and sometimes the login works for no apparent reason. If I had to guess it's some issue with the display driver.

@elfarolab
Copy link

elfarolab commented Jan 6, 2026

I am also using KDE on a laptop with CachyOS, which is a variant of Arch Linux (like Manjaro, which also derives from Arch).
CachyOS prioritizes gaming support. I use it because it heavily optimizes Arch Linux packages and provides good GPU driver support. Maybe try CachyOS as an alternative to fix your display driver.
https://www.phoronix.com/news/CachyOS-May-2025-Release

If you provide more details about the bootstrap process, I can look into it, even though graphics drivers aren't my expertise.

@Beinsezii Beinsezii deleted the beinsezii/rocm_mmq_tune branch January 6, 2026 17:27
@jiachengjason
Copy link
Contributor

jiachengjason commented Jan 7, 2026

strix halo performance
GPU Model Microbatch size Test t/s 3d26a09 t/s 9689295 Speedup
Strix Halo llama 8B IQ1_S - 1.5625 bpw 1 pp2048 60.00 59.94 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 2 pp2048 104.40 104.37 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 4 pp2048 200.04 199.89 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 8 pp2048 254.22 254.27 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 16 pp2048 490.52 490.75 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 32 pp2048 635.83 639.11 1.01
Strix Halo llama 8B IQ1_S - 1.5625 bpw 64 pp2048 778.03 778.56 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 128 pp2048 869.27 870.24 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 256 pp2048 934.24 932.46 1.00
Strix Halo llama 8B IQ1_S - 1.5625 bpw 512 pp2048 908.61 918.09 1.01
Strix Halo llama 8B IQ1_S - 1.5625 bpw 1024 pp2048 901.55 912.45 1.01
Strix Halo llama 8B IQ1_S - 1.5625 bpw 2048 pp2048 821.20 829.03 1.01
Strix Halo llama 8B IQ2_S - 2.5 bpw 1 pp2048 47.68 47.54 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 2 pp2048 85.43 85.42 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 4 pp2048 163.02 163.27 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 8 pp2048 205.91 205.93 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 16 pp2048 301.37 300.87 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 32 pp2048 568.26 567.77 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 64 pp2048 721.31 720.95 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 128 pp2048 760.99 762.41 1.00
Strix Halo llama 8B IQ2_S - 2.5 bpw 256 pp2048 811.27 620.74 0.77
Strix Halo llama 8B IQ2_S - 2.5 bpw 512 pp2048 796.69 752.79 0.94
Strix Halo llama 8B IQ2_S - 2.5 bpw 1024 pp2048 796.68 807.61 1.01
Strix Halo llama 8B IQ2_S - 2.5 bpw 2048 pp2048 734.43 778.80 1.06
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 1 pp2048 49.60 49.61 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 2 pp2048 87.77 87.81 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 4 pp2048 165.07 165.29 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 8 pp2048 202.39 202.56 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 16 pp2048 297.78 297.13 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 32 pp2048 574.37 574.98 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 64 pp2048 712.62 713.62 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 128 pp2048 741.74 741.56 1.00
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 256 pp2048 790.92 600.27 0.76
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 512 pp2048 779.03 726.41 0.93
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 1024 pp2048 784.19 800.73 1.02
Strix Halo llama 8B IQ2_XS - 2.3125 bpw 2048 pp2048 722.31 778.12 1.08
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 1 pp2048 41.39 41.43 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 2 pp2048 74.65 74.84 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 4 pp2048 143.67 144.28 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 8 pp2048 183.92 183.40 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 16 pp2048 363.01 363.32 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 32 pp2048 518.92 519.00 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 64 pp2048 723.17 709.32 0.98
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 128 pp2048 854.89 853.62 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 256 pp2048 914.79 914.03 1.00
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 512 pp2048 894.60 902.83 1.01
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 1024 pp2048 893.34 898.97 1.01
Strix Halo llama 8B IQ2_XXS - 2.0625 bpw 2048 pp2048 814.62 821.62 1.01
Strix Halo llama 8B IQ3_S - 3.4375 bpw 1 pp2048 38.99 38.89 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 2 pp2048 71.71 71.52 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 4 pp2048 140.79 140.20 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 8 pp2048 198.66 197.92 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 16 pp2048 334.82 333.54 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 32 pp2048 484.37 483.97 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 64 pp2048 718.15 716.37 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 128 pp2048 874.09 871.81 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 256 pp2048 930.54 928.49 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 512 pp2048 901.05 900.92 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 1024 pp2048 894.92 895.33 1.00
Strix Halo llama 8B IQ3_S - 3.4375 bpw 2048 pp2048 816.34 816.32 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 1 pp2048 38.84 38.80 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 2 pp2048 70.25 70.07 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 4 pp2048 135.46 134.67 0.99
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 8 pp2048 187.04 186.06 0.99
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 16 pp2048 348.51 346.77 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 32 pp2048 498.17 497.03 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 64 pp2048 726.32 724.69 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 128 pp2048 875.05 873.04 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 256 pp2048 929.04 927.61 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 512 pp2048 903.64 902.45 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 1024 pp2048 895.40 894.62 1.00
Strix Halo llama 8B IQ3_S mix - 3.66 bpw 2048 pp2048 817.66 817.35 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 1 pp2048 42.35 42.32 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 2 pp2048 77.96 77.90 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 4 pp2048 151.78 151.95 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 8 pp2048 203.23 203.11 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 16 pp2048 372.42 372.75 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 32 pp2048 518.34 518.58 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 64 pp2048 741.16 741.36 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 128 pp2048 890.30 889.72 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 256 pp2048 948.19 946.75 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 512 pp2048 924.39 924.19 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 1024 pp2048 912.97 913.45 1.00
Strix Halo llama 8B IQ3_XS - 3.3 bpw 2048 pp2048 831.98 831.88 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 1 pp2048 45.23 45.14 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 2 pp2048 82.08 81.99 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 4 pp2048 159.80 159.20 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 8 pp2048 203.22 203.02 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 16 pp2048 386.16 385.33 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 32 pp2048 552.93 551.49 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 64 pp2048 753.64 753.58 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 128 pp2048 891.51 891.46 1.00
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 256 pp2048 952.46 924.09 0.97
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 512 pp2048 929.04 909.88 0.98
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 1024 pp2048 919.63 908.94 0.99
Strix Halo llama 8B IQ3_XXS - 3.0625 bpw 2048 pp2048 837.08 838.19 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 1 pp2048 39.44 39.45 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 2 pp2048 75.17 75.19 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 4 pp2048 160.98 161.22 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 8 pp2048 262.05 262.33 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 16 pp2048 464.46 463.35 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 32 pp2048 516.00 516.21 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 64 pp2048 838.71 839.75 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 128 pp2048 944.58 944.39 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 256 pp2048 1017.32 1017.27 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 512 pp2048 996.56 991.89 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 1024 pp2048 977.05 974.35 1.00
Strix Halo llama 8B IQ4_NL - 4.5 bpw 2048 pp2048 879.54 878.44 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 1 pp2048 41.45 41.46 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 2 pp2048 78.87 78.86 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 4 pp2048 170.80 170.66 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 8 pp2048 271.98 271.95 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 16 pp2048 501.50 502.34 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 32 pp2048 391.49 391.20 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 64 pp2048 839.69 841.13 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 128 pp2048 953.07 954.80 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 256 pp2048 1027.69 1026.93 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 512 pp2048 999.99 1001.47 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 1024 pp2048 984.34 985.86 1.00
Strix Halo llama 8B IQ4_XS - 4.25 bpw 2048 pp2048 881.64 882.92 1.00
Strix Halo llama 8B Q2_K_S 1 pp2048 51.52 51.58 1.00
Strix Halo llama 8B Q2_K_S 2 pp2048 82.62 82.71 1.00
Strix Halo llama 8B Q2_K_S 4 pp2048 122.15 122.36 1.00
Strix Halo llama 8B Q2_K_S 8 pp2048 125.03 125.65 1.00
Strix Halo llama 8B Q2_K_S 16 pp2048 234.47 234.48 1.00
Strix Halo llama 8B Q2_K_S 32 pp2048 371.96 371.93 1.00
Strix Halo llama 8B Q2_K_S 64 pp2048 503.07 503.42 1.00
Strix Halo llama 8B Q2_K_S 128 pp2048 553.24 553.59 1.00
Strix Halo llama 8B Q2_K_S 256 pp2048 594.02 606.08 1.02
Strix Halo llama 8B Q2_K_S 512 pp2048 606.43 733.53 1.21
Strix Halo llama 8B Q2_K_S 1024 pp2048 628.48 811.93 1.29
Strix Halo llama 8B Q2_K_S 2048 pp2048 593.93 783.17 1.32
Strix Halo llama 8B Q3_K_S 1 pp2048 42.04 42.32 1.01
Strix Halo llama 8B Q3_K_S 2 pp2048 71.87 71.99 1.00
Strix Halo llama 8B Q3_K_S 4 pp2048 117.24 117.54 1.00
Strix Halo llama 8B Q3_K_S 8 pp2048 124.25 124.82 1.00
Strix Halo llama 8B Q3_K_S 16 pp2048 339.95 339.12 1.00
Strix Halo llama 8B Q3_K_S 32 pp2048 608.45 607.23 1.00
Strix Halo llama 8B Q3_K_S 64 pp2048 746.05 746.40 1.00
Strix Halo llama 8B Q3_K_S 128 pp2048 862.17 861.57 1.00
Strix Halo llama 8B Q3_K_S 256 pp2048 922.09 923.97 1.00
Strix Halo llama 8B Q3_K_S 512 pp2048 919.01 920.03 1.00
Strix Halo llama 8B Q3_K_S 1024 pp2048 896.72 897.46 1.00
Strix Halo llama 8B Q3_K_S 2048 pp2048 812.64 813.46 1.00
Strix Halo llama 8B Q4_0 1 pp2048 39.75 39.68 1.00
Strix Halo llama 8B Q4_0 2 pp2048 76.22 76.07 1.00
Strix Halo llama 8B Q4_0 4 pp2048 164.51 164.14 1.00
Strix Halo llama 8B Q4_0 8 pp2048 266.42 266.96 1.00
Strix Halo llama 8B Q4_0 16 pp2048 457.63 458.41 1.00
Strix Halo llama 8B Q4_0 32 pp2048 364.17 364.41 1.00
Strix Halo llama 8B Q4_0 64 pp2048 815.83 814.72 1.00
Strix Halo llama 8B Q4_0 128 pp2048 934.80 936.56 1.00
Strix Halo llama 8B Q4_0 256 pp2048 1001.43 1001.15 1.00
Strix Halo llama 8B Q4_0 512 pp2048 984.89 980.83 1.00
Strix Halo llama 8B Q4_0 1024 pp2048 964.78 964.76 1.00
Strix Halo llama 8B Q4_0 2048 pp2048 871.49 870.55 1.00
Strix Halo llama 8B Q4_1 1 pp2048 36.52 36.49 1.00
Strix Halo llama 8B Q4_1 2 pp2048 71.08 71.11 1.00
Strix Halo llama 8B Q4_1 4 pp2048 154.39 154.24 1.00
Strix Halo llama 8B Q4_1 8 pp2048 258.69 257.81 1.00
Strix Halo llama 8B Q4_1 16 pp2048 437.17 437.12 1.00
Strix Halo llama 8B Q4_1 32 pp2048 636.33 637.25 1.00
Strix Halo llama 8B Q4_1 64 pp2048 794.82 795.97 1.00
Strix Halo llama 8B Q4_1 128 pp2048 862.78 862.10 1.00
Strix Halo llama 8B Q4_1 256 pp2048 922.96 923.27 1.00
Strix Halo llama 8B Q4_1 512 pp2048 910.37 908.30 1.00
Strix Halo llama 8B Q4_1 1024 pp2048 906.03 905.01 1.00
Strix Halo llama 8B Q4_1 2048 pp2048 826.83 827.71 1.00
Strix Halo llama 8B Q4_K_S 1 pp2048 36.89 36.85 1.00
Strix Halo llama 8B Q4_K_S 2 pp2048 62.30 62.27 1.00
Strix Halo llama 8B Q4_K_S 4 pp2048 105.34 105.39 1.00
Strix Halo llama 8B Q4_K_S 8 pp2048 129.01 129.53 1.00
Strix Halo llama 8B Q4_K_S 16 pp2048 462.27 462.06 1.00
Strix Halo llama 8B Q4_K_S 32 pp2048 636.53 637.05 1.00
Strix Halo llama 8B Q4_K_S 64 pp2048 785.38 785.88 1.00
Strix Halo llama 8B Q4_K_S 128 pp2048 894.14 894.09 1.00
Strix Halo llama 8B Q4_K_S 256 pp2048 946.44 947.34 1.00
Strix Halo llama 8B Q4_K_S 512 pp2048 920.63 922.99 1.00
Strix Halo llama 8B Q4_K_S 1024 pp2048 916.42 916.68 1.00
Strix Halo llama 8B Q4_K_S 2048 pp2048 834.05 833.39 1.00
Strix Halo llama 8B Q5_1 1 pp2048 30.50 30.46 1.00
Strix Halo llama 8B Q5_1 2 pp2048 59.49 59.48 1.00
Strix Halo llama 8B Q5_1 4 pp2048 127.90 127.76 1.00
Strix Halo llama 8B Q5_1 8 pp2048 221.95 221.87 1.00
Strix Halo llama 8B Q5_1 16 pp2048 311.47 311.37 1.00
Strix Halo llama 8B Q5_1 32 pp2048 512.37 512.43 1.00
Strix Halo llama 8B Q5_1 64 pp2048 703.73 703.12 1.00
Strix Halo llama 8B Q5_1 128 pp2048 813.20 814.13 1.00
Strix Halo llama 8B Q5_1 256 pp2048 884.18 885.60 1.00
Strix Halo llama 8B Q5_1 512 pp2048 887.84 885.76 1.00
Strix Halo llama 8B Q5_1 1024 pp2048 877.35 876.09 1.00
Strix Halo llama 8B Q5_1 2048 pp2048 806.32 805.07 1.00
Strix Halo llama 8B Q5_K_S 1 pp2048 33.32 33.23 1.00
Strix Halo llama 8B Q5_K_S 2 pp2048 57.71 57.74 1.00
Strix Halo llama 8B Q5_K_S 4 pp2048 100.06 100.20 1.00
Strix Halo llama 8B Q5_K_S 8 pp2048 124.09 124.42 1.00
Strix Halo llama 8B Q5_K_S 16 pp2048 453.39 452.85 1.00
Strix Halo llama 8B Q5_K_S 32 pp2048 644.45 645.21 1.00
Strix Halo llama 8B Q5_K_S 64 pp2048 792.90 793.44 1.00
Strix Halo llama 8B Q5_K_S 128 pp2048 871.21 872.29 1.00
Strix Halo llama 8B Q5_K_S 256 pp2048 928.28 927.06 1.00
Strix Halo llama 8B Q5_K_S 512 pp2048 911.07 912.51 1.00
Strix Halo llama 8B Q5_K_S 1024 pp2048 905.92 908.61 1.00
Strix Halo llama 8B Q5_K_S 2048 pp2048 825.17 824.82 1.00
Strix Halo llama 8B Q6_K 1 pp2048 29.50 29.50 1.00
Strix Halo llama 8B Q6_K 2 pp2048 56.27 56.25 1.00
Strix Halo llama 8B Q6_K 4 pp2048 116.68 116.58 1.00
Strix Halo llama 8B Q6_K 8 pp2048 156.63 156.95 1.00
Strix Halo llama 8B Q6_K 16 pp2048 348.43 348.58 1.00
Strix Halo llama 8B Q6_K 32 pp2048 493.41 493.25 1.00
Strix Halo llama 8B Q6_K 64 pp2048 605.81 606.71 1.00
Strix Halo llama 8B Q6_K 128 pp2048 621.70 622.28 1.00
Strix Halo llama 8B Q6_K 256 pp2048 658.97 583.94 0.89
Strix Halo llama 8B Q6_K 512 pp2048 656.23 718.64 1.10
Strix Halo llama 8B Q6_K 1024 pp2048 664.13 798.61 1.20
Strix Halo llama 8B Q6_K 2048 pp2048 625.33 772.19 1.23
Strix Halo llama 8B Q8_0 1 pp2048 25.03 25.02 1.00
Strix Halo llama 8B Q8_0 2 pp2048 48.67 48.63 1.00
Strix Halo llama 8B Q8_0 4 pp2048 103.08 102.91 1.00
Strix Halo llama 8B Q8_0 8 pp2048 190.36 189.99 1.00
Strix Halo llama 8B Q8_0 16 pp2048 331.84 331.46 1.00
Strix Halo llama 8B Q8_0 32 pp2048 389.04 389.96 1.00
Strix Halo llama 8B Q8_0 64 pp2048 742.91 741.76 1.00
Strix Halo llama 8B Q8_0 128 pp2048 886.60 886.82 1.00
Strix Halo llama 8B Q8_0 256 pp2048 953.86 952.18 1.00
Strix Halo llama 8B Q8_0 512 pp2048 937.57 937.62 1.00
Strix Halo llama 8B Q8_0 1024 pp2048 930.12 930.16 1.00
Strix Halo llama 8B Q8_0 2048 pp2048 843.21 843.89 1.00

@jiachengjason
Copy link
Contributor

jiachengjason commented Jan 7, 2026

RDNA4 performance
GPU Model Microbatch size Test t/s be47fb9 t/s a435c77 Speedup
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 1 pp2048 68.69 68.82 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 2 pp2048 120.09 120.13 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 4 pp2048 184.47 184.68 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 8 pp2048 205.58 205.62 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 16 pp2048 569.29 564.61 0.99
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 32 pp2048 747.23 744.42 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 64 pp2048 1253.65 1254.23 1.00
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 128 pp2048 1703.77 1677.91 0.98
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 256 pp2048 1858.69 1615.75 0.87
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 512 pp2048 1888.65 1653.46 0.88
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 1024 pp2048 1921.45 1679.51 0.87
RX 9060 XT llama 8B IQ1_S - 1.5625 bpw 2048 pp2048 1897.64 1644.51 0.87
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 1 pp2048 57.19 57.14 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 2 pp2048 98.20 98.17 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 4 pp2048 150.44 150.49 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 8 pp2048 204.61 204.58 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 16 pp2048 333.69 334.65 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 32 pp2048 640.28 640.19 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 64 pp2048 1120.56 1121.55 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 128 pp2048 1467.23 1466.47 1.00
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 256 pp2048 1587.87 373.67 0.24
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 512 pp2048 1623.16 385.20 0.24
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 1024 pp2048 1647.19 393.93 0.24
RX 9060 XT llama 8B IQ2_S - 2.5 bpw 2048 pp2048 1621.33 394.44 0.24
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 1 pp2048 59.29 59.16 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 2 pp2048 101.31 101.17 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 4 pp2048 151.40 151.28 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 8 pp2048 205.95 205.68 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 16 pp2048 328.24 327.06 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 32 pp2048 634.16 633.63 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 64 pp2048 1099.59 1098.50 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 128 pp2048 1407.26 1406.07 1.00
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 256 pp2048 1527.93 338.65 0.22
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 512 pp2048 1558.01 350.58 0.23
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 1024 pp2048 1588.71 356.74 0.22
RX 9060 XT llama 8B IQ2_XS - 2.3125 bpw 2048 pp2048 1575.42 358.88 0.23
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 1 pp2048 49.50 49.52 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 2 pp2048 88.52 88.46 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 4 pp2048 148.21 148.44 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 8 pp2048 177.61 177.97 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 16 pp2048 463.33 461.38 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 32 pp2048 635.31 635.33 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 64 pp2048 1178.84 1187.44 1.01
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 128 pp2048 1710.56 1705.17 1.00
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 256 pp2048 1865.63 1604.23 0.86
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 512 pp2048 1897.50 1642.77 0.87
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 1024 pp2048 1936.72 1675.52 0.87
RX 9060 XT llama 8B IQ2_XXS - 2.0625 bpw 2048 pp2048 1916.88 1661.18 0.87
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 1 pp2048 46.07 46.03 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 2 pp2048 85.77 85.63 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 4 pp2048 146.27 145.75 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 8 pp2048 178.68 178.46 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 16 pp2048 447.56 448.15 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 32 pp2048 628.32 631.06 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 64 pp2048 1187.46 1177.83 0.99
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 128 pp2048 1762.23 1757.89 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 256 pp2048 1900.47 1899.32 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 512 pp2048 1938.78 1935.11 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 1024 pp2048 1970.33 1963.70 1.00
RX 9060 XT llama 8B IQ3_S - 3.4375 bpw 2048 pp2048 1932.46 1930.39 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 1 pp2048 45.90 46.01 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 2 pp2048 84.59 85.00 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 4 pp2048 143.48 143.85 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 8 pp2048 174.65 175.03 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 16 pp2048 453.40 454.11 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 32 pp2048 643.74 644.38 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 64 pp2048 1196.28 1199.68 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 128 pp2048 1751.81 1756.49 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 256 pp2048 1892.27 1899.63 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 512 pp2048 1931.79 1936.89 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 1024 pp2048 1963.12 1967.22 1.00
RX 9060 XT llama 8B IQ3_S mix - 3.66 bpw 2048 pp2048 1921.56 1932.94 1.01
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 1 pp2048 51.76 51.55 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 2 pp2048 93.53 93.13 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 4 pp2048 149.03 148.21 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 8 pp2048 180.44 180.01 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 16 pp2048 474.93 474.19 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 32 pp2048 658.92 657.57 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 64 pp2048 1218.19 1216.75 1.00
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 128 pp2048 1810.12 1797.37 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 256 pp2048 1959.46 1947.81 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 512 pp2048 2004.47 1992.24 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 1024 pp2048 2038.21 2026.73 0.99
RX 9060 XT llama 8B IQ3_XS - 3.3 bpw 2048 pp2048 1998.51 1988.25 0.99
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 1 pp2048 54.13 54.40 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 2 pp2048 94.62 95.23 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 4 pp2048 146.78 147.84 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 8 pp2048 182.08 182.64 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 16 pp2048 461.53 462.13 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 32 pp2048 673.79 675.92 1.00
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 64 pp2048 1226.37 1256.54 1.02
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 128 pp2048 1780.26 1793.39 1.01
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 256 pp2048 1937.10 1378.50 0.71
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 512 pp2048 1984.06 1410.04 0.71
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 1024 pp2048 2017.33 1441.88 0.71
RX 9060 XT llama 8B IQ3_XXS - 3.0625 bpw 2048 pp2048 1982.60 1431.47 0.72
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 1 pp2048 48.54 48.63 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 2 pp2048 92.22 92.32 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 4 pp2048 168.12 168.41 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 8 pp2048 193.58 193.93 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 16 pp2048 561.51 561.91 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 32 pp2048 782.72 783.87 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 64 pp2048 1389.64 1375.05 0.99
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 128 pp2048 1956.00 1994.64 1.02
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 256 pp2048 2169.17 2172.63 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 512 pp2048 2215.35 2217.34 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 1024 pp2048 2258.26 2262.33 1.00
RX 9060 XT llama 8B IQ4_NL - 4.5 bpw 2048 pp2048 2216.37 2219.76 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 1 pp2048 50.52 50.47 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 2 pp2048 96.60 96.69 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 4 pp2048 179.98 179.72 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 8 pp2048 208.33 207.92 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 16 pp2048 595.02 594.96 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 32 pp2048 797.95 797.78 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 64 pp2048 1437.76 1436.09 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 128 pp2048 2035.66 2041.45 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 256 pp2048 2238.64 2228.77 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 512 pp2048 2281.43 2271.05 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 1024 pp2048 2322.52 2316.75 1.00
RX 9060 XT llama 8B IQ4_XS - 4.25 bpw 2048 pp2048 2285.98 2276.98 1.00
RX 9060 XT llama 8B Q2_K_S 1 pp2048 67.29 67.27 1.00
RX 9060 XT llama 8B Q2_K_S 2 pp2048 105.20 104.06 0.99
RX 9060 XT llama 8B Q2_K_S 4 pp2048 128.24 126.94 0.99
RX 9060 XT llama 8B Q2_K_S 8 pp2048 142.94 141.15 0.99
RX 9060 XT llama 8B Q2_K_S 16 pp2048 347.71 345.56 0.99
RX 9060 XT llama 8B Q2_K_S 32 pp2048 462.14 459.29 0.99
RX 9060 XT llama 8B Q2_K_S 64 pp2048 740.70 738.53 1.00
RX 9060 XT llama 8B Q2_K_S 128 pp2048 905.54 899.19 0.99
RX 9060 XT llama 8B Q2_K_S 256 pp2048 978.28 348.91 0.36
RX 9060 XT llama 8B Q2_K_S 512 pp2048 964.15 361.71 0.38
RX 9060 XT llama 8B Q2_K_S 1024 pp2048 1022.04 369.67 0.36
RX 9060 XT llama 8B Q2_K_S 2048 pp2048 1015.14 371.65 0.37
RX 9060 XT llama 8B Q3_K_S 1 pp2048 50.23 50.03 1.00
RX 9060 XT llama 8B Q3_K_S 2 pp2048 83.84 83.80 1.00
RX 9060 XT llama 8B Q3_K_S 4 pp2048 119.53 119.01 1.00
RX 9060 XT llama 8B Q3_K_S 8 pp2048 143.45 142.72 0.99
RX 9060 XT llama 8B Q3_K_S 16 pp2048 456.60 455.65 1.00
RX 9060 XT llama 8B Q3_K_S 32 pp2048 666.93 662.43 0.99
RX 9060 XT llama 8B Q3_K_S 64 pp2048 1142.25 1141.32 1.00
RX 9060 XT llama 8B Q3_K_S 128 pp2048 1617.95 1608.82 0.99
RX 9060 XT llama 8B Q3_K_S 256 pp2048 1740.89 1714.79 0.99
RX 9060 XT llama 8B Q3_K_S 512 pp2048 1775.65 1766.64 0.99
RX 9060 XT llama 8B Q3_K_S 1024 pp2048 1809.70 1800.47 0.99
RX 9060 XT llama 8B Q3_K_S 2048 pp2048 1796.37 1784.52 0.99
RX 9060 XT llama 8B Q4_0 1 pp2048 48.80 48.83 1.00
RX 9060 XT llama 8B Q4_0 2 pp2048 92.57 92.57 1.00
RX 9060 XT llama 8B Q4_0 4 pp2048 169.50 169.56 1.00
RX 9060 XT llama 8B Q4_0 8 pp2048 200.24 200.41 1.00
RX 9060 XT llama 8B Q4_0 16 pp2048 554.36 552.12 1.00
RX 9060 XT llama 8B Q4_0 32 pp2048 758.38 758.83 1.00
RX 9060 XT llama 8B Q4_0 64 pp2048 1355.11 1349.07 1.00
RX 9060 XT llama 8B Q4_0 128 pp2048 1973.03 1974.87 1.00
RX 9060 XT llama 8B Q4_0 256 pp2048 2142.05 2145.10 1.00
RX 9060 XT llama 8B Q4_0 512 pp2048 2187.72 2188.37 1.00
RX 9060 XT llama 8B Q4_0 1024 pp2048 2237.77 2240.40 1.00
RX 9060 XT llama 8B Q4_0 2048 pp2048 2210.63 2211.25 1.00
RX 9060 XT llama 8B Q4_1 1 pp2048 46.14 46.13 1.00
RX 9060 XT llama 8B Q4_1 2 pp2048 87.36 87.37 1.00
RX 9060 XT llama 8B Q4_1 4 pp2048 161.56 161.47 1.00
RX 9060 XT llama 8B Q4_1 8 pp2048 212.05 211.77 1.00
RX 9060 XT llama 8B Q4_1 16 pp2048 554.58 551.61 0.99
RX 9060 XT llama 8B Q4_1 32 pp2048 763.75 763.82 1.00
RX 9060 XT llama 8B Q4_1 64 pp2048 1316.07 1314.47 1.00
RX 9060 XT llama 8B Q4_1 128 pp2048 1709.67 1704.02 1.00
RX 9060 XT llama 8B Q4_1 256 pp2048 1840.83 1836.41 1.00
RX 9060 XT llama 8B Q4_1 512 pp2048 1875.75 1850.53 0.99
RX 9060 XT llama 8B Q4_1 1024 pp2048 1909.05 1903.01 1.00
RX 9060 XT llama 8B Q4_1 2048 pp2048 1886.32 1882.86 1.00
RX 9060 XT llama 8B Q4_K_S 1 pp2048 48.96 49.02 1.00
RX 9060 XT llama 8B Q4_K_S 2 pp2048 89.77 89.66 1.00
RX 9060 XT llama 8B Q4_K_S 4 pp2048 132.20 131.90 1.00
RX 9060 XT llama 8B Q4_K_S 8 pp2048 148.52 148.43 1.00
RX 9060 XT llama 8B Q4_K_S 16 pp2048 522.11 522.04 1.00
RX 9060 XT llama 8B Q4_K_S 32 pp2048 736.98 737.38 1.00
RX 9060 XT llama 8B Q4_K_S 64 pp2048 1269.63 1270.78 1.00
RX 9060 XT llama 8B Q4_K_S 128 pp2048 1735.84 1734.80 1.00
RX 9060 XT llama 8B Q4_K_S 256 pp2048 1872.27 1846.87 0.99
RX 9060 XT llama 8B Q4_K_S 512 pp2048 1903.24 1901.62 1.00
RX 9060 XT llama 8B Q4_K_S 1024 pp2048 1938.90 1937.89 1.00
RX 9060 XT llama 8B Q4_K_S 2048 pp2048 1918.55 1916.50 1.00
RX 9060 XT llama 8B Q5_0 1 pp2048 43.53 43.58 1.00
RX 9060 XT llama 8B Q5_0 2 pp2048 82.54 82.59 1.00
RX 9060 XT llama 8B Q5_0 4 pp2048 150.46 150.69 1.00
RX 9060 XT llama 8B Q5_0 8 pp2048 195.36 195.40 1.00
RX 9060 XT llama 8B Q5_0 16 pp2048 484.35 484.93 1.00
RX 9060 XT llama 8B Q5_0 32 pp2048 683.42 684.76 1.00
RX 9060 XT llama 8B Q5_0 64 pp2048 1218.36 1222.12 1.00
RX 9060 XT llama 8B Q5_0 128 pp2048 1844.04 1846.53 1.00
RX 9060 XT llama 8B Q5_0 256 pp2048 2001.99 1997.67 1.00
RX 9060 XT llama 8B Q5_0 512 pp2048 2035.08 2032.12 1.00
RX 9060 XT llama 8B Q5_0 1024 pp2048 2073.51 2071.38 1.00
RX 9060 XT llama 8B Q5_0 2048 pp2048 2048.50 2042.04 1.00
RX 9060 XT llama 8B Q5_1 1 pp2048 41.40 41.50 1.00
RX 9060 XT llama 8B Q5_1 2 pp2048 77.02 77.00 1.00
RX 9060 XT llama 8B Q5_1 4 pp2048 144.96 144.90 1.00
RX 9060 XT llama 8B Q5_1 8 pp2048 229.51 229.47 1.00
RX 9060 XT llama 8B Q5_1 16 pp2048 406.97 407.04 1.00
RX 9060 XT llama 8B Q5_1 32 pp2048 604.87 604.56 1.00
RX 9060 XT llama 8B Q5_1 64 pp2048 1120.09 1119.52 1.00
RX 9060 XT llama 8B Q5_1 128 pp2048 1611.29 1610.55 1.00
RX 9060 XT llama 8B Q5_1 256 pp2048 1753.69 1752.39 1.00
RX 9060 XT llama 8B Q5_1 512 pp2048 1786.09 1796.75 1.01
RX 9060 XT llama 8B Q5_1 1024 pp2048 1821.42 1832.46 1.01
RX 9060 XT llama 8B Q5_1 2048 pp2048 1804.08 1814.60 1.01
RX 9060 XT llama 8B Q5_K_S 1 pp2048 43.83 43.88 1.00
RX 9060 XT llama 8B Q5_K_S 2 pp2048 80.23 80.47 1.00
RX 9060 XT llama 8B Q5_K_S 4 pp2048 127.81 128.12 1.00
RX 9060 XT llama 8B Q5_K_S 8 pp2048 145.77 146.13 1.00
RX 9060 XT llama 8B Q5_K_S 16 pp2048 526.08 526.38 1.00
RX 9060 XT llama 8B Q5_K_S 32 pp2048 734.17 734.97 1.00
RX 9060 XT llama 8B Q5_K_S 64 pp2048 1268.87 1272.27 1.00
RX 9060 XT llama 8B Q5_K_S 128 pp2048 1701.48 1706.23 1.00
RX 9060 XT llama 8B Q5_K_S 256 pp2048 1836.02 1839.31 1.00
RX 9060 XT llama 8B Q5_K_S 512 pp2048 1872.70 1875.40 1.00
RX 9060 XT llama 8B Q5_K_S 1024 pp2048 1908.40 1908.81 1.00
RX 9060 XT llama 8B Q5_K_S 2048 pp2048 1887.29 1888.04 1.00
RX 9060 XT llama 8B Q6_K 1 pp2048 38.88 38.99 1.00
RX 9060 XT llama 8B Q6_K 2 pp2048 73.73 73.87 1.00
RX 9060 XT llama 8B Q6_K 4 pp2048 125.21 125.35 1.00
RX 9060 XT llama 8B Q6_K 8 pp2048 155.97 156.21 1.00
RX 9060 XT llama 8B Q6_K 16 pp2048 423.32 424.10 1.00
RX 9060 XT llama 8B Q6_K 32 pp2048 572.10 572.61 1.00
RX 9060 XT llama 8B Q6_K 64 pp2048 890.37 883.60 0.99
RX 9060 XT llama 8B Q6_K 128 pp2048 1092.09 1089.40 1.00
RX 9060 XT llama 8B Q6_K 256 pp2048 1178.14 327.99 0.28
RX 9060 XT llama 8B Q6_K 512 pp2048 1191.51 342.57 0.29
RX 9060 XT llama 8B Q6_K 1024 pp2048 1210.27 350.61 0.29
RX 9060 XT llama 8B Q6_K 2048 pp2048 1200.71 352.17 0.29
RX 9060 XT llama 8B Q8_0 1 pp2048 32.86 32.88 1.00
RX 9060 XT llama 8B Q8_0 2 pp2048 61.02 61.09 1.00
RX 9060 XT llama 8B Q8_0 4 pp2048 115.13 115.26 1.00
RX 9060 XT llama 8B Q8_0 8 pp2048 182.72 182.61 1.00
RX 9060 XT llama 8B Q8_0 16 pp2048 428.24 427.08 1.00
RX 9060 XT llama 8B Q8_0 32 pp2048 664.97 663.10 1.00
RX 9060 XT llama 8B Q8_0 64 pp2048 1202.56 1201.90 1.00
RX 9060 XT llama 8B Q8_0 128 pp2048 1869.15 1862.62 1.00
RX 9060 XT llama 8B Q8_0 256 pp2048 2049.97 2045.47 1.00
RX 9060 XT llama 8B Q8_0 512 pp2048 2108.77 2099.31 1.00
RX 9060 XT llama 8B Q8_0 1024 pp2048 2177.10 2161.62 0.99
RX 9060 XT llama 8B Q8_0 2048 pp2048 2155.69 2142.31 0.99
On RDNA4 this PR is consistently making the performance worse so ggml_cuda_should_use_mmq should return true for that architecture.

I think we can use the configs from #18442 for RDNA4 for now as that doesn't make the performance worse

RDNA4 Performance for https://github.com//pull/18442
GPU Model Microbatch size Test t/s 3d26a09 t/s master Speedup
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 1 pp2048 117.13 117.22 1.00
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 2 pp2048 185.31 185.74 1.00
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 4 pp2048 283.10 283.71 1.00
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 8 pp2048 422.37 422.63 1.00
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 16 pp2048 871.29 873.23 1.00
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 32 pp2048 46.22 47.71 1.03
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 64 pp2048 91.78 95.72 1.04
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 128 pp2048 183.88 189.12 1.03
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 256 pp2048 362.44 373.27 1.03
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 512 pp2048 708.48 725.04 1.02
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 1024 pp2048 1300.47 1329.14 1.02
AI PRO R9700 llama 8B IQ1_S - 1.5625 bpw 2048 pp2048 1889.42 1943.09 1.03
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 1 pp2048 91.92 91.88 1.00
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 2 pp2048 148.07 147.92 1.00
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 4 pp2048 233.27 233.02 1.00
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 8 pp2048 394.81 393.95 1.00
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 16 pp2048 519.61 520.39 1.00
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 32 pp2048 47.16 47.86 1.02
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 64 pp2048 93.42 95.34 1.02
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 128 pp2048 186.38 190.26 1.02
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 256 pp2048 370.48 367.87 0.99
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 512 pp2048 715.88 729.76 1.02
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 1024 pp2048 1318.86 1323.85 1.00
AI PRO R9700 llama 8B IQ2_S - 2.5 bpw 2048 pp2048 1718.81 1872.13 1.09
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 1 pp2048 94.72 94.91 1.00
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 2 pp2048 150.32 150.34 1.00
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 4 pp2048 233.39 233.48 1.00
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 8 pp2048 394.72 394.73 1.00
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 16 pp2048 502.09 502.86 1.00
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 32 pp2048 47.12 46.59 0.99
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 64 pp2048 94.61 95.44 1.01
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 128 pp2048 188.99 190.11 1.01
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 256 pp2048 363.15 376.11 1.04
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 512 pp2048 722.25 737.82 1.02
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 1024 pp2048 1323.17 1322.69 1.00
AI PRO R9700 llama 8B IQ2_XS - 2.3125 bpw 2048 pp2048 1717.85 1891.75 1.10
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 1 pp2048 80.72 80.83 1.00
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 2 pp2048 136.07 136.22 1.00
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 4 pp2048 227.75 228.08 1.00
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 8 pp2048 346.25 346.98 1.00
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 16 pp2048 732.49 733.20 1.00
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 32 pp2048 46.63 47.72 1.02
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 64 pp2048 93.04 95.69 1.03
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 128 pp2048 185.70 191.14 1.03
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 256 pp2048 365.50 375.29 1.03
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 512 pp2048 711.39 726.73 1.02
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 1024 pp2048 1313.40 1306.95 1.00
AI PRO R9700 llama 8B IQ2_XXS - 2.0625 bpw 2048 pp2048 1933.84 1991.03 1.03
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 1 pp2048 76.34 76.25 1.00
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 2 pp2048 132.08 132.14 1.00
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 4 pp2048 224.23 224.37 1.00
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 8 pp2048 347.83 347.80 1.00
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 16 pp2048 701.95 701.32 1.00
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 32 pp2048 46.91 47.89 1.02
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 64 pp2048 93.46 95.57 1.02
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 128 pp2048 186.47 189.85 1.02
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 256 pp2048 368.08 375.31 1.02
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 512 pp2048 710.79 722.67 1.02
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 1024 pp2048 1318.61 1279.79 0.97
AI PRO R9700 llama 8B IQ3_S - 3.4375 bpw 2048 pp2048 1906.58 1951.80 1.02
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 1 pp2048 76.74 76.79 1.00
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 2 pp2048 132.51 132.75 1.00
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 4 pp2048 221.05 221.09 1.00
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 8 pp2048 335.05 335.69 1.00
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 16 pp2048 715.97 716.38 1.00
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 32 pp2048 46.85 46.18 0.99
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 64 pp2048 93.11 87.72 0.94
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 128 pp2048 187.11 179.12 0.96
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 256 pp2048 369.52 344.78 0.93
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 512 pp2048 709.89 667.98 0.94
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 1024 pp2048 1326.72 1233.21 0.93
AI PRO R9700 llama 8B IQ3_S mix - 3.66 bpw 2048 pp2048 1907.75 1872.50 0.98
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 1 pp2048 83.16 82.94 1.00
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 2 pp2048 138.52 138.50 1.00
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 4 pp2048 227.93 227.71 1.00
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 8 pp2048 354.66 354.86 1.00
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 16 pp2048 745.73 746.07 1.00
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 32 pp2048 47.23 45.70 0.97
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 64 pp2048 94.32 89.24 0.95
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 128 pp2048 187.85 176.97 0.94
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 256 pp2048 373.90 350.98 0.94
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 512 pp2048 716.79 660.41 0.92
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 1024 pp2048 1317.91 1171.80 0.89
AI PRO R9700 llama 8B IQ3_XS - 3.3 bpw 2048 pp2048 1966.78 1888.51 0.96
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 1 pp2048 89.57 89.86 1.00
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 2 pp2048 144.31 143.95 1.00
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 4 pp2048 230.10 230.12 1.00
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 8 pp2048 358.63 358.55 1.00
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 16 pp2048 695.21 695.59 1.00
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 32 pp2048 46.98 47.66 1.01
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 64 pp2048 94.44 95.18 1.01
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 128 pp2048 187.91 190.09 1.01
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 256 pp2048 369.12 372.49 1.01
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 512 pp2048 711.76 720.63 1.01
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 1024 pp2048 1316.19 1329.32 1.01
AI PRO R9700 llama 8B IQ3_XXS - 3.0625 bpw 2048 pp2048 1931.99 1977.33 1.02
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 1 pp2048 89.44 89.63 1.00
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 2 pp2048 155.17 155.13 1.00
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 4 pp2048 270.99 271.03 1.00
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 8 pp2048 409.77 409.51 1.00
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 16 pp2048 925.84 925.63 1.00
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 32 pp2048 46.78 47.10 1.01
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 64 pp2048 93.82 94.64 1.01
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 128 pp2048 187.41 188.39 1.01
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 256 pp2048 366.99 372.49 1.02
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 512 pp2048 714.44 723.02 1.01
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 1024 pp2048 1318.86 1324.90 1.00
AI PRO R9700 llama 8B IQ4_NL - 4.5 bpw 2048 pp2048 2078.02 2056.59 0.99
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 1 pp2048 92.87 92.64 1.00
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 2 pp2048 158.84 158.77 1.00
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 4 pp2048 281.55 281.59 1.00
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 8 pp2048 446.83 447.23 1.00
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 16 pp2048 967.83 969.20 1.00
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 32 pp2048 47.04 47.96 1.02
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 64 pp2048 94.15 96.21 1.02
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 128 pp2048 188.06 191.48 1.02
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 256 pp2048 368.28 375.42 1.02
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 512 pp2048 712.03 731.48 1.03
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 1024 pp2048 1330.28 1352.03 1.02
AI PRO R9700 llama 8B IQ4_XS - 4.25 bpw 2048 pp2048 2079.70 2079.96 1.00
AI PRO R9700 llama 8B Q2_K_S 1 pp2048 108.42 108.62 1.00
AI PRO R9700 llama 8B Q2_K_S 2 pp2048 154.59 154.42 1.00
AI PRO R9700 llama 8B Q2_K_S 4 pp2048 205.33 204.77 1.00
AI PRO R9700 llama 8B Q2_K_S 8 pp2048 263.23 261.85 0.99
AI PRO R9700 llama 8B Q2_K_S 16 pp2048 518.48 518.71 1.00
AI PRO R9700 llama 8B Q2_K_S 32 pp2048 46.73 47.71 1.02
AI PRO R9700 llama 8B Q2_K_S 64 pp2048 93.58 94.57 1.01
AI PRO R9700 llama 8B Q2_K_S 128 pp2048 186.26 190.68 1.02
AI PRO R9700 llama 8B Q2_K_S 256 pp2048 369.34 375.64 1.02
AI PRO R9700 llama 8B Q2_K_S 512 pp2048 710.69 718.36 1.01
AI PRO R9700 llama 8B Q2_K_S 1024 pp2048 1015.05 1161.67 1.14
AI PRO R9700 llama 8B Q2_K_S 2048 pp2048 1277.44 1666.11 1.30
AI PRO R9700 llama 8B Q3_K_S 1 pp2048 82.67 82.77 1.00
AI PRO R9700 llama 8B Q3_K_S 2 pp2048 132.90 133.13 1.00
AI PRO R9700 llama 8B Q3_K_S 4 pp2048 192.23 192.61 1.00
AI PRO R9700 llama 8B Q3_K_S 8 pp2048 264.89 265.16 1.00
AI PRO R9700 llama 8B Q3_K_S 16 pp2048 730.58 731.93 1.00
AI PRO R9700 llama 8B Q3_K_S 32 pp2048 47.31 48.05 1.02
AI PRO R9700 llama 8B Q3_K_S 64 pp2048 93.55 96.01 1.03
AI PRO R9700 llama 8B Q3_K_S 128 pp2048 188.23 191.37 1.02
AI PRO R9700 llama 8B Q3_K_S 256 pp2048 367.76 373.72 1.02
AI PRO R9700 llama 8B Q3_K_S 512 pp2048 720.35 727.09 1.01
AI PRO R9700 llama 8B Q3_K_S 1024 pp2048 1353.73 1364.06 1.01
AI PRO R9700 llama 8B Q3_K_S 2048 pp2048 1854.98 1967.88 1.06
AI PRO R9700 llama 8B Q4_0 1 pp2048 90.50 90.34 1.00
AI PRO R9700 llama 8B Q4_0 2 pp2048 155.41 155.24 1.00
AI PRO R9700 llama 8B Q4_0 4 pp2048 277.24 277.09 1.00
AI PRO R9700 llama 8B Q4_0 8 pp2048 425.29 425.22 1.00
AI PRO R9700 llama 8B Q4_0 16 pp2048 900.34 897.94 1.00
AI PRO R9700 llama 8B Q4_0 32 pp2048 47.05 46.77 0.99
AI PRO R9700 llama 8B Q4_0 64 pp2048 94.49 88.12 0.93
AI PRO R9700 llama 8B Q4_0 128 pp2048 187.28 175.56 0.94
AI PRO R9700 llama 8B Q4_0 256 pp2048 372.06 348.99 0.94
AI PRO R9700 llama 8B Q4_0 512 pp2048 724.72 677.10 0.93
AI PRO R9700 llama 8B Q4_0 1024 pp2048 1340.10 1265.08 0.94
AI PRO R9700 llama 8B Q4_0 2048 pp2048 2065.00 2007.58 0.97
AI PRO R9700 llama 8B Q4_1 1 pp2048 85.31 85.47 1.00
AI PRO R9700 llama 8B Q4_1 2 pp2048 149.33 148.79 1.00
AI PRO R9700 llama 8B Q4_1 4 pp2048 265.37 265.74 1.00
AI PRO R9700 llama 8B Q4_1 8 pp2048 449.38 449.31 1.00
AI PRO R9700 llama 8B Q4_1 16 pp2048 916.65 913.81 1.00
AI PRO R9700 llama 8B Q4_1 32 pp2048 47.10 48.03 1.02
AI PRO R9700 llama 8B Q4_1 64 pp2048 94.01 96.22 1.02
AI PRO R9700 llama 8B Q4_1 128 pp2048 188.70 190.80 1.01
AI PRO R9700 llama 8B Q4_1 256 pp2048 369.84 375.95 1.02
AI PRO R9700 llama 8B Q4_1 512 pp2048 724.45 732.01 1.01
AI PRO R9700 llama 8B Q4_1 1024 pp2048 1350.11 1330.57 0.99
AI PRO R9700 llama 8B Q4_1 2048 pp2048 1903.40 1910.75 1.00
AI PRO R9700 llama 8B Q4_K_S 1 pp2048 87.48 87.46 1.00
AI PRO R9700 llama 8B Q4_K_S 2 pp2048 144.48 144.22 1.00
AI PRO R9700 llama 8B Q4_K_S 4 pp2048 201.93 202.05 1.00
AI PRO R9700 llama 8B Q4_K_S 8 pp2048 269.03 269.59 1.00
AI PRO R9700 llama 8B Q4_K_S 16 pp2048 871.46 871.87 1.00
AI PRO R9700 llama 8B Q4_K_S 32 pp2048 47.20 47.95 1.02
AI PRO R9700 llama 8B Q4_K_S 64 pp2048 94.88 95.88 1.01
AI PRO R9700 llama 8B Q4_K_S 128 pp2048 188.73 190.45 1.01
AI PRO R9700 llama 8B Q4_K_S 256 pp2048 369.93 373.97 1.01
AI PRO R9700 llama 8B Q4_K_S 512 pp2048 720.35 732.51 1.02
AI PRO R9700 llama 8B Q4_K_S 1024 pp2048 1340.10 1348.95 1.01
AI PRO R9700 llama 8B Q4_K_S 2048 pp2048 1939.16 1973.47 1.02
AI PRO R9700 llama 8B Q5_1 1 pp2048 78.28 78.50 1.00
AI PRO R9700 llama 8B Q5_1 2 pp2048 134.97 135.73 1.01
AI PRO R9700 llama 8B Q5_1 4 pp2048 246.45 246.91 1.00
AI PRO R9700 llama 8B Q5_1 8 pp2048 471.96 473.76 1.00
AI PRO R9700 llama 8B Q5_1 16 pp2048 687.72 688.26 1.00
AI PRO R9700 llama 8B Q5_1 32 pp2048 47.13 46.83 0.99
AI PRO R9700 llama 8B Q5_1 64 pp2048 94.69 95.81 1.01
AI PRO R9700 llama 8B Q5_1 128 pp2048 188.05 190.66 1.01
AI PRO R9700 llama 8B Q5_1 256 pp2048 367.66 373.77 1.02
AI PRO R9700 llama 8B Q5_1 512 pp2048 718.57 728.89 1.01
AI PRO R9700 llama 8B Q5_1 1024 pp2048 1334.04 1360.73 1.02
AI PRO R9700 llama 8B Q5_1 2048 pp2048 1876.55 1880.14 1.00
AI PRO R9700 llama 8B Q5_K_S 1 pp2048 78.64 78.68 1.00
AI PRO R9700 llama 8B Q5_K_S 2 pp2048 133.93 134.41 1.00
AI PRO R9700 llama 8B Q5_K_S 4 pp2048 195.97 196.24 1.00
AI PRO R9700 llama 8B Q5_K_S 8 pp2048 263.84 263.91 1.00
AI PRO R9700 llama 8B Q5_K_S 16 pp2048 866.82 868.42 1.00
AI PRO R9700 llama 8B Q5_K_S 32 pp2048 47.07 47.50 1.01
AI PRO R9700 llama 8B Q5_K_S 64 pp2048 94.34 95.64 1.01
AI PRO R9700 llama 8B Q5_K_S 128 pp2048 188.76 190.76 1.01
AI PRO R9700 llama 8B Q5_K_S 256 pp2048 369.03 375.59 1.02
AI PRO R9700 llama 8B Q5_K_S 512 pp2048 725.97 733.59 1.01
AI PRO R9700 llama 8B Q5_K_S 1024 pp2048 1333.84 1361.75 1.02
AI PRO R9700 llama 8B Q5_K_S 2048 pp2048 1906.69 1977.18 1.04
AI PRO R9700 llama 8B Q6_K 1 pp2048 71.86 71.83 1.00
AI PRO R9700 llama 8B Q6_K 2 pp2048 124.49 124.48 1.00
AI PRO R9700 llama 8B Q6_K 4 pp2048 199.36 199.23 1.00
AI PRO R9700 llama 8B Q6_K 8 pp2048 295.95 295.70 1.00
AI PRO R9700 llama 8B Q6_K 16 pp2048 660.70 660.41 1.00
AI PRO R9700 llama 8B Q6_K 32 pp2048 47.02 48.29 1.03
AI PRO R9700 llama 8B Q6_K 64 pp2048 94.05 96.72 1.03
AI PRO R9700 llama 8B Q6_K 128 pp2048 188.36 192.51 1.02
AI PRO R9700 llama 8B Q6_K 256 pp2048 370.51 378.06 1.02
AI PRO R9700 llama 8B Q6_K 512 pp2048 733.15 744.20 1.02
AI PRO R9700 llama 8B Q6_K 1024 pp2048 1135.83 1245.92 1.10
AI PRO R9700 llama 8B Q6_K 2048 pp2048 1439.37 1747.67 1.21
AI PRO R9700 llama 8B Q8_0 1 pp2048 62.77 62.73 1.00
AI PRO R9700 llama 8B Q8_0 2 pp2048 113.93 113.64 1.00
AI PRO R9700 llama 8B Q8_0 4 pp2048 209.72 209.28 1.00
AI PRO R9700 llama 8B Q8_0 8 pp2048 391.31 391.90 1.00
AI PRO R9700 llama 8B Q8_0 16 pp2048 744.83 745.90 1.00
AI PRO R9700 llama 8B Q8_0 32 pp2048 47.08 45.94 0.98
AI PRO R9700 llama 8B Q8_0 64 pp2048 93.55 91.95 0.98
AI PRO R9700 llama 8B Q8_0 128 pp2048 187.51 182.84 0.98
AI PRO R9700 llama 8B Q8_0 256 pp2048 369.21 361.73 0.98
AI PRO R9700 llama 8B Q8_0 512 pp2048 716.55 703.59 0.98
AI PRO R9700 llama 8B Q8_0 1024 pp2048 1319.34 1299.29 0.98
AI PRO R9700 llama 8B Q8_0 2048 pp2048 2068.12 2034.75 0.98

@JohannesGaessler
Copy link
Contributor

If you want to make further changes to the default kernel selection logic as it exists on master, make a new PR relative to master and post a benchmark for how this affects the performance with the newest ROCm version and without any environment variables being set. Any logic that is written for a specific ROCm version or for specific environment variables needs a corresponding check in the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants