mmq.cu: tune mmq/rocblas switching for RDNA#18537
mmq.cu: tune mmq/rocblas switching for RDNA#18537JohannesGaessler merged 5 commits intoggml-org:masterfrom
Conversation
recover performance regression for ggml-org#17917
Beinsezii
left a comment
There was a problem hiding this comment.
Overall I think this should be satisfactory in all cases.
@IMbackK in #14949 you mentioned doing more comprehensive tests for CDNA in the future. Won't be 1:1 but possibly this could apply there too. At the very least I think the current CDNA code might be over-prioritizing rocblas with the default branch being false but I have no way to test this.
ggml/src/ggml-cuda/mmq.cu
Outdated
| // High expert counts almost always better on MMQ | ||
| // due to a large amount of graph splits | ||
| // https://github.com/ggml-org/llama.cpp/pull/18202 | ||
| if (n_experts >= 64) { | ||
| return true; | ||
| } |
There was a problem hiding this comment.
32 for GPT OSS 20B was also sort of a wash, but I'm assuming non-fp4 models will benefit more from the quant cases instead. There's not a lot of MOEs I can run fully in VRAM, so testing this was limited.
There was a problem hiding this comment.
Thinking more, since the default case is true anyways, this might not actually be needed. I think the reason it was added in #18202 is because the default case is false which leads to most quants running rocblas by default.
The worst case would be a highly sparse MoE running @ Q6_K or Q2_K so probably I can test this specifically using GPT 20B and Qwen 30B for 32 and 128 experts respectively this weekend.
ggml/src/ggml-cuda/mmq.cu
Outdated
| switch (type) { | ||
| // These quants are really bad on MMQ | ||
| case GGML_TYPE_Q2_K: | ||
| case GGML_TYPE_Q6_K: | ||
| // These quants are usually worse but not always | ||
| case GGML_TYPE_IQ2_XS: | ||
| case GGML_TYPE_IQ2_S: | ||
| return ne11 <= 128; | ||
| default: | ||
| return true; | ||
| } |
There was a problem hiding this comment.
Reason for the elaborate dataset is I was expecting this switch to be a lot larger, but I think most of the problem is just that Q2_K and Q6_K perform unusually poor on MMQ, typically by >30%
|
One particularly fun case is on models with heterogeneous mmq/rocblas layers like https://huggingface.co/Beinsezii/Mistral-Small-3.2-24B-Instruct-2506-Q6F-Q8A-GGUF this branch is faster than either forced mmq / cublas separately. |
|
Added compare-llama-bench formatted version of results per request. I did not include quants outside of cuda mmq scope. There's a variety of models demonstrated but they do not include Llama3-8B specifically. Can be added if needed. |
|
Some things to consider here:
This pr is Assuming that the results for gfx11 on tensile has equivalent performance to gf12 on tensilelite which i find unlikely from CDNA experience. CDNA1/2 have well tuned kernels in tensile while RDNA GPUs are less well tuned there, AFAIK the tensilelite kernels are reasonably well tuned for gfx11 (only the large register file versions like gfx1100) and gfx12. This is why the rocblas path wins more often on cdna. Since this PR is comparing mmq against gfx11 on tensile which is a known slower path, it might make sense to compare against tensilelite too, or to switch gfx11 to tensilelite like gfx12 and then choose mmq where this is slower. Compared to cdna on rDNA you also have to contend that the devices often have very little vram, while cdna can't have less then 32gb (mmq saves dequant buffers) |
|
Regardless, this PR is an overall improvement and the above can be investigated at a later time |
|
Let me be frank here: As of right now I don't have the bandwidth to be micromanaging the AMD kernel selection logic to this degree. The kernel generator selection is something that rocBLAS should be deciding automatically. Alternatively I would like to outsource the corresponding logic in llama.cpp/ggml to someone else who would be available to maintain it long-term. |
|
I already see this bit as something in my area, i also have a gfx11 (gfx1100) device on the way so no need for you to do anything, aside from merging the reviewed PRs since I lost write access. If the selection logic gets long, which it might if we also start selecting between tensile and tensilelite, we can also move this logic go ggml-hip to get it out out of your way. Anyhow this is soemthing for later. |
|
And yeah the whole situation with the terrible internal kernel selection in rocblas and hipblaslt is quite annoying, the selection of the "correct" kernel generator is only the tip of the iceberg in this regard. |
Given the nonlinear quants are typically <+10%, if we really needed to this PR could be simplified to just if (amd_wmma_available(cc)) {
return !(ne11 >= 128 && (type == GGML_TYPE_Q2_K || type == GGML_TYPE_Q6_K));
}as I think switching on these two at the least is mandatory since all improvements are >+30% |
|
|
The logic as it is in this PR is completely fine. My issue specifically has to do with potentially having to juggle multiple different rocBLAS versions with multiple optional environment variables. I think going forward I will review and maintain the kernel selection logic to align with the newest ROCm version at the default settings, and any other setup will need to be maintained by someone else. |
That's pretty much what I went for here. ROCm 6.4 had slightly faster PP for me back a few months ago but I'm not interested in redownloading 50 gigs of outdated libs just to see if there's maybe one or two more cases where rocblas is slightly faster than mmq on the old version. HIP_GRAPHS I found made effectively 0% difference for pp, but for small/sparse models its about +10% tg so I always have it on. I think as it stands this should only need re-visiting with major revisions to mmq or rocblas, and with significantly less benchmarking needed at that. Since RDNA support in MMQ is pretty fresh, I'm hoping that ideally as more work is done there it eventually just eats up these last 4 cases and we delete this whole block anyways. |
for testing that yes, but you can hint rocblas what kernels to use via its c api, which we could use to make it use the tensilelite kernels where they are faster. Sometimes they are alot faster (like 50%+) anyhow this is something for another time. |
on gfx1100 looks like Probably I'll just do a re-measure if someone ends up adding the kernel hinting through C |
If its the default now there is something we have to do to get best perfromance: unlike tensile, tensilelite supports V_WMMA_F32_16X16X16_F16. Currently because tensile only issues V_WMMA_F16_16X16X16_F16 we accumulate at fp16 and upconvert after, which is stupid from a performance perspective on this hardware and causes extra issues with overflow. Anyhow as stated i recently ordered a gfx11 device and will get around to doing some optimization work on it in the near future. |
RDNA4 performance
On RDNA4 this PR is consistently making the performance worse so |
Probably better to just gate the new block behind a RDNA3 check than make RDNA4 true imo. I'm surprised how much worse it is actually, the 3/4 mmq impls looked similar I wonder if this is rocblas problems. |
|
Did |
That was my initial assumption too, then I got unexpected reports about performance regressions. But I think those were related to rocBLAS versions and environment variables. On Monday I should be able to check RDNA 3.5 performance to make sure. |
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
|
My current status is this: AMD sent me an Asus ROG Flow Z13 machine where my Linux installation is currently broken due to what seems to be hardware-specific issues. I don't know how long it will take me to sort this out so for now I would simply merge this PR as-is and deal with any issues that arise after the fact. |
@JohannesGaessler Thank you so much. |
|
My situation is as follows: on my desktop machine I have installed Manjaro with KDE which is working without issue. On the ROG Flow Z13 I received from AMD I also installed Manjaro in order to deal with only a single Linux distribution. With Manjaro v25 it was working as intended after installing the |
|
I am also using KDE on a laptop with CachyOS, which is a variant of Arch Linux (like Manjaro, which also derives from Arch). If you provide more details about the bootstrap process, I can look into it, even though graphics drivers aren't my expertise. |
strix halo performance
|
I think we can use the configs from #18442 for RDNA4 for now as that doesn't make the performance worse RDNA4 Performance for https://github.com//pull/18442
|
|
If you want to make further changes to the default kernel selection logic as it exists on master, make a new PR relative to master and post a benchmark for how this affects the performance with the newest ROCm version and without any environment variables being set. Any logic that is written for a specific ROCm version or for specific environment variables needs a corresponding check in the code. |
Continuing from #18442 I applied similar benchmarking as #14949 and #18202 to try and minimize bad cases on RDNA while keeping the logic simple.
TL;DR
Over an average of all models on https://huggingface.co/Beinsezii/mmq_test over a variety of µbatch sizes I have
Where 100% is a theoretical maximum if it were to optimally choose mmq or rocblas in each case.
Current master is functionally equivalent to
mmqfor this and all other benchmarks.Benchmarks
compare-llama-bench affected quants
compare-llama-bench full
mmq/blas/tuned breakdown
Edge Cases
When excluding the 1B model which is noisy, there's exactly two outliers
Both of which are on bs=128, and both of which quickly flip < 128. If you wanted to fudge this case, you could probably do something like
to get avg tuned% > 99, but that might be considered splitting hairs.
Testing Setup
100% of my testing was on GFX1100, ROCm 7.1.1, compile flags
and forced mmq/cublas as appropriate for measuring.
I do not own any RDNA3.5 or RDNA4 hardware. I'm assuming RDNA3.5 will behave pretty much the same, but since RDNA4 has some implementation differences in MMQ, it may be worth for someone to re-measure in the future.
Methodology
The raw data for every combination of model / batch / backend can be viewed at measurements.csv which was generated by the scripts on huggingface.
Different from the other PRs, I've made the baseline MMQ as it seems to better handle most cases.
In general, I put little weight on the 1B results as it's extremely noisy, even with hip graphs. For cases that were a wash between µbatch sizes like Q4_K and Q5_K, I simply preferred MMQ.