Skip to content

CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware#23227

Merged
JohannesGaessler merged 3 commits into
ggml-org:masterfrom
jadenmach2:cdna-mmq-batch4
May 28, 2026
Merged

CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware#23227
JohannesGaessler merged 3 commits into
ggml-org:masterfrom
jadenmach2:cdna-mmq-batch4

Conversation

@jadenmach2
Copy link
Copy Markdown
Contributor

@jadenmach2 jadenmach2 commented May 17, 2026

The dispatcher unconditionally prefers mul_mat_vec_q (per-row GEMV) over mul_mat_q (MFMA-tiled GEMM) for any quantized matmul with batch <= 8. On AMD CDNA the MFMA path is materially faster once the verify batch reaches 4, which is exactly where every form of speculative decoding lives (ngram, draft-model, native MTP).

This patch adds an AMD-MFMA-specific cap of 3, gated on amd_mfma_available(cc), with a GGML_CUDA_FORCE_MMVQ env-var escape hatch. Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA, CPU) are byte-identical.

Measured on MI250X (gfx90a, ROCm 7.2.1) with Qwen3.6-27B Q4_K_M:

llama-batched-bench, token-gen tok/s by batch:
B=1: 34.28 -> 34.14 (noise; same dispatcher path)
B=2: 50.97 -> 50.83 (noise)
B=3: 58.18 -> 58.30 (noise; last B where MMVQ stays)
B=4: 63.01 -> 64.76 (+3 %)
B=5: 63.94 -> 76.93 (+20 %)
B=6: 63.53 -> 88.29 (+39 %)
B=8: 66.84 -> 107.62 (+61 %)

llama-server, MTP --spec-draft-n-max sweep, end-to-end tok/s
on a 256-token deterministic completion:
n_max=2 (verify B=3): 47.24 -> 46.90 (noise)
n_max=3 (verify B=4): 42.87 -> 50.12 (+17 %)
n_max=4 (verify B=5): 39.68 -> 47.79 (+20 %)
n_max=5 (verify B=6): 36.43 -> 50.98 (+40 %)

llama-bench plain workloads (no-regression check):
pp512: 794.78 -> 795.01 t/s (+0.03 %)
tg128: 34.44 -> 34.02 t/s (-1.2 %, day-to-day noise)

rocprofv3 kernel trace of one MTP n=3 completion:
mul_mat_q (MFMA-tiled GEMM): 0 -> 8 955 calls
mul_mat_vec_q (per-row GEMV): 11 400 -> 2 445 calls

The baseline column above is reproduced from the same binary by exporting GGML_CUDA_FORCE_MMVQ=1.

