[BACKEND] Support fp64 simt fma for common use and fp64 mma for SM80/SM90.#7310
Conversation
Modify numElemsPerVec from `32 / bitwidth` to `bitwidth < 32 ? 32 / bitwidth`. Support `TensorCoreType::FP64_FP64_FP64_FP64` in MMAv2.cpp. Check `supportF64mma` in the MMA selection flow.
…2 / bitwidth, 1) Calculate vars like numElementPerReg in a clearer way.
Restore supportMMA and getMMAVersionSafe to their original states. Move fp64 MMA check to BlockedToMMA::matchAndRewrite.
…in of current device
|
Just #7310 (comment) is missing |
1. support Hopper f64 mma m16n8k16 as mmav2.2 2. use numVecK loop to unify some logic of mmav2 across different tileBitWidthK 3. removed the previous restriction of f64 dot3d. 4. more testcases, use generic method to detect shared_mem_avail
lezcano
left a comment
There was a problem hiding this comment.
Oh, I see you went on to use the Hopper specific mma intrinsics. Thomas and I meant that you could simply use the m8n8k4 variant in Hopper as well, but this is an even better effort.
I left a few comments on post-Hopper architectures.
…width == 64, may fix the gluon tutorial test failure.
ThomasRaoux
left a comment
There was a problem hiding this comment.
LGTM
Let's wait for @lezcano to do one more round of review as well
| tileSize.push_back(1); | ||
| } | ||
| // warpSizeK * (warpRepK * VecBitWidth) | ||
| auto tileBitWidthK = (isAmpere() && bitwidth == 64) ? (4 * 256) : (4 * 64); |
There was a problem hiding this comment.
I'm a bit confused about this last change. What's the context for the Ampere tile to be larger? I thought it was the smaller one.
There was a problem hiding this comment.
It used to be auto tileBitWidthK = isHopperF64() ? (4 * 256) : (4 * 64), but later it was designed to repeat SM80 fp64 m8n8k4 (2,1,4)=8 times, making it the same as SM90 fp64 mma m16n8k16, thereby achieving a unified mmav2 fp64 processing logic.
As previously discussed:
- SM90 fp64 m16n8k16 belongs to mmav2, with tileBitWidthK being 4*256
- SM80 fp64 m8n8k4 also belongs to mmav2, with tileBitWidthK being 4*64 (same as other mma instructions).
We needed to distinguish SM90 fp64 mma, so initially I designed it with versionMajor=2, versionMinor=2 and used isHopperF64() to identify this information, as #7310 (comment) shows.
@ThomasRaoux thought introducing a new versionMinor would be more confusing (#7310 (comment)), so I repeat SM80 fp64 m8n8k4 (2,1,4)=8 times, making it the same as SM90 fp64 mma m16n8k16, thereby achieving a unified mmav2 fp64 processing logic ( #7310 (comment)).
lezcano
left a comment
There was a problem hiding this comment.
Sounds good. Thank you for the great work!
|
Is there a reason behind this assert? Previously, we were running such dots & didn't run into any issues. assert((bitwidth != 64 || largeK == false) && |
@vwbaker MMAv2 has special handling logic for largeK as the following, I don't adapted to it. (I figure relying solely on repK should suffice) |
…SM90. (triton-lang#7310) This PR add support to fp64 simt fma for common use and fp64 mma for SM80/SM90. - Modify numElemsPerVec's calculation from `32 / bitwidth` to `std::max(bitwidth < 32, 1)` to support fp64 in many files - Support F64 MMA as new kind of MMAv2 - Common MMAv2 (e.g. `m16n8k16.row.col.f32.f16.f16.f32`): numVecK = 2 - (`tileBitWidthK = 2x4x32 = 256`, eq k16 x bitWidth(f16)) - SM80 FP64 MMA (`m8n8k4.row.col.f64.f64.f64.f64`): numVecK = 1 - (`tileBitWidthK = 1x4x64 = 256`, eq k4 * bitWidth(f64)) - Dispatch (2,1,4)=8 calls at once to simulate m16n8k16, similar to Turing's approach. - SM90 FP64 MMA (`m16n8k16.row.col.f64.f64.f64.f64`): numVecK = 4 - (`tileBitWidthK = 4x4x64 = 1024`, eq k16 * bitWidth(f64)) - Check support to F64 MMA in the MMA selection flow. - Add lit tests and python end-to-end test implements triton-lang#5483 Performance on A100. A100: ```bash python3 python/tutorials/03-matrix-multiplication_fp64.py ...... ✅ Triton and Torch match matmul-performance-fp16: M N K cuBLAS Triton 0 256.0 256.0 256.0 1.820444 1.638400 1 384.0 384.0 384.0 6.505412 4.253539 2 512.0 512.0 512.0 9.362286 7.489828 3 640.0 640.0 640.0 11.377778 11.906977 4 768.0 768.0 768.0 13.611323 10.532572 5 896.0 896.0 896.0 13.380267 14.483794 6 1024.0 1024.0 1024.0 15.087425 12.052598 7 1152.0 1152.0 1152.0 15.798857 14.782099 8 1280.0 1280.0 1280.0 14.628571 15.937743 9 1408.0 1408.0 1408.0 15.711170 13.494495 10 1536.0 1536.0 1536.0 15.353337 14.593584 11 1664.0 1664.0 1664.0 15.226585 13.717854 12 1792.0 1792.0 1792.0 14.372665 15.438769 13 1920.0 1920.0 1920.0 14.475393 13.768925 14 2048.0 2048.0 2048.0 15.006454 15.087425 15 2176.0 2176.0 2176.0 13.897547 15.073893 16 2304.0 2304.0 2304.0 16.239206 16.429073 17 2432.0 2432.0 2432.0 15.721580 15.153432 18 2560.0 2560.0 2560.0 14.246956 15.442035 19 2688.0 2688.0 2688.0 15.106753 14.657286 20 2816.0 2816.0 2816.0 15.255058 15.568163 21 2944.0 2944.0 2944.0 15.226408 16.206840 22 3072.0 3072.0 3072.0 15.095469 15.315960 23 3200.0 3200.0 3200.0 14.899313 16.004001 24 3328.0 3328.0 3328.0 15.007566 15.432218 25 3456.0 3456.0 3456.0 15.376991 16.040901 26 3584.0 3584.0 3584.0 14.894052 15.296936 27 3712.0 3712.0 3712.0 14.992847 16.001497 28 3840.0 3840.0 3840.0 14.634379 15.624753 29 3968.0 3968.0 3968.0 14.701679 16.282885 30 4096.0 4096.0 4096.0 15.184719 15.823830 ``` <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: kunzh <zhikun.wu@outlook.com>
…ons. (#10060) The fp64 MMA path now operates at native `m8n8k4` granularity, supporting any shape that is a multiple of 8×8×4, including the minimal 8×8×4 case. This is an extension of #7310 (The implementation was based on that PR) Tests passed on A100. ## Files Changed ### `lib/Dialect/TritonGPU/IR/Dialect.cpp` - `getRepForOperand`: Changed `tileBitWidthK` from `2 * 256` to `1 * 256` for fp64 (K-tile = 4). Changed `tileSize[M]` from hardcoded `16` to `8` for fp64. ### `lib/Dialect/TritonGPU/Transforms/Utility.cpp` - `mmaVersionToInstrShape`: Returns `instrShape[M] = 8` for fp64 (was always 16). ### `lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp` - `nvidiaDotToLinearLayout`: Uses `instrShape` from the MMA encoding for tile shape computation. K tile multiplier is 4 (not 8) when `instrM == 8`. ### `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp` - `getMmaRetType`: fp64 returns `struct{f64, f64}` (2 elements) instead of `struct{f64, f64, f64, f64}` (4 elements). - `callMmaAmpereFp64`: Extended: Now it is able to also emit a single `m8n8k4` instruction per call (single retArgs(2), aArgs(1), bArgs(1), cArgs(2)). - `numRegisters`: `{1, 1, 1}` for fp64 (was effectively `{2, 1, 2}`). - `numMmaRets`: 2 for fp64 (was 4). - `numCPackedElem`: 1 for fp64 (was incorrectly computed). - fc indexing formula: Uses `numMmaRets * numCPackedElem` instead of hardcoded `4`. ### `third_party/nvidia/backend/compiler.py` - `min_dot_size`: Added `elif lhs_bitwidth == 64: return (1, 1, 4)` to allow K=4 for fp64. ### `python/test/unit/language/test_core.py` - Added small fp64 test cases: `(8,8,4)`, `(8,8,8)`, `(16,8,4)`, `(8,8,16)` with `num_warps=1`. ### `test/Conversion/tritongpu_to_llvm.mlir` - Updated `f64_mma_cvt` test to use `instrShape = [8, 8]` matching the new fp64 encoding. ---- # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
…ons. (triton-lang#10060) The fp64 MMA path now operates at native `m8n8k4` granularity, supporting any shape that is a multiple of 8×8×4, including the minimal 8×8×4 case. This is an extension of triton-lang#7310 (The implementation was based on that PR) Tests passed on A100. ## Files Changed ### `lib/Dialect/TritonGPU/IR/Dialect.cpp` - `getRepForOperand`: Changed `tileBitWidthK` from `2 * 256` to `1 * 256` for fp64 (K-tile = 4). Changed `tileSize[M]` from hardcoded `16` to `8` for fp64. ### `lib/Dialect/TritonGPU/Transforms/Utility.cpp` - `mmaVersionToInstrShape`: Returns `instrShape[M] = 8` for fp64 (was always 16). ### `lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp` - `nvidiaDotToLinearLayout`: Uses `instrShape` from the MMA encoding for tile shape computation. K tile multiplier is 4 (not 8) when `instrM == 8`. ### `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp` - `getMmaRetType`: fp64 returns `struct{f64, f64}` (2 elements) instead of `struct{f64, f64, f64, f64}` (4 elements). - `callMmaAmpereFp64`: Extended: Now it is able to also emit a single `m8n8k4` instruction per call (single retArgs(2), aArgs(1), bArgs(1), cArgs(2)). - `numRegisters`: `{1, 1, 1}` for fp64 (was effectively `{2, 1, 2}`). - `numMmaRets`: 2 for fp64 (was 4). - `numCPackedElem`: 1 for fp64 (was incorrectly computed). - fc indexing formula: Uses `numMmaRets * numCPackedElem` instead of hardcoded `4`. ### `third_party/nvidia/backend/compiler.py` - `min_dot_size`: Added `elif lhs_bitwidth == 64: return (1, 1, 4)` to allow K=4 for fp64. ### `python/test/unit/language/test_core.py` - Added small fp64 test cases: `(8,8,4)`, `(8,8,8)`, `(16,8,4)`, `(8,8,16)` with `num_warps=1`. ### `test/Conversion/tritongpu_to_llvm.mlir` - Updated `f64_mma_cvt` test to use `instrShape = [8, 8]` matching the new fp64 encoding. ---- # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
This PR add support to fp64 simt fma for common use and fp64 mma for SM80/SM90.
32 / bitwidthtostd::max(bitwidth < 32, 1)to support fp64 in many filesm16n8k16.row.col.f32.f16.f16.f32): numVecK = 2tileBitWidthK = 2x4x32 = 256, eq k16 x bitWidth(f16))m8n8k4.row.col.f64.f64.f64.f64): numVecK = 1tileBitWidthK = 1x4x64 = 256, eq k4 * bitWidth(f64))m16n8k16.row.col.f64.f64.f64.f64): numVecK = 4tileBitWidthK = 4x4x64 = 1024, eq k16 * bitWidth(f64))implements #5483
Performance on A100.
A100:
python3 python/tutorials/03-matrix-multiplication_fp64.py ...... ✅ Triton and Torch match matmul-performance-fp16: M N K cuBLAS Triton 0 256.0 256.0 256.0 1.820444 1.638400 1 384.0 384.0 384.0 6.505412 4.253539 2 512.0 512.0 512.0 9.362286 7.489828 3 640.0 640.0 640.0 11.377778 11.906977 4 768.0 768.0 768.0 13.611323 10.532572 5 896.0 896.0 896.0 13.380267 14.483794 6 1024.0 1024.0 1024.0 15.087425 12.052598 7 1152.0 1152.0 1152.0 15.798857 14.782099 8 1280.0 1280.0 1280.0 14.628571 15.937743 9 1408.0 1408.0 1408.0 15.711170 13.494495 10 1536.0 1536.0 1536.0 15.353337 14.593584 11 1664.0 1664.0 1664.0 15.226585 13.717854 12 1792.0 1792.0 1792.0 14.372665 15.438769 13 1920.0 1920.0 1920.0 14.475393 13.768925 14 2048.0 2048.0 2048.0 15.006454 15.087425 15 2176.0 2176.0 2176.0 13.897547 15.073893 16 2304.0 2304.0 2304.0 16.239206 16.429073 17 2432.0 2432.0 2432.0 15.721580 15.153432 18 2560.0 2560.0 2560.0 14.246956 15.442035 19 2688.0 2688.0 2688.0 15.106753 14.657286 20 2816.0 2816.0 2816.0 15.255058 15.568163 21 2944.0 2944.0 2944.0 15.226408 16.206840 22 3072.0 3072.0 3072.0 15.095469 15.315960 23 3200.0 3200.0 3200.0 14.899313 16.004001 24 3328.0 3328.0 3328.0 15.007566 15.432218 25 3456.0 3456.0 3456.0 15.376991 16.040901 26 3584.0 3584.0 3584.0 14.894052 15.296936 27 3712.0 3712.0 3712.0 14.992847 16.001497 28 3840.0 3840.0 3840.0 14.634379 15.624753 29 3968.0 3968.0 3968.0 14.701679 16.282885 30 4096.0 4096.0 4096.0 15.184719 15.823830New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end testsFILL THIS IN.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)