-
Notifications
You must be signed in to change notification settings - Fork 19.9k
CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware #23227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+51
−0
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -271,6 +271,43 @@ int get_mmvq_mmid_max_batch(ggml_type type, int cc) { | |||||||||
| return MMVQ_MAX_BATCH_SIZE; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // On AMD MFMA hardware (CDNA), pick the per-quant batch threshold above which | ||||||||||
| // MMVQ should yield to the MMQ (MFMA-tiled GEMM) path. The crossover differs | ||||||||||
| // noticeably by quant family because the per-row GEMV cost is dominated by the | ||||||||||
| // dequantisation work, not the dot-product itself: K-quants pay a heavier | ||||||||||
| // per-row decode (super-block scales) and so MMQ wins sooner; legacy and IQ | ||||||||||
| // quants have lean decode and stay ahead until the batch is wide enough to | ||||||||||
| // fully populate an MFMA tile. | ||||||||||
| // | ||||||||||
| // Calibrated on MI250X with Llama-3.2-3B (pp512, ubatch 1..8, 10 reps each), | ||||||||||
| // across all 20 supported quant types. See PR description for the full table. | ||||||||||
| static int64_t mmvq_max_batch_amd_mfma(ggml_type type) { | ||||||||||
| switch (type) { | ||||||||||
| case GGML_TYPE_Q3_K: | ||||||||||
| case GGML_TYPE_Q4_K: | ||||||||||
| case GGML_TYPE_Q5_K: | ||||||||||
| return 3; // MMQ wins from batch=4 onward (+5% to +76%) | ||||||||||
| case GGML_TYPE_Q2_K: | ||||||||||
| case GGML_TYPE_Q6_K: | ||||||||||
| return 5; // MMQ wins from batch=6 onward (+8% to +35%) | ||||||||||
| default: | ||||||||||
| // Legacy (Q4_0/Q4_1/Q5_0/Q5_1/Q8_0) and IQ quants regress under MMQ | ||||||||||
| // up to batch=7, so keep the global threshold for them. | ||||||||||
| return MMVQ_MAX_BATCH_SIZE; | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11) { | ||||||||||
| static const bool force_mmvq = (getenv("GGML_CUDA_FORCE_MMVQ") != nullptr); | ||||||||||
| if (force_mmvq) { | ||||||||||
| return ne11 <= MMVQ_MAX_BATCH_SIZE; | ||||||||||
| } | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| if (amd_mfma_available(cc)) { | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| return ne11 <= mmvq_max_batch_amd_mfma(type); | ||||||||||
| } | ||||||||||
| return ne11 <= MMVQ_MAX_BATCH_SIZE; | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // Device constexpr: returns the max batch size for the current arch+type at compile time. | ||||||||||
| template <ggml_type type> | ||||||||||
| static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { | ||||||||||
|
|
||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,20 @@ | |||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Returns true if a quantized matmul of shape (..., ne11) on a device with | ||||||||||||||||||||||||||
| // compute capability `cc` should take the MMVQ (per-row GEMV) path. | ||||||||||||||||||||||||||
| // Returning false sends it to the MMQ path (batched GEMM, MFMA-tiled on CDNA). | ||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||
| // On AMD MFMA hardware (CDNA) the optimal batch threshold is quant-dependent: | ||||||||||||||||||||||||||
| // K-quants have a heavier per-row GEMV (block scales + super-block decode), so | ||||||||||||||||||||||||||
| // MFMA-tiled MMQ overtakes MMVQ at a smaller batch; legacy and IQ quants have | ||||||||||||||||||||||||||
| // lean GEMV kernels that stay ahead until the batch nearly fills an MFMA tile. | ||||||||||||||||||||||||||
| // Thresholds calibrated on MI250X with Llama-3.2-3B (pp512, ubatch 1..8) — see | ||||||||||||||||||||||||||
| // the PR description for the full sweep. | ||||||||||||||||||||||||||
| // | ||||||||||||||||||||||||||
| // Set GGML_CUDA_FORCE_MMVQ=1 to restore the original global threshold. | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
| bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, | ||||||||||||||||||||||||||
| // based on the quantization type and GPU architecture (compute capability). | ||||||||||||||||||||||||||
| int get_mmvq_mmid_max_batch(ggml_type type, int cc); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inline this function and remove the comments except for "tuned for CDNA2".