Skip to content

[LAYOUTS] Generalise HoistLayoutConversion to work with arbitrary layouts and chains of ops#5673

Merged
lezcano merged 2 commits into
mainfrom
remat_dot
Jan 31, 2025
Merged

[LAYOUTS] Generalise HoistLayoutConversion to work with arbitrary layouts and chains of ops#5673
lezcano merged 2 commits into
mainfrom
remat_dot

Conversation

@lezcano
Copy link
Copy Markdown
Contributor

@lezcano lezcano commented Jan 22, 2025

We generalise HoistLayoutConversion to lift a given convert_layout dot_operand
above any chain of operations that do not require data movement. We
could totally generalise this in the future to lift it over other ops. We do
this as a first step to keep the code somewhat similar to the previous
one.

Regarding the previous limitations of canHoistDotOpEncV2 I did a bit of archeology:

We also add proper support for isPure for elementwise_inline_asm ops

On the location of the code, we just leave it in RemoveLayoutConversion.cpp to
take advantage of the rather generic implementation of rewriteSlice. We could totally
move this pass outside of remove-layout-conversion, as it's probably enough to run
it once. This code will go through further changes in the near future, so we'll assess this
then.

@lezcano lezcano requested a review from ptillet as a code owner January 22, 2025 17:52
@lezcano lezcano marked this pull request as draft January 22, 2025 23:30
@lezcano lezcano force-pushed the remat_dot branch 2 times, most recently from 1eed10c to a77a54a Compare January 24, 2025 15:00
@lezcano lezcano changed the title [WIP][LAYOUTS] Remove HoistLayoutConversion in favour of backwardsRemat [LAYOUTS] Generalise HoistLayoutConversion to work with arbitrary layouts and chains of ops Jan 27, 2025
Comment on lines +3278 to +3421
#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
tt.func @dot_op_hoisted_to_load_with_unsupported_op_and_initializer_above_slice(
Copy link
Copy Markdown
Contributor Author

@lezcano lezcano Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all codemovement + adding this test that was proposed but not merged in #5349 (comment)
as we now hoist everything as expected

Comment thread test/TritonGPU/combine.mlir Outdated
@lezcano lezcano marked this pull request as ready for review January 27, 2025 10:23
@lezcano lezcano requested a review from ThomasRaoux January 27, 2025 10:32
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. One comment is I'm not sure if we need the speculatively part but that's kind of a detail

Comment thread lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp Outdated
// This could be generalised if necessary
if (!loadOp) {
auto op = v.getDefiningOp();
if (isa<arith::ConstantOp>(op) || noDataMovement(op)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ConstantOp just be put inside noDataMovement?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yeah, but leaving it as-is for now because it currently works and we are probably going to refactor this pass in the near future so whatever.

@lezcano lezcano enabled auto-merge (squash) January 29, 2025 10:07
We now support all layouts as LL, and reductions support any layout as
input. As such, at least in theory, we should be able to propagate
layouts freely, even DotOperands, similar to what we do with other
layouts.

This PR is a bit tentative. Let's see if anything interesting breaks
@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Jan 31, 2025

I had to add a fix for MemDescTransOp to behave like TransOp when it comes to comparing LLs and legacy layouts as equal if they are structurally equal. Thank you @Mogball who helped me with the MLIR rough parts in #5747, which makes this transition a breeze (inheritance is so nice when it works)

@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Jan 31, 2025

WTF, the A100 tests failed:

8 workers [16973 items]
<unknown>:0: error: 'tt.trans' op inferred type(s) 'tensor<16x128xbf16, #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [0, [32](https://github.com/triton-lang/triton/actions/runs/13071386381/job/36473639009?pr=5673#step:11:33)], [0, 64]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 0], [0, 16]], block = []}>>' are incompatible with return type(s) of operation 'tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>, kWidth = 2}>>'
<unknown>:0: error: 'tt.trans' op failed to infer returned types
<unknown>:0: error: 'tt.trans' op inferred type(s) 'tensor<16x64xbf16, #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [0, 32]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 0], [0, 16]], block = []}>>' are incompatible with return type(s) of operation 'tensor<64x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>, kWidth = 2}>>'
<unknown>:0: error: 'tt.trans' op failed to infer returned types
<unknown>:0: error: 'tt.trans' op inferred type(s) 'tensor<16x128xbf16, #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [0, 32], [0, 64]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 0], [0, 16]], block = []}>>' are incompatible with return type(s) of operation 'tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>, kWidth = 2}>>'
<unknown>:0: error: 'tt.trans' op failed to infer returned types
<unknown>:0: error: 'tt.trans' op inferred type(s) 'tensor<16x128xbf16, #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [0, 32], [0, 64]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 0], [0, 16]], block = []}>>' are incompatible with return type(s) of operation 'tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>, kWidth = 2}>>'
<unknown>:0: error: 'tt.trans' op failed to infer returned types
<unknown>:0: error: 'tt.trans' op inferred type(s) 'tensor<16x64xbf16, #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [0, 32]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 0], [0, 16]], block = []}>>' are incompatible with return type(s) of operation 'tensor<64x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>, kWidth = 2}>>'
<unknown>:0: error: 'tt.trans' op failed to infer returned types
<unknown>:0: error: 'tt.trans' op inferred type(s) 'tensor<16x128xbf16, #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [0, 32], [0, 64]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 0], [0, 16]], block = []}>>' are incompatible with return type(s) of operation 'tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>, kWidth = 2}>>'
<unknown>:0: error: 'tt.trans' op failed to infer returned types
<unknown>:0: error: operation scheduled before its operands

but CI marked it as green?

@lezcano lezcano merged commit b3dcc32 into main Jan 31, 2025
@lezcano lezcano deleted the remat_dot branch January 31, 2025 11:15
peterbell10 added a commit that referenced this pull request Jan 31, 2025
…rary layouts and chains of ops (#5673)"

This reverts PR #5673
lezcano pushed a commit that referenced this pull request Jan 31, 2025
#5776)

This reverts PR #5673

This broke the tests on A100, even though CI was green. The CI issue
will be resolved by #5775
ThomasRaoux pushed a commit that referenced this pull request Jan 31, 2025
I had used `.ONESHELL` to allow `cd` to effect the other commands, but
it seems this also prevents the error status from propagating from
anything but the last command in a rule.

e.g. see
#5673 (comment)
@lezcano lezcano restored the remat_dot branch February 1, 2025 12:05
@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Feb 1, 2025

Reopened at #5788

AlexAUT pushed a commit to AlexAUT/triton that referenced this pull request Feb 6, 2025
…outs and chains of ops (triton-lang#5673)

We generalise `HoistLayoutConversion` to lift a given `convert_layout
dot_operand`
above any chain of operations that do not require data movement. We
could totally generalise this in the future to lift it over other ops.
We do
this as a first step to keep the code somewhat similar to the previous
one.

Regarding the previous limitations of `canHoistDotOpEncV2` I did a bit
of archeology:
- The "don't hoist past select" was added in this issue
triton-lang#2857. I run the repro and
with the recent layout fixes, it now passes.
- The TruncOps being skipped comes from
triton-lang#2181. I think this is
related with the hack that was removed in
triton-lang#5044, so now it should work
- Same same for the `UIToFpOp`, this is now supported after triton-lang#5044
- Mixed dtype hack is not necessary either as now everything works as
expected with the `convert_layout` rework.

We also add proper support for `isPure` for `elementwise_inline_asm` ops

On the location of the code, we just leave it in
`RemoveLayoutConversion.cpp` to
take advantage of the rather generic implementation of `rewriteSlice`.
We could totally
move this pass outside of `remove-layout-conversion`, as it's probably
enough to run
it once. This code will go through further changes in the near future,
so we'll assess this
then.
AlexAUT pushed a commit to AlexAUT/triton that referenced this pull request Feb 6, 2025
triton-lang#5776)

This reverts PR triton-lang#5673

This broke the tests on A100, even though CI was green. The CI issue
will be resolved by triton-lang#5775
AlexAUT pushed a commit to AlexAUT/triton that referenced this pull request Feb 6, 2025
I had used `.ONESHELL` to allow `cd` to effect the other commands, but
it seems this also prevents the error status from propagating from
anything but the last command in a rule.

e.g. see
triton-lang#5673 (comment)
makslevental pushed a commit to makslevental/triton that referenced this pull request Feb 19, 2025
…outs and chains of ops (triton-lang#5673)

We generalise `HoistLayoutConversion` to lift a given `convert_layout
dot_operand`
above any chain of operations that do not require data movement. We
could totally generalise this in the future to lift it over other ops.
We do
this as a first step to keep the code somewhat similar to the previous
one.

Regarding the previous limitations of `canHoistDotOpEncV2` I did a bit
of archeology:
- The "don't hoist past select" was added in this issue
triton-lang#2857. I run the repro and
with the recent layout fixes, it now passes.
- The TruncOps being skipped comes from
triton-lang#2181. I think this is
related with the hack that was removed in
triton-lang#5044, so now it should work
- Same same for the `UIToFpOp`, this is now supported after triton-lang#5044
- Mixed dtype hack is not necessary either as now everything works as
expected with the `convert_layout` rework.

We also add proper support for `isPure` for `elementwise_inline_asm` ops

On the location of the code, we just leave it in
`RemoveLayoutConversion.cpp` to
take advantage of the rather generic implementation of `rewriteSlice`.
We could totally
move this pass outside of `remove-layout-conversion`, as it's probably
enough to run
it once. This code will go through further changes in the near future,
so we'll assess this
then.
makslevental pushed a commit to makslevental/triton that referenced this pull request Feb 19, 2025
triton-lang#5776)

This reverts PR triton-lang#5673

This broke the tests on A100, even though CI was green. The CI issue
will be resolved by triton-lang#5775
makslevental pushed a commit to makslevental/triton that referenced this pull request Feb 19, 2025
I had used `.ONESHELL` to allow `cd` to effect the other commands, but
it seems this also prevents the error status from propagating from
anything but the last command in a rule.

e.g. see
triton-lang#5673 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants