Skip to content

[BACKEND] Support fp64 simt fma for common use and fp64 mma for SM80/SM90.#7310

Merged
lezcano merged 18 commits into
triton-lang:mainfrom
kzwrime:dot-fp64-mma
Jul 8, 2025
Merged

[BACKEND] Support fp64 simt fma for common use and fp64 mma for SM80/SM90.#7310
lezcano merged 18 commits into
triton-lang:mainfrom
kzwrime:dot-fp64-mma

Conversation

@kzwrime
Copy link
Copy Markdown
Contributor

@kzwrime kzwrime commented Jun 25, 2025

This PR add support to fp64 simt fma for common use and fp64 mma for SM80/SM90.

  • Modify numElemsPerVec's calculation from 32 / bitwidth to std::max(bitwidth < 32, 1) to support fp64 in many files
  • Support F64 MMA as new kind of MMAv2
    • Common MMAv2 (e.g. m16n8k16.row.col.f32.f16.f16.f32): numVecK = 2
      • (tileBitWidthK = 2x4x32 = 256, eq k16 x bitWidth(f16))
    • SM80 FP64 MMA (m8n8k4.row.col.f64.f64.f64.f64): numVecK = 1
      • (tileBitWidthK = 1x4x64 = 256, eq k4 * bitWidth(f64))
      • Dispatch (2,1,4)=8 calls at once to simulate m16n8k16, similar to Turing's approach.
    • SM90 FP64 MMA (m16n8k16.row.col.f64.f64.f64.f64): numVecK = 4
      • (tileBitWidthK = 4x4x64 = 1024, eq k16 * bitWidth(f64))
  • Check support to F64 MMA in the MMA selection flow.
  • Add lit tests and python end-to-end test

implements #5483

Performance on A100.

A100:

python3 python/tutorials/03-matrix-multiplication_fp64.py 
......
✅ Triton and Torch match
matmul-performance-fp16:
         M       N       K     cuBLAS     Triton
0    256.0   256.0   256.0   1.820444   1.638400
1    384.0   384.0   384.0   6.505412   4.253539
2    512.0   512.0   512.0   9.362286   7.489828
3    640.0   640.0   640.0  11.377778  11.906977
4    768.0   768.0   768.0  13.611323  10.532572
5    896.0   896.0   896.0  13.380267  14.483794
6   1024.0  1024.0  1024.0  15.087425  12.052598
7   1152.0  1152.0  1152.0  15.798857  14.782099
8   1280.0  1280.0  1280.0  14.628571  15.937743
9   1408.0  1408.0  1408.0  15.711170  13.494495
10  1536.0  1536.0  1536.0  15.353337  14.593584
11  1664.0  1664.0  1664.0  15.226585  13.717854
12  1792.0  1792.0  1792.0  14.372665  15.438769
13  1920.0  1920.0  1920.0  14.475393  13.768925
14  2048.0  2048.0  2048.0  15.006454  15.087425
15  2176.0  2176.0  2176.0  13.897547  15.073893
16  2304.0  2304.0  2304.0  16.239206  16.429073
17  2432.0  2432.0  2432.0  15.721580  15.153432
18  2560.0  2560.0  2560.0  14.246956  15.442035
19  2688.0  2688.0  2688.0  15.106753  14.657286
20  2816.0  2816.0  2816.0  15.255058  15.568163
21  2944.0  2944.0  2944.0  15.226408  16.206840
22  3072.0  3072.0  3072.0  15.095469  15.315960
23  3200.0  3200.0  3200.0  14.899313  16.004001
24  3328.0  3328.0  3328.0  15.007566  15.432218
25  3456.0  3456.0  3456.0  15.376991  16.040901
26  3584.0  3584.0  3584.0  14.894052  15.296936
27  3712.0  3712.0  3712.0  14.992847  16.001497
28  3840.0  3840.0  3840.0  14.634379  15.624753
29  3968.0  3968.0  3968.0  14.701679  16.282885
30  4096.0  4096.0  4096.0  15.184719  15.823830

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Modify numElemsPerVec from `32 / bitwidth` to `bitwidth < 32 ? 32 / bitwidth`.
Support `TensorCoreType::FP64_FP64_FP64_FP64` in MMAv2.cpp.
Check `supportF64mma` in the MMA selection flow.
Comment thread include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td Outdated
Comment thread lib/Analysis/Utility.cpp Outdated
Comment thread python/test/unit/language/test_matmul.py Outdated
Comment thread lib/Dialect/TritonGPU/IR/Dialect.cpp Outdated
kzwrime added 4 commits June 26, 2025 17:23
…2 / bitwidth, 1)

Calculate vars like numElementPerReg in a clearer way.
Restore supportMMA and getMMAVersionSafe to their original states.
Move fp64 MMA check to BlockedToMMA::matchAndRewrite.
Comment thread third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp Outdated
Comment thread third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp Outdated
Comment thread lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp Outdated
Comment thread third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp Outdated
@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Jun 27, 2025

Just #7310 (comment) is missing

Comment thread python/test/unit/language/test_matmul.py Outdated
1. support Hopper f64 mma m16n8k16 as mmav2.2
2. use numVecK loop to unify some logic of mmav2 across different tileBitWidthK
3. removed the previous restriction of f64 dot3d.
4. more testcases, use generic method to detect shared_mem_avail
@kzwrime kzwrime changed the title [BACKEND] Support fp64 simt fma for common use and fp64 mma for SM80. [BACKEND] Support fp64 simt fma for common use and fp64 mma for SM80/SM90. Jun 28, 2025
Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see you went on to use the Hopper specific mma intrinsics. Thomas and I meant that you could simply use the m8n8k4 variant in Hopper as well, but this is an even better effort.

I left a few comments on post-Hopper architectures.

Comment thread lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Comment thread lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp Outdated
Comment thread python/test/unit/language/test_matmul.py Outdated
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Let's wait for @lezcano to do one more round of review as well

tileSize.push_back(1);
}
// warpSizeK * (warpRepK * VecBitWidth)
auto tileBitWidthK = (isAmpere() && bitwidth == 64) ? (4 * 256) : (4 * 64);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about this last change. What's the context for the Ampere tile to be larger? I thought it was the smaller one.

Copy link
Copy Markdown
Contributor Author

@kzwrime kzwrime Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to be auto tileBitWidthK = isHopperF64() ? (4 * 256) : (4 * 64), but later it was designed to repeat SM80 fp64 m8n8k4 (2,1,4)=8 times, making it the same as SM90 fp64 mma m16n8k16, thereby achieving a unified mmav2 fp64 processing logic.

As previously discussed:

  • SM90 fp64 m16n8k16 belongs to mmav2, with tileBitWidthK being 4*256
  • SM80 fp64 m8n8k4 also belongs to mmav2, with tileBitWidthK being 4*64 (same as other mma instructions).

We needed to distinguish SM90 fp64 mma, so initially I designed it with versionMajor=2, versionMinor=2 and used isHopperF64() to identify this information, as #7310 (comment) shows.

@ThomasRaoux thought introducing a new versionMinor would be more confusing (#7310 (comment)), so I repeat SM80 fp64 m8n8k4 (2,1,4)=8 times, making it the same as SM90 fp64 mma m16n8k16, thereby achieving a unified mmav2 fp64 processing logic ( #7310 (comment)).

Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Thank you for the great work!

@lezcano lezcano merged commit 1c948e8 into triton-lang:main Jul 8, 2025
9 checks passed
@vwbaker
Copy link
Copy Markdown
Collaborator

vwbaker commented Jul 25, 2025

Is there a reason behind this assert? Previously, we were running such dots & didn't run into any issues.

assert((bitwidth != 64 || largeK == false) &&
"Currently fp64 don't support largeK MMA");

@kzwrime
Copy link
Copy Markdown
Contributor Author

kzwrime commented Jul 31, 2025

Is there a reason behind this assert? Previously, we were running such dots & didn't run into any issues.

assert((bitwidth != 64 || largeK == false) && "Currently fp64 don't support largeK MMA");

@vwbaker MMAv2 has special handling logic for largeK as the following, I don't adapted to it. (I figure relying solely on repK should suffice)

if (largeK) {
    // For layouts with a large K dimension, the original register layout needs
    // to be divided into multiple MMAs, where each MMA has contiguous 32 bits
    // along the K dimension per thread.
// ...

tie-pilot-qxw pushed a commit to tie-pilot-qxw/triton that referenced this pull request Aug 30, 2025
…SM90. (triton-lang#7310)

This PR add support to fp64 simt fma for common use and fp64 mma for
SM80/SM90.

- Modify numElemsPerVec's calculation from `32 / bitwidth` to
`std::max(bitwidth < 32, 1)` to support fp64 in many files
- Support F64 MMA as new kind of MMAv2
- Common MMAv2 (e.g. `m16n8k16.row.col.f32.f16.f16.f32`): numVecK = 2
        - (`tileBitWidthK = 2x4x32 = 256`, eq k16 x bitWidth(f16))
    - SM80 FP64 MMA (`m8n8k4.row.col.f64.f64.f64.f64`): numVecK = 1 
        - (`tileBitWidthK = 1x4x64 = 256`, eq k4 * bitWidth(f64))
- Dispatch (2,1,4)=8 calls at once to simulate m16n8k16, similar to
Turing's approach.
    - SM90 FP64 MMA (`m16n8k16.row.col.f64.f64.f64.f64`): numVecK = 4 
        - (`tileBitWidthK = 4x4x64 = 1024`, eq k16 * bitWidth(f64))
- Check support to F64 MMA in the MMA selection flow.
- Add lit tests and python end-to-end test

implements triton-lang#5483

Performance on A100.

A100:

```bash
python3 python/tutorials/03-matrix-multiplication_fp64.py 
......
✅ Triton and Torch match
matmul-performance-fp16:
         M       N       K     cuBLAS     Triton
0    256.0   256.0   256.0   1.820444   1.638400
1    384.0   384.0   384.0   6.505412   4.253539
2    512.0   512.0   512.0   9.362286   7.489828
3    640.0   640.0   640.0  11.377778  11.906977
4    768.0   768.0   768.0  13.611323  10.532572
5    896.0   896.0   896.0  13.380267  14.483794
6   1024.0  1024.0  1024.0  15.087425  12.052598
7   1152.0  1152.0  1152.0  15.798857  14.782099
8   1280.0  1280.0  1280.0  14.628571  15.937743
9   1408.0  1408.0  1408.0  15.711170  13.494495
10  1536.0  1536.0  1536.0  15.353337  14.593584
11  1664.0  1664.0  1664.0  15.226585  13.717854
12  1792.0  1792.0  1792.0  14.372665  15.438769
13  1920.0  1920.0  1920.0  14.475393  13.768925
14  2048.0  2048.0  2048.0  15.006454  15.087425
15  2176.0  2176.0  2176.0  13.897547  15.073893
16  2304.0  2304.0  2304.0  16.239206  16.429073
17  2432.0  2432.0  2432.0  15.721580  15.153432
18  2560.0  2560.0  2560.0  14.246956  15.442035
19  2688.0  2688.0  2688.0  15.106753  14.657286
20  2816.0  2816.0  2816.0  15.255058  15.568163
21  2944.0  2944.0  2944.0  15.226408  16.206840
22  3072.0  3072.0  3072.0  15.095469  15.315960
23  3200.0  3200.0  3200.0  14.899313  16.004001
24  3328.0  3328.0  3328.0  15.007566  15.432218
25  3456.0  3456.0  3456.0  15.376991  16.040901
26  3584.0  3584.0  3584.0  14.894052  15.296936
27  3712.0  3712.0  3712.0  14.992847  16.001497
28  3840.0  3840.0  3840.0  14.634379  15.624753
29  3968.0  3968.0  3968.0  14.701679  16.282885
30  4096.0  4096.0  4096.0  15.184719  15.823830
```

<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [ ] I have not added any `lit` tests.
- [x] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)

---------

Co-authored-by: kunzh <zhikun.wu@outlook.com>
Jokeren pushed a commit that referenced this pull request Apr 20, 2026
…ons. (#10060)

The fp64 MMA path now operates at native `m8n8k4` granularity,
supporting any shape that is a multiple of 8×8×4, including the minimal
8×8×4 case.

This is an extension of #7310
(The implementation was based on that PR)

Tests passed on A100.

## Files Changed

### `lib/Dialect/TritonGPU/IR/Dialect.cpp`
- `getRepForOperand`: Changed `tileBitWidthK` from `2 * 256` to `1 *
256` for fp64 (K-tile = 4). Changed `tileSize[M]` from hardcoded `16` to
`8` for fp64.

### `lib/Dialect/TritonGPU/Transforms/Utility.cpp`
- `mmaVersionToInstrShape`: Returns `instrShape[M] = 8` for fp64 (was
always 16).

### `lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp`
- `nvidiaDotToLinearLayout`: Uses `instrShape` from the MMA encoding for
tile shape computation. K tile multiplier is 4 (not 8) when `instrM ==
8`.

### `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp`
- `getMmaRetType`: fp64 returns `struct{f64, f64}` (2 elements) instead
of `struct{f64, f64, f64, f64}` (4 elements).
- `callMmaAmpereFp64`: Extended: Now it is able to also emit a single
`m8n8k4` instruction per call (single retArgs(2), aArgs(1), bArgs(1),
cArgs(2)).
- `numRegisters`: `{1, 1, 1}` for fp64 (was effectively `{2, 1, 2}`).
- `numMmaRets`: 2 for fp64 (was 4).
- `numCPackedElem`: 1 for fp64 (was incorrectly computed).
- fc indexing formula: Uses `numMmaRets * numCPackedElem` instead of
hardcoded `4`.

### `third_party/nvidia/backend/compiler.py`
- `min_dot_size`: Added `elif lhs_bitwidth == 64: return (1, 1, 4)` to
allow K=4 for fp64.

### `python/test/unit/language/test_core.py`
- Added small fp64 test cases: `(8,8,4)`, `(8,8,8)`, `(16,8,4)`,
`(8,8,16)` with `num_warps=1`.

### `test/Conversion/tritongpu_to_llvm.mlir`
- Updated `f64_mma_cvt` test to use `instrShape = [8, 8]` matching the
new fp64 encoding.

----


# New contributor declaration
- [X] I am not making a trivial change, such as fixing a typo in a
comment.

- [X] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [X] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [X] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
bingyizh233 pushed a commit to bingyizh233/triton that referenced this pull request Apr 20, 2026
…ons. (triton-lang#10060)

The fp64 MMA path now operates at native `m8n8k4` granularity,
supporting any shape that is a multiple of 8×8×4, including the minimal
8×8×4 case.

This is an extension of triton-lang#7310
(The implementation was based on that PR)

Tests passed on A100.

## Files Changed

### `lib/Dialect/TritonGPU/IR/Dialect.cpp`
- `getRepForOperand`: Changed `tileBitWidthK` from `2 * 256` to `1 *
256` for fp64 (K-tile = 4). Changed `tileSize[M]` from hardcoded `16` to
`8` for fp64.

### `lib/Dialect/TritonGPU/Transforms/Utility.cpp`
- `mmaVersionToInstrShape`: Returns `instrShape[M] = 8` for fp64 (was
always 16).

### `lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp`
- `nvidiaDotToLinearLayout`: Uses `instrShape` from the MMA encoding for
tile shape computation. K tile multiplier is 4 (not 8) when `instrM ==
8`.

### `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp`
- `getMmaRetType`: fp64 returns `struct{f64, f64}` (2 elements) instead
of `struct{f64, f64, f64, f64}` (4 elements).
- `callMmaAmpereFp64`: Extended: Now it is able to also emit a single
`m8n8k4` instruction per call (single retArgs(2), aArgs(1), bArgs(1),
cArgs(2)).
- `numRegisters`: `{1, 1, 1}` for fp64 (was effectively `{2, 1, 2}`).
- `numMmaRets`: 2 for fp64 (was 4).
- `numCPackedElem`: 1 for fp64 (was incorrectly computed).
- fc indexing formula: Uses `numMmaRets * numCPackedElem` instead of
hardcoded `4`.

### `third_party/nvidia/backend/compiler.py`
- `min_dot_size`: Added `elif lhs_bitwidth == 64: return (1, 1, 4)` to
allow K=4 for fp64.

### `python/test/unit/language/test_core.py`
- Added small fp64 test cases: `(8,8,4)`, `(8,8,8)`, `(16,8,4)`,
`(8,8,16)` with `num_warps=1`.

### `test/Conversion/tritongpu_to_llvm.mlir`
- Updated `f64_mma_cvt` test to use `instrShape = [8, 8]` matching the
new fp64 encoding.

----


# New contributor declaration
- [X] I am not making a trivial change, such as fixing a typo in a
comment.

- [X] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [X] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [X] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants