CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware#23227
Conversation
|
@JohannesGaessler this touches CUDA matmul dispatch on AMD MFMA hardware. Threshold rationale and full bench numbers are in the PR body. Happy to narrow the gate, change the threshold, or rename the constant |
|
For consistency with the rest of the code (e.g. export mn=llama_3-8b
for q in q4_0 q4_1 q5_0 q5_1 q8_0 q2_k q3_k_s q4_k_s q5_k_s q6_k iq1_s iq2_xxs iq2_xs iq2_s iq3_xxs iq3_xs iq3_s iq3_m iq4_nl iq4_xs; do echo $q; ./build/bin/llama-bench --model models/opt/${mn}-${q}.gguf -r 10 -fa 1 -n 0 -ub "1-8" -o sql|sqlite3 llama-bench.sqlite; sleep 10; doneThe above command is executed for the last master branch commit where the PR branches off as well as the PR itself, basically any small model with a few billion parameters will be fine. Afterwards a table for performance comparison can be created like this: python scripts/compare-llama-bench.py -s gpu_info,model_type,n_ubatch -i llama-bench.sqliteI will require data like this to be the basis for the kernel selection logic in |
|
I forgot: if the data in the comparison is too noisy, more benchmark runs can be added to the database until it evens out. |
The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8)
to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled
GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs
substantially by quant family because the per-row GEMV cost is dominated
by dequantisation, not the dot-product itself: K-quants pay a heavier
super-block decode and so MMQ wins sooner; legacy and IQ quants have
lean decode and stay ahead until the batch fully populates an MFMA tile.
This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool,
mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant
thresholds on amd_mfma_available(cc):
Q3_K, Q4_K, Q5_K : MMVQ <= 3 (MMQ wins from batch=4: +5% .. +76%)
Q2_K, Q6_K : MMVQ <= 5 (MMQ wins from batch=6: +8% .. +35%)
others : MMVQ <= 8 (legacy & IQ regress under MMQ; unchanged)
Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical
to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold
for A/B testing.
Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct,
llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps.
Full table in PR description.
Selected pp512 throughput (tok/s, ub=8):
Q4_K_S: 559 -> 940 (+68%)
Q5_K_S: 503 -> 884 (+76%)
Q3_K_S: 629 -> 879 (+40%)
Q2_K : 615 -> 809 (+32%)
Q6_K : 582 -> 776 (+33%)
Selected pp512 throughput (tok/s, ub=4):
Q4_K_S: 444 -> 480 (+ 8%)
Q4_0 : 682 -> 685 (+ 0%) (no regression - retains MMVQ)
IQ4_XS: 706 -> 698 (- 1%) (no regression - retains MMVQ)
199dc40 to
98c6212
Compare
|
@JohannesGaessler addressed both points: refactored to |
| static int64_t mmvq_max_batch_amd_mfma(ggml_type type) { | ||
| switch (type) { | ||
| case GGML_TYPE_Q3_K: | ||
| case GGML_TYPE_Q4_K: | ||
| case GGML_TYPE_Q5_K: | ||
| return 3; // MMQ wins from batch=4 onward (+5% to +76%) | ||
| case GGML_TYPE_Q2_K: | ||
| case GGML_TYPE_Q6_K: | ||
| return 5; // MMQ wins from batch=6 onward (+8% to +35%) | ||
| default: | ||
| // Legacy (Q4_0/Q4_1/Q5_0/Q5_1/Q8_0) and IQ quants regress under MMQ | ||
| // up to batch=7, so keep the global threshold for them. | ||
| return MMVQ_MAX_BATCH_SIZE; | ||
| } | ||
| } |
There was a problem hiding this comment.
Inline this function and remove the comments except for "tuned for CDNA2".
| static const bool force_mmvq = (getenv("GGML_CUDA_FORCE_MMVQ") != nullptr); | ||
| if (force_mmvq) { | ||
| return ne11 <= MMVQ_MAX_BATCH_SIZE; | ||
| } |
There was a problem hiding this comment.
| static const bool force_mmvq = (getenv("GGML_CUDA_FORCE_MMVQ") != nullptr); | |
| if (force_mmvq) { | |
| return ne11 <= MMVQ_MAX_BATCH_SIZE; | |
| } |
| // Returns true if a quantized matmul of shape (..., ne11) on a device with | ||
| // compute capability `cc` should take the MMVQ (per-row GEMV) path. | ||
| // Returning false sends it to the MMQ path (batched GEMM, MFMA-tiled on CDNA). | ||
| // | ||
| // On AMD MFMA hardware (CDNA) the optimal batch threshold is quant-dependent: | ||
| // K-quants have a heavier per-row GEMV (block scales + super-block decode), so | ||
| // MFMA-tiled MMQ overtakes MMVQ at a smaller batch; legacy and IQ quants have | ||
| // lean GEMV kernels that stay ahead until the batch nearly fills an MFMA tile. | ||
| // Thresholds calibrated on MI250X with Llama-3.2-3B (pp512, ubatch 1..8) — see | ||
| // the PR description for the full sweep. | ||
| // | ||
| // Set GGML_CUDA_FORCE_MMVQ=1 to restore the original global threshold. |
There was a problem hiding this comment.
| // Returns true if a quantized matmul of shape (..., ne11) on a device with | |
| // compute capability `cc` should take the MMVQ (per-row GEMV) path. | |
| // Returning false sends it to the MMQ path (batched GEMM, MFMA-tiled on CDNA). | |
| // | |
| // On AMD MFMA hardware (CDNA) the optimal batch threshold is quant-dependent: | |
| // K-quants have a heavier per-row GEMV (block scales + super-block decode), so | |
| // MFMA-tiled MMQ overtakes MMVQ at a smaller batch; legacy and IQ quants have | |
| // lean GEMV kernels that stay ahead until the batch nearly fills an MFMA tile. | |
| // Thresholds calibrated on MI250X with Llama-3.2-3B (pp512, ubatch 1..8) — see | |
| // the PR description for the full sweep. | |
| // | |
| // Set GGML_CUDA_FORCE_MMVQ=1 to restore the original global threshold. |
| if (force_mmvq) { | ||
| return ne11 <= MMVQ_MAX_BATCH_SIZE; | ||
| } | ||
| if (amd_mfma_available(cc)) { |
There was a problem hiding this comment.
| if (amd_mfma_available(cc)) { | |
| if (GGML_CUDA_CC_IS_CDNA(cc)) { |
|
@JohannesGaessler , addressed the comments from your review |
|
I benchmarked the performance of MMVQ vs. MMQ on CDNA1: Performance table
I pushed the corresponding kernel selection for CDNA1 to this PR, for CDNA2/3 we the logic is based on your numbers. |
|
@ggml-org/maintainers can I please get a second approval? |
…ml-org#23227) * CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8) to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs substantially by quant family because the per-row GEMV cost is dominated by dequantisation, not the dot-product itself: K-quants pay a heavier super-block decode and so MMQ wins sooner; legacy and IQ quants have lean decode and stay ahead until the batch fully populates an MFMA tile. This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool, mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant thresholds on amd_mfma_available(cc): Q3_K, Q4_K, Q5_K : MMVQ <= 3 (MMQ wins from batch=4: +5% .. +76%) Q2_K, Q6_K : MMVQ <= 5 (MMQ wins from batch=6: +8% .. +35%) others : MMVQ <= 8 (legacy & IQ regress under MMQ; unchanged) Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold for A/B testing. Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct, llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps. Full table in PR description. Selected pp512 throughput (tok/s, ub=8): Q4_K_S: 559 -> 940 (+68%) Q5_K_S: 503 -> 884 (+76%) Q3_K_S: 629 -> 879 (+40%) Q2_K : 615 -> 809 (+32%) Q6_K : 582 -> 776 (+33%) Selected pp512 throughput (tok/s, ub=4): Q4_K_S: 444 -> 480 (+ 8%) Q4_0 : 682 -> 685 (+ 0%) (no regression - retains MMVQ) IQ4_XS: 706 -> 698 (- 1%) (no regression - retains MMVQ) * CUDA: address review — inline MMVQ batch table, drop env hatch & doc block * tune kernel selection logic for CDNA1 --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* origin/master: (32 commits) hexagon: basic/generic op fusion support and RMS_NORM+MUL fusion (ggml-org#23835) mtmd-debug: add color and rainbow mode (ggml-org#23829) mtmd: fix gemma 4 projector pre_norm (ggml-org#23822) opencl: move backend info printing into its own function (ggml-org#23702) ci : run ui publish on ubuntu-slim (ggml-org#23818) ui: fix audio and video modality detection (ggml-org#23756) ci : releases use Github-hosted builds for the UI (ggml-org#23823) app : improve help output (ggml-org#23805) mtmd: n_head_kv defaults to n_head (ggml-org#23782) mtmd: fix gemma 4 audio rms norm eps (ggml-org#23815) ci : change Vulkan builds to Release to reduce ccache (ggml-org#23820) arg: Add LLAMA_ARG_API_KEY_FILE environment variable for --api-key-file (ggml-org#23167) test-llama-archs: fix table format [no release] (ggml-org#23810) ggml: auto apply iGPU flag CUDA/HIP if integrated device (ggml-org#23007) mmvq Optim: add MMVQ_PARAMETERS_TURING(mmvq_parameter_table_id) for … (ggml-org#23729) CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware (ggml-org#23227) server: minor tweaks to use more cpp features (ggml-org#23785) hexagon: minor refresh for HMX FA and MM (ggml-org#23796) vulkan: fast path for walsh-hadamard transform (ggml-org#23687) chat : add Granite 4.1 chat template (ggml-org#23518) ...
…ml-org#23227) * CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8) to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs substantially by quant family because the per-row GEMV cost is dominated by dequantisation, not the dot-product itself: K-quants pay a heavier super-block decode and so MMQ wins sooner; legacy and IQ quants have lean decode and stay ahead until the batch fully populates an MFMA tile. This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool, mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant thresholds on amd_mfma_available(cc): Q3_K, Q4_K, Q5_K : MMVQ <= 3 (MMQ wins from batch=4: +5% .. +76%) Q2_K, Q6_K : MMVQ <= 5 (MMQ wins from batch=6: +8% .. +35%) others : MMVQ <= 8 (legacy & IQ regress under MMQ; unchanged) Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold for A/B testing. Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct, llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps. Full table in PR description. Selected pp512 throughput (tok/s, ub=8): Q4_K_S: 559 -> 940 (+68%) Q5_K_S: 503 -> 884 (+76%) Q3_K_S: 629 -> 879 (+40%) Q2_K : 615 -> 809 (+32%) Q6_K : 582 -> 776 (+33%) Selected pp512 throughput (tok/s, ub=4): Q4_K_S: 444 -> 480 (+ 8%) Q4_0 : 682 -> 685 (+ 0%) (no regression - retains MMVQ) IQ4_XS: 706 -> 698 (- 1%) (no regression - retains MMVQ) * CUDA: address review — inline MMVQ batch table, drop env hatch & doc block * tune kernel selection logic for CDNA1 --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
…ml-org#23227) * CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8) to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs substantially by quant family because the per-row GEMV cost is dominated by dequantisation, not the dot-product itself: K-quants pay a heavier super-block decode and so MMQ wins sooner; legacy and IQ quants have lean decode and stay ahead until the batch fully populates an MFMA tile. This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool, mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant thresholds on amd_mfma_available(cc): Q3_K, Q4_K, Q5_K : MMVQ <= 3 (MMQ wins from batch=4: +5% .. +76%) Q2_K, Q6_K : MMVQ <= 5 (MMQ wins from batch=6: +8% .. +35%) others : MMVQ <= 8 (legacy & IQ regress under MMQ; unchanged) Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold for A/B testing. Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct, llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps. Full table in PR description. Selected pp512 throughput (tok/s, ub=8): Q4_K_S: 559 -> 940 (+68%) Q5_K_S: 503 -> 884 (+76%) Q3_K_S: 629 -> 879 (+40%) Q2_K : 615 -> 809 (+32%) Q6_K : 582 -> 776 (+33%) Selected pp512 throughput (tok/s, ub=4): Q4_K_S: 444 -> 480 (+ 8%) Q4_0 : 682 -> 685 (+ 0%) (no regression - retains MMVQ) IQ4_XS: 706 -> 698 (- 1%) (no regression - retains MMVQ) * CUDA: address review — inline MMVQ batch table, drop env hatch & doc block * tune kernel selection logic for CDNA1 --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
The dispatcher unconditionally prefers mul_mat_vec_q (per-row GEMV) over mul_mat_q (MFMA-tiled GEMM) for any quantized matmul with batch <= 8. On AMD CDNA the MFMA path is materially faster once the verify batch reaches 4, which is exactly where every form of speculative decoding lives (ngram, draft-model, native MTP).
This patch adds an AMD-MFMA-specific cap of 3, gated on amd_mfma_available(cc), with a GGML_CUDA_FORCE_MMVQ env-var escape hatch. Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA, CPU) are byte-identical.
Measured on MI250X (gfx90a, ROCm 7.2.1) with Qwen3.6-27B Q4_K_M:
llama-batched-bench, token-gen tok/s by batch:
B=1: 34.28 -> 34.14 (noise; same dispatcher path)
B=2: 50.97 -> 50.83 (noise)
B=3: 58.18 -> 58.30 (noise; last B where MMVQ stays)
B=4: 63.01 -> 64.76 (+3 %)
B=5: 63.94 -> 76.93 (+20 %)
B=6: 63.53 -> 88.29 (+39 %)
B=8: 66.84 -> 107.62 (+61 %)
llama-server, MTP --spec-draft-n-max sweep, end-to-end tok/s
on a 256-token deterministic completion:
n_max=2 (verify B=3): 47.24 -> 46.90 (noise)
n_max=3 (verify B=4): 42.87 -> 50.12 (+17 %)
n_max=4 (verify B=5): 39.68 -> 47.79 (+20 %)
n_max=5 (verify B=6): 36.43 -> 50.98 (+40 %)
llama-bench plain workloads (no-regression check):
pp512: 794.78 -> 795.01 t/s (+0.03 %)
tg128: 34.44 -> 34.02 t/s (-1.2 %, day-to-day noise)
rocprofv3 kernel trace of one MTP n=3 completion:
mul_mat_q (MFMA-tiled GEMM): 0 -> 8 955 calls
mul_mat_vec_q (per-row GEMV): 11 400 -> 2 445 calls
The baseline column above is reproduced from the same binary by exporting GGML_CUDA_FORCE_MMVQ=1.
Threshold value: empirically tuned at MMVQ_MAX_BATCH_SIZE_AMD_MFMA=3 on MI250X with Q4_K_M. Crossover sits between B=3 (where MMVQ wins ~13 %) and B=4 (where MMQ wins ~8 %). MI300X (CDNA3) has wider MFMA tiles so the crossover almost certainly moves earlier; threshold of 3 is therefore conservative-correct (we'd leave a few percent on the table at B=3 in the worst case, never regress below baseline).
Per-quant batch threshold sweep — MI250X / gfx90a
Setup: Llama-3.2-3B-Instruct,
llama-bench -fa 1 -n 0 -ub 1..8 -r 10, single GCD,ROCm 7.2.1,
GGML_HIP_GRAPHS=ONon both builds.Why per-quant?
A first attempt used a single global threshold of 3 (MMVQ → MMQ at batch ≥ 4) on
AMD MFMA hardware. The 20-quant sweep showed this regresses legacy and IQ quants
substantially while only K-quants benefit:
Cell = patched / baseline throughput. Bold = ≥ 5% deviation from parity.
Values < 1.00x are regressions vs master; > 1.00x are speedups from this PR.
Per-quant thresholds chosen
The PR ships these thresholds (only on
amd_mfma_available(cc)):Verification of refactored patch
After reshaping the dispatcher into per-quant thresholds, spot-checked four
representative quants (one per family) at the previously-pathological ubatches.
Numbers are tok/s, single GCD, MI250X.
K-quant wins are preserved; legacy and IQ regressions are eliminated.
Reproducing
Use
GGML_CUDA_FORCE_MMVQ=1to force the legacy global threshold on the patchedbuild for A/B testing.
Overview
Additional information
Requirements