Threshold value: empirically tuned at MMVQ_MAX_BATCH_SIZE_AMD_MFMA=3 on MI250X with Q4_K_M. Crossover sits between B=3 (where MMVQ wins ~13 %) and B=4 (where MMQ wins ~8 %). MI300X (CDNA3) has wider MFMA tiles so the crossover almost certainly moves earlier; threshold of 3 is therefore conservative-correct (we'd leave a few percent on the table at B=3 in the worst case, never regress below baseline).

Per-quant batch threshold sweep — MI250X / gfx90a

Setup: Llama-3.2-3B-Instruct, llama-bench -fa 1 -n 0 -ub 1..8 -r 10, single GCD,
ROCm 7.2.1, GGML_HIP_GRAPHS=ON on both builds.

Why per-quant?

A first attempt used a single global threshold of 3 (MMVQ → MMQ at batch ≥ 4) on
AMD MFMA hardware. The 20-quant sweep showed this regresses legacy and IQ quants
substantially while only K-quants benefit:

quant ub=1 ub=2 ub=3 ub=4 ub=5 ub=6 ub=7 ub=8
Q4_0 1.00x 1.00x 1.00x 0.68x 0.73x 0.81x 0.88x 0.99x
Q4_1 1.01x 1.04x 1.03x 0.70x 0.77x 0.84x 0.91x 1.01x
Q5_0 1.02x 1.01x 1.01x 0.66x 0.71x 0.76x 0.84x 0.90x
Q5_1 1.02x 1.03x 1.02x 0.70x 0.75x 0.83x 0.90x 0.98x
Q8_0 0.96x 1.00x 0.97x 0.66x 0.73x 0.79x 0.85x 0.95x
Q2_K 0.99x 0.98x 1.00x 0.92x 0.98x 1.16x 1.23x 1.31x
Q3_K_S 1.02x 1.00x 1.00x 1.08x 0.96x 1.39x 1.33x 1.40x
Q4_K_S 0.99x 0.99x 0.99x 1.07x 1.20x 1.35x 1.51x 1.67x
Q5_K_S 0.99x 0.98x 0.98x 1.05x 1.17x 1.30x 1.46x 1.76x
Q6_K 0.97x 0.99x 0.99x 0.74x 0.97x 1.08x 1.23x 1.34x
IQ1_S 1.00x 0.99x 1.00x 0.70x 0.80x 0.86x 0.94x 1.02x
IQ2_XXS 0.99x 0.99x 1.00x 0.79x 0.89x 0.93x 1.00x 1.07x
IQ2_XS 0.99x 0.99x 0.99x 0.73x 0.83x 0.87x 0.94x 1.00x
IQ2_S 0.99x 1.02x 0.99x 0.73x 0.82x 0.86x 0.93x 0.96x
IQ3_XXS 1.00x 1.01x 1.00x 0.74x 0.84x 0.87x 0.93x 0.98x
IQ3_XS 1.00x 1.01x 1.00x 0.77x 0.86x 0.88x 0.97x 0.99x
IQ3_S 0.98x 1.00x 1.00x 0.78x 0.88x 0.88x 0.97x 0.99x
IQ3_M 1.01x 1.04x 1.02x 0.79x 0.89x 0.93x 1.02x 1.06x
IQ4_NL 0.97x 0.97x 0.98x 0.75x 0.82x 0.88x 0.98x 1.07x
IQ4_XS 0.98x 0.99x 1.00x 0.73x 0.81x 0.88x 0.98x 1.07x

Cell = patched / baseline throughput. Bold = ≥ 5% deviation from parity.
Values < 1.00x are regressions vs master; > 1.00x are speedups from this PR.

Per-quant thresholds chosen

The PR ships these thresholds (only on amd_mfma_available(cc)):

Quant family MMVQ max batch Justification
Q3_K, Q4_K, Q5_K 3 MMQ wins cleanly from batch ≥ 4 (+5% to +76%)
Q2_K, Q6_K 5 MMQ wins from batch ≥ 6 (+8% to +35%)
Q4_0/Q4_1/Q5_0/Q5_1/Q8_0 8 (unchanged) MMVQ stays ahead until batch=8
All IQ quants 8 (unchanged) MMVQ stays ahead until batch=8

Verification of refactored patch

After reshaping the dispatcher into per-quant thresholds, spot-checked four
representative quants (one per family) at the previously-pathological ubatches.
Numbers are tok/s, single GCD, MI250X.

quant ub=1 base / patched ub=4 base / patched ub=8 base / patched
Q4_0 256 / 259 (+1%) 676 / 685 (+1%) 927 / 938 (+1%)
Q4_K_S 229 / 224 (-2%) 452 / 480 (+6%) 559 / 940 (+68%)
Q6_K 204 / 207 (+1%) 530 / 532 (+0%) 586 / 775 (+32%)
IQ4_XS 274 / 262 (-4%) 706 / 698 (-1%) 939 / 953 (+1%)

K-quant wins are preserved; legacy and IQ regressions are eliminated.

Reproducing

# 20-quant cross sweep:
for Q in q4_0 q4_1 q5_0 q5_1 q8_0 q2_k q3_k_s q4_k_s q5_k_s q6_k \
         iq1_s iq2_xxs iq2_xs iq2_s iq3_xxs iq3_xs iq3_s iq3_m iq4_nl iq4_xs; do
  llama-bench --model /path/to/Llama-3.2-3B-${Q}.gguf -r 10 -fa 1 -n 0 \
              -ub 1,2,3,4,5,6,7,8 -o sql > sql/${BUILD}_${Q}.sql
done
sqlite3 llama-bench.sqlite < sql/*.sql
scripts/compare-llama-bench.py -i llama-bench.sqlite

Use GGML_CUDA_FORCE_MMVQ=1 to force the legacy global threshold on the patched
build for A/B testing.

Overview

Additional information

Requirements

@jadenmach2 jadenmach2 requested a review from a team as a code owner May 17, 2026 18:28
@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 17, 2026
@jadenmach2
Copy link
Copy Markdown
Contributor Author

@JohannesGaessler this touches CUDA matmul dispatch on AMD MFMA hardware. Threshold rationale and full bench numbers are in the PR body. Happy to narrow the gate, change the threshold, or rename the constant

@JohannesGaessler
Copy link
Copy Markdown
Contributor

JohannesGaessler commented May 18, 2026

For consistency with the rest of the code (e.g. ggml_cuda_should_use_mmq), please add a function ggml_cuda_should_use_mmvq that returns a boolean. For performance testing, a sweep across quantization formats is needed. I have a MI100 on which I will test it like this:

export mn=llama_3-8b
for q in q4_0 q4_1 q5_0 q5_1 q8_0 q2_k q3_k_s q4_k_s q5_k_s q6_k iq1_s iq2_xxs iq2_xs iq2_s iq3_xxs iq3_xs iq3_s iq3_m iq4_nl iq4_xs; do echo $q; ./build/bin/llama-bench --model models/opt/${mn}-${q}.gguf -r 10 -fa 1 -n 0 -ub "1-8" -o sql|sqlite3 llama-bench.sqlite; sleep 10; done

The above command is executed for the last master branch commit where the PR branches off as well as the PR itself, basically any small model with a few billion parameters will be fine. Afterwards a table for performance comparison can be created like this:

python scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqlite

I will require data like this to be the basis for the kernel selection logic in ggml_cuda_should_use_mmvq. If you could provide an equivalent table for CDNA2 it would be appreciated, otherwise the kernel selection logic for CDNA2 will be based on CDNA1 data.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

I forgot: if the data in the comparison is too noisy, more benchmark runs can be added to the database until it evens out.

The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8)
to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled
GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs
substantially by quant family because the per-row GEMV cost is dominated
by dequantisation, not the dot-product itself: K-quants pay a heavier
super-block decode and so MMQ wins sooner; legacy and IQ quants have
lean decode and stay ahead until the batch fully populates an MFMA tile.

This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool,
mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant
thresholds on amd_mfma_available(cc):

  Q3_K, Q4_K, Q5_K  : MMVQ <= 3   (MMQ wins from batch=4: +5% .. +76%)
  Q2_K, Q6_K        : MMVQ <= 5   (MMQ wins from batch=6: +8% .. +35%)
  others            : MMVQ <= 8   (legacy & IQ regress under MMQ; unchanged)

Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical
to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold
for A/B testing.

Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct,
llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps.
Full table in PR description.

  Selected pp512 throughput (tok/s, ub=8):
    Q4_K_S:  559 -> 940  (+68%)
    Q5_K_S:  503 -> 884  (+76%)
    Q3_K_S:  629 -> 879  (+40%)
    Q2_K  :  615 -> 809  (+32%)
    Q6_K  :  582 -> 776  (+33%)

  Selected pp512 throughput (tok/s, ub=4):
    Q4_K_S:  444 -> 480  (+ 8%)
    Q4_0  :  682 -> 685  (+ 0%)   (no regression - retains MMVQ)
    IQ4_XS:  706 -> 698  (- 1%)   (no regression - retains MMVQ)
@jadenmach2
Copy link
Copy Markdown
Contributor Author

jadenmach2 commented May 19, 2026

@JohannesGaessler addressed both points: refactored to ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool
(mirrors ggml_cuda_should_use_mmq), and ran the full 20-quant llama-bench sweep on MI250X. Single threshold was
indeed wrong — only K-quants benefit, so the v2 patch uses per-quant thresholds (Q3/Q4/Q5_K ≤ 3, Q2/Q6_K ≤ 5,
legacy/IQ unchanged at 8). PR description has the benchmarking results

Comment thread ggml/src/ggml-cuda/mmvq.cu Outdated
Comment on lines +284 to +298
static int64_t mmvq_max_batch_amd_mfma(ggml_type type) {
switch (type) {
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
return 3; // MMQ wins from batch=4 onward (+5% to +76%)
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q6_K:
return 5; // MMQ wins from batch=6 onward (+8% to +35%)
default:
// Legacy (Q4_0/Q4_1/Q5_0/Q5_1/Q8_0) and IQ quants regress under MMQ
// up to batch=7, so keep the global threshold for them.
return MMVQ_MAX_BATCH_SIZE;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline this function and remove the comments except for "tuned for CDNA2".

Comment thread ggml/src/ggml-cuda/mmvq.cu Outdated
Comment on lines +301 to +304
static const bool force_mmvq = (getenv("GGML_CUDA_FORCE_MMVQ") != nullptr);
if (force_mmvq) {
return ne11 <= MMVQ_MAX_BATCH_SIZE;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
static const bool force_mmvq = (getenv("GGML_CUDA_FORCE_MMVQ") != nullptr);
if (force_mmvq) {
return ne11 <= MMVQ_MAX_BATCH_SIZE;
}

Comment thread ggml/src/ggml-cuda/mmvq.cuh Outdated
Comment on lines +5 to +16
// Returns true if a quantized matmul of shape (..., ne11) on a device with
// compute capability `cc` should take the MMVQ (per-row GEMV) path.
// Returning false sends it to the MMQ path (batched GEMM, MFMA-tiled on CDNA).
//
// On AMD MFMA hardware (CDNA) the optimal batch threshold is quant-dependent:
// K-quants have a heavier per-row GEMV (block scales + super-block decode), so
// MFMA-tiled MMQ overtakes MMVQ at a smaller batch; legacy and IQ quants have
// lean GEMV kernels that stay ahead until the batch nearly fills an MFMA tile.
// Thresholds calibrated on MI250X with Llama-3.2-3B (pp512, ubatch 1..8) — see
// the PR description for the full sweep.
//
// Set GGML_CUDA_FORCE_MMVQ=1 to restore the original global threshold.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Returns true if a quantized matmul of shape (..., ne11) on a device with
// compute capability `cc` should take the MMVQ (per-row GEMV) path.
// Returning false sends it to the MMQ path (batched GEMM, MFMA-tiled on CDNA).
//
// On AMD MFMA hardware (CDNA) the optimal batch threshold is quant-dependent:
// K-quants have a heavier per-row GEMV (block scales + super-block decode), so
// MFMA-tiled MMQ overtakes MMVQ at a smaller batch; legacy and IQ quants have
// lean GEMV kernels that stay ahead until the batch nearly fills an MFMA tile.
// Thresholds calibrated on MI250X with Llama-3.2-3B (pp512, ubatch 1..8) — see
// the PR description for the full sweep.
//
// Set GGML_CUDA_FORCE_MMVQ=1 to restore the original global threshold.

Comment thread ggml/src/ggml-cuda/mmvq.cu Outdated
if (force_mmvq) {
return ne11 <= MMVQ_MAX_BATCH_SIZE;
}
if (amd_mfma_available(cc)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (amd_mfma_available(cc)) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {

@jadenmach2
Copy link
Copy Markdown
Contributor Author

@JohannesGaessler , addressed the comments from your review

@JohannesGaessler
Copy link
Copy Markdown
Contributor

I benchmarked the performance of MMVQ vs. MMQ on CDNA1:

Performance table
GPU Model Microbatch size Test t/s MMVQ t/s MMQ Speedup
MI100 llama 8B IQ1_S - 1.5625 bpw 1 pp512 126.69 98.05 0.77
MI100 llama 8B IQ1_S - 1.5625 bpw 2 pp512 174.32 120.82 0.69
MI100 llama 8B IQ1_S - 1.5625 bpw 3 pp512 236.78 176.39 0.74
MI100 llama 8B IQ1_S - 1.5625 bpw 4 pp512 265.04 232.86 0.88
MI100 llama 8B IQ1_S - 1.5625 bpw 5 pp512 306.83 281.47 0.92
MI100 llama 8B IQ1_S - 1.5625 bpw 6 pp512 328.00 338.19 1.03
MI100 llama 8B IQ1_S - 1.5625 bpw 7 pp512 336.46 392.99 1.17
MI100 llama 8B IQ1_S - 1.5625 bpw 8 pp512 347.86 450.97 1.30
MI100 llama 8B IQ2_S - 2.5 bpw 1 pp512 83.43 67.59 0.81
MI100 llama 8B IQ2_S - 2.5 bpw 2 pp512 145.57 82.37 0.57
MI100 llama 8B IQ2_S - 2.5 bpw 3 pp512 195.73 121.67 0.62
MI100 llama 8B IQ2_S - 2.5 bpw 4 pp512 230.78 161.06 0.70
MI100 llama 8B IQ2_S - 2.5 bpw 5 pp512 256.19 196.42 0.77
MI100 llama 8B IQ2_S - 2.5 bpw 6 pp512 277.11 235.33 0.85
MI100 llama 8B IQ2_S - 2.5 bpw 7 pp512 301.52 274.00 0.91
MI100 llama 8B IQ2_S - 2.5 bpw 8 pp512 317.24 314.15 0.99
MI100 llama 8B IQ2_XS - 2.3125 bpw 1 pp512 86.78 69.53 0.80
MI100 llama 8B IQ2_XS - 2.3125 bpw 2 pp512 147.12 84.60 0.58
MI100 llama 8B IQ2_XS - 2.3125 bpw 3 pp512 190.69 125.03 0.66
MI100 llama 8B IQ2_XS - 2.3125 bpw 4 pp512 229.57 165.55 0.72
MI100 llama 8B IQ2_XS - 2.3125 bpw 5 pp512 264.19 201.22 0.76
MI100 llama 8B IQ2_XS - 2.3125 bpw 6 pp512 288.72 241.20 0.84
MI100 llama 8B IQ2_XS - 2.3125 bpw 7 pp512 305.14 280.69 0.92
MI100 llama 8B IQ2_XS - 2.3125 bpw 8 pp512 321.64 321.15 1.00
MI100 llama 8B IQ2_XXS - 2.0625 bpw 1 pp512 89.40 72.58 0.81
MI100 llama 8B IQ2_XXS - 2.0625 bpw 2 pp512 150.80 97.32 0.65
MI100 llama 8B IQ2_XXS - 2.0625 bpw 3 pp512 202.14 144.70 0.72
MI100 llama 8B IQ2_XXS - 2.0625 bpw 4 pp512 242.75 191.48 0.79
MI100 llama 8B IQ2_XXS - 2.0625 bpw 5 pp512 263.52 230.20 0.87
MI100 llama 8B IQ2_XXS - 2.0625 bpw 6 pp512 297.40 277.01 0.93
MI100 llama 8B IQ2_XXS - 2.0625 bpw 7 pp512 317.72 321.44 1.01
MI100 llama 8B IQ2_XXS - 2.0625 bpw 8 pp512 336.68 368.99 1.10
MI100 llama 8B IQ3_S - 3.4375 bpw 1 pp512 77.31 63.64 0.82
MI100 llama 8B IQ3_S - 3.4375 bpw 2 pp512 138.70 89.44 0.64
MI100 llama 8B IQ3_S - 3.4375 bpw 3 pp512 189.23 131.98 0.70
MI100 llama 8B IQ3_S - 3.4375 bpw 4 pp512 216.65 174.75 0.81
MI100 llama 8B IQ3_S - 3.4375 bpw 5 pp512 248.13 209.84 0.85
MI100 llama 8B IQ3_S - 3.4375 bpw 6 pp512 269.74 252.90 0.94
MI100 llama 8B IQ3_S - 3.4375 bpw 7 pp512 294.79 293.47 1.00
MI100 llama 8B IQ3_S - 3.4375 bpw 8 pp512 317.12 336.90 1.06
MI100 llama 8B IQ3_S mix - 3.66 bpw 1 pp512 80.86 65.85 0.81
MI100 llama 8B IQ3_S mix - 3.66 bpw 2 pp512 140.92 92.39 0.66
MI100 llama 8B IQ3_S mix - 3.66 bpw 3 pp512 189.30 135.72 0.72
MI100 llama 8B IQ3_S mix - 3.66 bpw 4 pp512 214.65 180.41 0.84
MI100 llama 8B IQ3_S mix - 3.66 bpw 5 pp512 234.70 216.35 0.92
MI100 llama 8B IQ3_S mix - 3.66 bpw 6 pp512 254.17 259.77 1.02
MI100 llama 8B IQ3_S mix - 3.66 bpw 7 pp512 272.59 301.85 1.11
MI100 llama 8B IQ3_S mix - 3.66 bpw 8 pp512 292.80 345.90 1.18
MI100 llama 8B IQ3_XS - 3.3 bpw 1 pp512 78.81 64.88 0.82
MI100 llama 8B IQ3_XS - 3.3 bpw 2 pp512 140.20 84.91 0.61
MI100 llama 8B IQ3_XS - 3.3 bpw 3 pp512 191.17 124.69 0.65
MI100 llama 8B IQ3_XS - 3.3 bpw 4 pp512 223.65 166.51 0.74
MI100 llama 8B IQ3_XS - 3.3 bpw 5 pp512 249.14 202.41 0.81
MI100 llama 8B IQ3_XS - 3.3 bpw 6 pp512 279.43 240.91 0.86
MI100 llama 8B IQ3_XS - 3.3 bpw 7 pp512 298.40 280.33 0.94
MI100 llama 8B IQ3_XS - 3.3 bpw 8 pp512 322.25 322.09 1.00
MI100 llama 8B IQ3_XXS - 3.0625 bpw 1 pp512 80.96 64.25 0.79
MI100 llama 8B IQ3_XXS - 3.0625 bpw 2 pp512 139.08 83.08 0.60
MI100 llama 8B IQ3_XXS - 3.0625 bpw 3 pp512 185.86 122.14 0.66
MI100 llama 8B IQ3_XXS - 3.0625 bpw 4 pp512 219.91 164.42 0.75
MI100 llama 8B IQ3_XXS - 3.0625 bpw 5 pp512 244.69 199.25 0.81
MI100 llama 8B IQ3_XXS - 3.0625 bpw 6 pp512 274.99 237.94 0.87
MI100 llama 8B IQ3_XXS - 3.0625 bpw 7 pp512 299.18 276.11 0.92
MI100 llama 8B IQ3_XXS - 3.0625 bpw 8 pp512 317.67 315.75 0.99
MI100 llama 8B IQ4_NL - 4.5 bpw 1 pp512 134.16 93.07 0.69
MI100 llama 8B IQ4_NL - 4.5 bpw 2 pp512 215.58 105.15 0.49
MI100 llama 8B IQ4_NL - 4.5 bpw 3 pp512 261.57 156.28 0.60
MI100 llama 8B IQ4_NL - 4.5 bpw 4 pp512 303.27 207.68 0.68
MI100 llama 8B IQ4_NL - 4.5 bpw 5 pp512 335.19 245.22 0.73
MI100 llama 8B IQ4_NL - 4.5 bpw 6 pp512 361.99 294.18 0.81
MI100 llama 8B IQ4_NL - 4.5 bpw 7 pp512 378.93 342.10 0.90
MI100 llama 8B IQ4_NL - 4.5 bpw 8 pp512 385.01 391.52 1.02
MI100 llama 8B IQ4_XS - 4.25 bpw 1 pp512 143.71 103.55 0.72
MI100 llama 8B IQ4_XS - 4.25 bpw 2 pp512 220.19 130.89 0.59
MI100 llama 8B IQ4_XS - 4.25 bpw 3 pp512 276.85 193.60 0.70
MI100 llama 8B IQ4_XS - 4.25 bpw 4 pp512 326.47 256.26 0.78
MI100 llama 8B IQ4_XS - 4.25 bpw 5 pp512 348.94 302.55 0.87
MI100 llama 8B IQ4_XS - 4.25 bpw 6 pp512 371.58 362.74 0.98
MI100 llama 8B IQ4_XS - 4.25 bpw 7 pp512 386.52 420.88 1.09
MI100 llama 8B IQ4_XS - 4.25 bpw 8 pp512 387.31 482.80 1.25
MI100 llama 8B Q2_K_M 1 pp512 110.07 85.79 0.78
MI100 llama 8B Q2_K_M 2 pp512 155.49 105.81 0.68
MI100 llama 8B Q2_K_M 3 pp512 167.16 155.24 0.93
MI100 llama 8B Q2_K_M 4 pp512 209.58 206.36 0.98
MI100 llama 8B Q2_K_M 5 pp512 224.09 250.73 1.12
MI100 llama 8B Q2_K_M 6 pp512 235.48 297.88 1.26
MI100 llama 8B Q2_K_M 7 pp512 184.13 345.65 1.88
MI100 llama 8B Q2_K_M 8 pp512 236.82 395.27 1.67
MI100 llama 8B Q3_K_S 1 pp512 98.21 81.33 0.83
MI100 llama 8B Q3_K_S 2 pp512 147.51 112.57 0.76
MI100 llama 8B Q3_K_S 3 pp512 140.09 166.83 1.19
MI100 llama 8B Q3_K_S 4 pp512 208.02 221.23 1.06
MI100 llama 8B Q3_K_S 5 pp512 235.05 263.67 1.12
MI100 llama 8B Q3_K_S 6 pp512 251.56 317.93 1.26
MI100 llama 8B Q3_K_S 7 pp512 148.00 368.29 2.49
MI100 llama 8B Q3_K_S 8 pp512 249.83 424.61 1.70
MI100 llama 8B Q4_0 1 pp512 146.21 108.74 0.74
MI100 llama 8B Q4_0 2 pp512 230.11 132.15 0.57
MI100 llama 8B Q4_0 3 pp512 305.91 190.64 0.62
MI100 llama 8B Q4_0 4 pp512 355.23 250.83 0.71
MI100 llama 8B Q4_0 5 pp512 382.81 295.41 0.77
MI100 llama 8B Q4_0 6 pp512 404.42 352.38 0.87
MI100 llama 8B Q4_0 7 pp512 428.71 406.84 0.95
MI100 llama 8B Q4_0 8 pp512 441.91 466.47 1.06
MI100 llama 8B Q4_1 1 pp512 141.58 108.57 0.77
MI100 llama 8B Q4_1 2 pp512 223.47 128.70 0.58
MI100 llama 8B Q4_1 3 pp512 299.83 189.24 0.63
MI100 llama 8B Q4_1 4 pp512 349.14 250.76 0.72
MI100 llama 8B Q4_1 5 pp512 373.53 294.44 0.79
MI100 llama 8B Q4_1 6 pp512 405.26 353.85 0.87
MI100 llama 8B Q4_1 7 pp512 418.15 409.82 0.98
MI100 llama 8B Q4_1 8 pp512 430.25 468.30 1.09
MI100 llama 8B Q4_K_S 1 pp512 123.83 98.07 0.79
MI100 llama 8B Q4_K_S 2 pp512 159.52 124.62 0.78
MI100 llama 8B Q4_K_S 3 pp512 169.37 184.23 1.09
MI100 llama 8B Q4_K_S 4 pp512 183.70 244.41 1.33
MI100 llama 8B Q4_K_S 5 pp512 165.43 288.41 1.74
MI100 llama 8B Q4_K_S 6 pp512 176.98 345.97 1.95
MI100 llama 8B Q4_K_S 7 pp512 183.34 401.58 2.19
MI100 llama 8B Q4_K_S 8 pp512 183.99 460.81 2.50
MI100 llama 8B Q5_0 1 pp512 117.24 85.83 0.73
MI100 llama 8B Q5_0 2 pp512 191.20 94.23 0.49
MI100 llama 8B Q5_0 3 pp512 250.67 137.16 0.55
MI100 llama 8B Q5_0 4 pp512 291.97 181.39 0.62
MI100 llama 8B Q5_0 5 pp512 318.24 219.40 0.69
MI100 llama 8B Q5_0 6 pp512 344.90 263.50 0.76
MI100 llama 8B Q5_0 7 pp512 360.57 307.03 0.85
MI100 llama 8B Q5_0 8 pp512 383.48 351.20 0.92
MI100 llama 8B Q5_1 1 pp512 122.97 92.10 0.75
MI100 llama 8B Q5_1 2 pp512 193.68 108.40 0.56
MI100 llama 8B Q5_1 3 pp512 255.93 158.92 0.62
MI100 llama 8B Q5_1 4 pp512 303.26 209.54 0.69
MI100 llama 8B Q5_1 5 pp512 326.67 250.46 0.77
MI100 llama 8B Q5_1 6 pp512 354.82 302.28 0.85
MI100 llama 8B Q5_1 7 pp512 370.15 349.95 0.95
MI100 llama 8B Q5_1 8 pp512 385.90 402.41 1.04
MI100 llama 8B Q5_K_S 1 pp512 107.97 82.96 0.77
MI100 llama 8B Q5_K_S 2 pp512 148.06 98.37 0.66
MI100 llama 8B Q5_K_S 3 pp512 159.93 146.95 0.92
MI100 llama 8B Q5_K_S 4 pp512 150.67 195.17 1.30
MI100 llama 8B Q5_K_S 5 pp512 162.18 231.97 1.43
MI100 llama 8B Q5_K_S 6 pp512 171.36 278.36 1.62
MI100 llama 8B Q5_K_S 7 pp512 177.17 324.39 1.83
MI100 llama 8B Q5_K_S 8 pp512 128.16 371.27 2.90
MI100 llama 8B Q6_K 1 pp512 92.34 74.68 0.81
MI100 llama 8B Q6_K 2 pp512 134.65 91.04 0.68
MI100 llama 8B Q6_K 3 pp512 174.85 134.19 0.77
MI100 llama 8B Q6_K 4 pp512 197.06 177.20 0.90
MI100 llama 8B Q6_K 5 pp512 207.26 212.10 1.02
MI100 llama 8B Q6_K 6 pp512 225.25 254.43 1.13
MI100 llama 8B Q6_K 7 pp512 237.94 296.10 1.24
MI100 llama 8B Q6_K 8 pp512 221.74 339.92 1.53
MI100 llama 8B Q8_0 1 pp512 94.62 78.65 0.83
MI100 llama 8B Q8_0 2 pp512 150.97 106.46 0.71
MI100 llama 8B Q8_0 3 pp512 204.20 155.83 0.76
MI100 llama 8B Q8_0 4 pp512 248.00 205.23 0.83
MI100 llama 8B Q8_0 5 pp512 278.15 245.74 0.88
MI100 llama 8B Q8_0 6 pp512 307.11 295.78 0.96
MI100 llama 8B Q8_0 7 pp512 313.37 340.89 1.09
MI100 llama 8B Q8_0 8 pp512 331.12 394.32 1.19

I pushed the corresponding kernel selection for CDNA1 to this PR, for CDNA2/3 we the logic is based on your numbers.

@JohannesGaessler JohannesGaessler requested review from a team and IMbackK May 21, 2026 10:48
@JohannesGaessler
Copy link
Copy Markdown
Contributor

@ggml-org/maintainers can I please get a second approval?

@JohannesGaessler JohannesGaessler merged commit bc81d47 into ggml-org:master May 28, 2026
50 checks passed
adrianhoehne pushed a commit to adrianhoehne/llama.cpp that referenced this pull request May 28, 2026
…ml-org#23227)

* CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware

The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8)
to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled
GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs
substantially by quant family because the per-row GEMV cost is dominated
by dequantisation, not the dot-product itself: K-quants pay a heavier
super-block decode and so MMQ wins sooner; legacy and IQ quants have
lean decode and stay ahead until the batch fully populates an MFMA tile.

This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool,
mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant
thresholds on amd_mfma_available(cc):

  Q3_K, Q4_K, Q5_K  : MMVQ <= 3   (MMQ wins from batch=4: +5% .. +76%)
  Q2_K, Q6_K        : MMVQ <= 5   (MMQ wins from batch=6: +8% .. +35%)
  others            : MMVQ <= 8   (legacy & IQ regress under MMQ; unchanged)

Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical
to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold
for A/B testing.

Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct,
llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps.
Full table in PR description.

  Selected pp512 throughput (tok/s, ub=8):
    Q4_K_S:  559 -> 940  (+68%)
    Q5_K_S:  503 -> 884  (+76%)
    Q3_K_S:  629 -> 879  (+40%)
    Q2_K  :  615 -> 809  (+32%)
    Q6_K  :  582 -> 776  (+33%)

  Selected pp512 throughput (tok/s, ub=4):
    Q4_K_S:  444 -> 480  (+ 8%)
    Q4_0  :  682 -> 685  (+ 0%)   (no regression - retains MMVQ)
    IQ4_XS:  706 -> 698  (- 1%)   (no regression - retains MMVQ)

* CUDA: address review — inline MMVQ batch table, drop env hatch & doc block

* tune kernel selection logic for CDNA1

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request May 28, 2026
* origin/master: (32 commits)
hexagon: basic/generic op fusion support and RMS_NORM+MUL fusion (ggml-org#23835)
mtmd-debug: add color and rainbow mode (ggml-org#23829)
mtmd: fix gemma 4 projector pre_norm (ggml-org#23822)
opencl: move backend info printing into its own function (ggml-org#23702)
ci : run ui publish on ubuntu-slim (ggml-org#23818)
ui: fix audio and video modality detection (ggml-org#23756)
ci : releases use Github-hosted builds for the UI (ggml-org#23823)
app : improve help output (ggml-org#23805)
mtmd: n_head_kv defaults to n_head (ggml-org#23782)
mtmd: fix gemma 4 audio rms norm eps (ggml-org#23815)
ci : change Vulkan builds to Release to reduce ccache (ggml-org#23820)
arg: Add LLAMA_ARG_API_KEY_FILE environment variable for --api-key-file (ggml-org#23167)
test-llama-archs: fix table format [no release] (ggml-org#23810)
ggml: auto apply iGPU flag CUDA/HIP if integrated device (ggml-org#23007)
mmvq Optim: add MMVQ_PARAMETERS_TURING(mmvq_parameter_table_id) for … (ggml-org#23729)
CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware (ggml-org#23227)
server: minor tweaks to use more cpp features (ggml-org#23785)
hexagon: minor refresh for HMX FA and MM (ggml-org#23796)
vulkan: fast path for walsh-hadamard transform (ggml-org#23687)
chat : add Granite 4.1 chat template (ggml-org#23518)
...
fewtarius pushed a commit to fewtarius/llama.cpp that referenced this pull request May 30, 2026
…ml-org#23227)

* CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware

The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8)
to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled
GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs
substantially by quant family because the per-row GEMV cost is dominated
by dequantisation, not the dot-product itself: K-quants pay a heavier
super-block decode and so MMQ wins sooner; legacy and IQ quants have
lean decode and stay ahead until the batch fully populates an MFMA tile.

This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool,
mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant
thresholds on amd_mfma_available(cc):

  Q3_K, Q4_K, Q5_K  : MMVQ <= 3   (MMQ wins from batch=4: +5% .. +76%)
  Q2_K, Q6_K        : MMVQ <= 5   (MMQ wins from batch=6: +8% .. +35%)
  others            : MMVQ <= 8   (legacy & IQ regress under MMQ; unchanged)

Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical
to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold
for A/B testing.

Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct,
llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps.
Full table in PR description.

  Selected pp512 throughput (tok/s, ub=8):
    Q4_K_S:  559 -> 940  (+68%)
    Q5_K_S:  503 -> 884  (+76%)
    Q3_K_S:  629 -> 879  (+40%)
    Q2_K  :  615 -> 809  (+32%)
    Q6_K  :  582 -> 776  (+33%)

  Selected pp512 throughput (tok/s, ub=4):
    Q4_K_S:  444 -> 480  (+ 8%)
    Q4_0  :  682 -> 685  (+ 0%)   (no regression - retains MMVQ)
    IQ4_XS:  706 -> 698  (- 1%)   (no regression - retains MMVQ)

* CUDA: address review — inline MMVQ batch table, drop env hatch & doc block

* tune kernel selection logic for CDNA1

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
turbo-tan pushed a commit to turbo-tan/llama.cpp-tq3 that referenced this pull request Jun 2, 2026
…ml-org#23227)

* CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware

The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8)
to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled
GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs
substantially by quant family because the per-row GEMV cost is dominated
by dequantisation, not the dot-product itself: K-quants pay a heavier
super-block decode and so MMQ wins sooner; legacy and IQ quants have
lean decode and stay ahead until the batch fully populates an MFMA tile.

This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool,
mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant
thresholds on amd_mfma_available(cc):

  Q3_K, Q4_K, Q5_K  : MMVQ <= 3   (MMQ wins from batch=4: +5% .. +76%)
  Q2_K, Q6_K        : MMVQ <= 5   (MMQ wins from batch=6: +8% .. +35%)
  others            : MMVQ <= 8   (legacy & IQ regress under MMQ; unchanged)

Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical
to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold
for A/B testing.

Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct,
llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps.
Full table in PR description.

  Selected pp512 throughput (tok/s, ub=8):
    Q4_K_S:  559 -> 940  (+68%)
    Q5_K_S:  503 -> 884  (+76%)
    Q3_K_S:  629 -> 879  (+40%)
    Q2_K  :  615 -> 809  (+32%)
    Q6_K  :  582 -> 776  (+33%)

  Selected pp512 throughput (tok/s, ub=4):
    Q4_K_S:  444 -> 480  (+ 8%)
    Q4_0  :  682 -> 685  (+ 0%)   (no regression - retains MMVQ)
    IQ4_XS:  706 -> 698  (- 1%)   (no regression - retains MMVQ)

* CUDA: address review — inline MMVQ batch table, drop env hatch & doc block

* tune kernel selection logic for CDNA1

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants