Skip to content

Conversation

@jataylo
Copy link

@jataylo jataylo commented Mar 27, 2025

For PT usage

joviliast and others added 19 commits March 27, 2025 06:37
In the case of 16 bit floats operands for tt::AtomicRMWOp, construct
only one LLVM::AtomicRMWOp but use vector of elements.
Such approach allows to generate packed intrinsics and process 2
elements at once.
Added a lit test for f16 vectorized case.

(cherry picked from commit 78c8054)
(cherry picked from commit 4d70942)
(cherry picked from commit 2f8aacc)
(cherry picked from commit 86a2ac7)
(cherry picked from commit 4c7d56e)
(cherry picked from commit 0529343)
This commit adds support for warp-level reduction
with DPP instructions, which can improve performance.

See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/

(cherry picked from commit 21119e3)
(cherry picked from commit d0142d3)
(cherry picked from commit 9f2b69b)
TritonAMDGPUTransforms now depends on it.

(cherry picked from commit 0b443ce)
(cherry picked from commit 37cec47)
(cherry picked from commit 1ab334d)
This commit adds support for warp-level reduction
with DPP instructions, which can improve performance.

See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/

(cherry picked from commit 21119e3)
(cherry picked from commit ca8842c)
(cherry picked from commit 3a3902d)
…4935)

This PR adds more restrictions about when should we apply
the sched-load optimizations and un-revert
triton-lang#4823.

We will only apply the optimization when all of the following is
satisfied:
1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop
2. two `tt.load`s in the main loop
3. 2nd `tt.load` is ahead of the `tt.dot`
4. 1st user of 2nd `tt.load` is after the `tt.dot`
5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64

(cherry picked from commit 4f6f768)
(cherry picked from commit f6053a3)
(cherry picked from commit 72d1575)
…-lang#5009)

Allows for upcasting in DotOp encoding in RF.
This lowering path is not currently in use; pending
triton-lang#5003

(cherry picked from commit cfddb09)
(cherry picked from commit f8c2c30)
(cherry picked from commit 73ef337)
This commit adds initial support for scaled_dot with
mxfp8 LHS and fp8 RHS. It supports both mfma32
and mfma16 intrinsic variants.

Right now we are missing software emulation for
`Float8E4M3FN` type, so this only enables for
`Float8E5M2`.

(cherry picked from commit 3549db8)
(cherry picked from commit efe0ec4)
(cherry picked from commit 010fe45)
…lang#4996)

In the passing we also improve a few other things:
- Now `scaled_dot` accepts both uint8/uint16 fp8/bf16 as inputs (before
you had to cast it to uint8, which was weird when extending it to bf16).
- Add `scaled_dot` to the docs and improve the docs overall (have not
render them, might need a few further tweaks)

(cherry picked from commit 23c9ec1)
(cherry picked from commit 675758b)
(cherry picked from commit 4e04af0)
…n-lang#4991)

Specifically, it fixes problems when `srcLayout` and `dstLayout` have
different number of registers but the same number of not free registers.
We solved the problem by padding free registers to either `srcLayout` or
`dstLayout`, but this can be improved by fixing the `invertAndCompose`
function.

(cherry picked from commit 15c5e55)
(cherry picked from commit 6537eb6)
(cherry picked from commit 4ca5013)
…triton-lang#4951)

This PR removes the legacy `isMmaToDotShortcut` and its associated shortcut conversion.

(cherry picked from commit 1d5fdfe)
(cherry picked from commit fc6d96b)
(cherry picked from commit 9f67c54)
)

We also clean a bit `TritonGPU/IR/Dialect.cpp` using some auxiliary
functions to make the intentions a bit clearer.

We add a few asserts in the `LinearLayoutConversion` to make sure it's
clear why we do certain things here and there.

We also kill `getCvtOrder`, as it was not used anywhere

(cherry picked from commit 56584c4)
(cherry picked from commit 276d182)
(cherry picked from commit 72651c2)
…riton-lang#5055)

We use `getOrder` very liberally throughout the codebase, when we really
meant to use `getThreadOrder`. This is an issue with the input layout is
an
`DotOperand(mma(opIdx=1))`, where the thread order and the matrix order
are opposite.

Found this to be an issue when a PR changed the `getOrder` of
`DotOperand(Hopper)` to an incorrect one and CI still passed! The issue
here is that the LLVM lowering for wgmma and the LinearLayout does not
use `getOrder`, but there are many other subsystems do, and many
heuristics would be getting an incorrect order, and potentially be
disabled.

This is particularly problematic for `DotOperand(opIdx=1)` in nvidia
hardware, as `getThreadOrder` and `getOrder` are different!

While doing so we:
- Audit most (all?) the calls to `getOrder(dotOperand)`. It turns out
that most of them really meant `getThreadOrder`
- Fix the ordering methods of `SliceEncodingAttr` to be consistent
- Move the implementation of `getWarpOrder` to the Attr classes, because
of OOP

The test strategy was to add `llvm::report_fatal_error("Testing");`
within `getOrder(nvidiaMma)` and `getOrder(DotOperand(nvidiaMma))` and
triaging all errors that were raised in CI.

(cherry picked from commit 38a11b8)
(cherry picked from commit 8412154)
(cherry picked from commit a569c3e)
This commit relands triton-lang#4819
with the following fixes:

* Changed to a better way to mark opIdx for loads
* Replaced temlate-based `rewindUnaryOps` to use regular
  for-loops. The new way is more robust and can handle other
  unary ops automatically.
* Replaced `instr.sched.barriers` using the ones from
  `rocdl` dialect from the MLIR upstream
* Extended lit tests

(cherry picked from commit ee5876c)
(cherry picked from commit 8dd9226)
(cherry picked from commit aed3efc)
(cherry picked from commit f062540)
(cherry picked from commit ca75b5f)
(cherry picked from commit 98149dd)
This commit adds support for mxfp4 typed A tensor
for sacled dot in the AMD backend.

We moved the `convertMxfp4x2ToBf16x2` impl
from NVIDIA side to a common path to reuse.

(cherry picked from commit edc5c5c)
(cherry picked from commit ac9f0d0)
(cherry picked from commit c0710dc)
Two bugfixes following triton-lang#5009.

- When `BLOCK_M=64` and `num_warps > 4`, the order of warps for
DotOpEncoded tensor should be M-major instead of N-major, since WGMMA
expects the 4 warps in each warp group to be stacked along the M
dimension.
- Should use `mmaBitwidth` instead of `bitwidth` when calculating
`numRep` in `SharedToDotOperandMMAv2OrV3`. This was missed in a bad
rebase.

@lezcano I encountered these bugs when attempting to locally test the
[DotOp hoisting PR](triton-lang#5003)
after rebasing (they normally would be caught by `test_core.py` but that
path was not yet enabled in the last PR). With these fixes added, I was
able to successfully validate against pytorch.

(cherry picked from commit e82dfd9)
(cherry picked from commit 5287a68)
(cherry picked from commit 8d70247)
(cherry picked from commit 302de9d)
- Removed functions related to unpacking and packing I32 values.
- Updated utilities to handle conversion of mxfp4 values without
packing/unpacking I32.
- Move the register value ordering logic from the element-wise operation
lowering to the dot operation lowering.
- Use linear layout to handle conversions between almost all distributed
layouts.
- Clean up data loading and mma computation involving `repN`, `repK`,
and `repM`.

(cherry picked from commit 1cf7b1b)
(cherry picked from commit 376fe7e)
(cherry picked from commit 2141a4e)
(cherry picked from commit d0e4abc)
… for general and NVIDIA layouts (triton-lang#5089)

This partially reverts commit 38a11b8.
Supersedes triton-lang#5085

It also documents that we are implicitly choosing a way to tile a
full tensor depending on the layout. See
triton-lang#5085 (comment)

(cherry picked from commit 57643b3)
(cherry picked from commit ffb2032)
(cherry picked from commit a11c6be)
(cherry picked from commit 9d080b4)
@jataylo jataylo merged commit a34a79d into ROCm:release/internal/3.2.x Mar 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants