# [BACKEND] FP64 on Hopper: add support for m16n8k4 path for sm90+#10313
# [BACKEND] FP64 on Hopper: add support for m16n8k4 path for sm90+#10313mwichro wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Originally, I was very supportive of the FP64 path, but now the update seems much larger than I expected. Adding too many if/else branches would make the code harder to maintain, especially when different architectures choose different instruction shapes and FP64 is relatively low priority. Can you share with us your plan for the FP64 functionality and what else still needs to be pushed upstream? We may want to reach an agreement on what essential features should be there before working on more PRs
cc @lezcano
|
The point of this PR is to gain as much performance as possible while keeping the changes reasonably small. At the same time, I wanted to make it easy to extend the implementation to larger tiles. From adding support for more instructions, I was not able to get any performance gains. Fully closing the gap would probably require deeper changes, so I opened this PR as at least a reasonable checkpoint. I am not able to run the profiler on the university H100/H200 servers; I can't even confirm that my reasoning about the performance gap is valid.
That is a good question. For me, those changes seem to be enough; the gap to cuBLAS is reasonably small, but on the other hand, it is still 5%, so there is definitely something to gain (and this PR feels like unfinished work for me). So, let me return the question: Where is the sweet spot concerning FP64 performance? As for my applications, tile Mx8x4 is perfect. |
I believe support FP64 is mainly for the community benefit. I think getting 100% or higher perf would be beneficial but I haven't looked into the details. If you are interested in further exploring what steps have to be done to get there with a brief report, this would be better rather than submitting separated PRs. We can leave this PR open for now without merging and come back later if you figure out that we don't have a good way to achieve 100% perf without significant updates. |
|
Sounds like a plan. I already investigated why there is a gap before I opened this PR, that is why I decided to post it before going further. Would you mind providing some comments on what I think is causing the gap? My reasoningBy looking at PTX, with some help from Claude: B operand at This means just adding Also PTX dumpcuBLASTTGIR
|
|
Thanks for the preliminary investigation. What I want actually is comparing triton FP64 and cublas FP64's NCU profiling results and assembly code. I don't think the analysis from Claude is useful without checking what cublas is doing. |
Summary
Triton's FP64 dot lowering previously emitted only
m8n8k4.f64, leaving ~35% of the H100 perfomance on the table. This PR adds them16n8k4.f64MMA shape for sm_90+.m8n8k4.f64is kept as sm_80 fallback.Performance
Performance measured with script provided in #10060
The autotune settings need to be updated:
num_wrap=2m16n8k8andm16n8k16I tried fully closing the gap by implementing
m16n8k8andm16n8k16support, but there was no performance gain. So I am not including those changes, but the code should be extendible to add support for those:pickFp64MmaKis the single point to extend with operand-K-aware dispatch (and an env-var override).instrShapealready carries K explicitly, so adding K=8/16 shapes is additive on the encoding side.getMmaTypeDotdispatches oninstrShape[M]; One will branch Further oninstrShape.back().callMmaAmpereFp64M16K4is a single-K helper; a TODO at its definition lists the regs/thread shape for k=8 (A=4, B=2) and k=16 (A=8, B=4) so the generalization is a localized refactor.Remaining gap to cuBLAS
I've spent some time trying to close the gap, but it turns out is needs a deeper change.
The B operand for
m16n8k{8,16}cannot be vectorized with the current shared-encoding choice (N-contiguous shared layout vs K-adjacent register fragment), so the BLOCK_K ≥ 32 autotune configs regress and the autotuner sticks with BLOCK_K = 16 (= m16n8k4).Declaration
rules.
pre-commit run --from-ref origin/main --to-ref HEAD./testforlittests/unittestfor C++ tests/python/testfor end-to-end testslittests I have added follow these best practices,