Skip to content

[AMD] Fix "keep Q tensor in VGPRS" optimization#4901

Merged
antiagainst merged 2 commits intotriton-lang:mainfrom
oplavsic:hoist_q_fix_upstream
Oct 14, 2024
Merged

[AMD] Fix "keep Q tensor in VGPRS" optimization#4901
antiagainst merged 2 commits intotriton-lang:mainfrom
oplavsic:hoist_q_fix_upstream

Conversation

@oplavsic
Copy link
Copy Markdown
Contributor

Adjust the placement of LDS writes and reads to immediately follow the
definition of their operands in case where LDS write is in the loop but it's operand is not. This is a heuristic for optimizing fused attention by hoisting Q tensor LDS read/write operations outside of the loop, as Q is a loop invariant and can be loaded once before entering the loop.

In the previous implementation, the heuristic incorrectly assumed that the operand of the LDS write had to be a load operation, which is unnecessary. Additionally, there was no explicit check to verify whether the LDS write was in the loop while its defining operand was not. This PR addresses both issues.

@oplavsic oplavsic force-pushed the hoist_q_fix_upstream branch from abae40e to 6bb2a0b Compare October 13, 2024 23:14
Copy link
Copy Markdown
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! A few comments inlined. I'm fine to land this for fixing the regression quickly; but we need to restructure the pass next as it's quite outgrown now.

#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32 {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg13: i32 {tt.divisibility = 16 : i32}, %arg14: i32 {tt.divisibility = 16 : i32}, %arg15: i32 {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you simplify the test here? I don't think we need all the details here to check what we want to check?

// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]]

// CHECK-LABEL: hoist_q_out_of_the_loop
// CHECK: %[[TRUNCF:.+]] = arith.truncf
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also want to check the relative position of scf.for to be clear?

Comment thread test/TritonGPU/amd/amd-reorder-instructions.mlir Outdated
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
// CHECK-LABEL: no_hoist_q_type_reordering
// CHECK: tt.load
// CHECK-NEXT: arith.constant
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please check relative positioning of scf.for to be clear.

Comment thread third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp Outdated
Comment thread third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp Outdated
Comment thread third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp Outdated
Comment thread third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp Outdated
Comment thread third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp Outdated
@oplavsic
Copy link
Copy Markdown
Contributor Author

Thanks for the fix! A few comments inlined. I'm fine to land this for fixing the regression quickly; but we need to restructure the pass next as it's quite outgrown now.

Yes, I agree. I plan to do this as a part of dot slicing work since dot slicing is heavily relying on instruction order of sliced parts.

@oplavsic oplavsic force-pushed the hoist_q_fix_upstream branch 2 times, most recently from eff48b0 to c8cf15b Compare October 14, 2024 16:01
@oplavsic oplavsic force-pushed the hoist_q_fix_upstream branch from c8cf15b to f0633c4 Compare October 14, 2024 16:02
@antiagainst antiagainst marked this pull request as ready for review October 14, 2024 16:21
@oplavsic
Copy link
Copy Markdown
Contributor Author

@antiagainst Addressed your comments, thanks for the review! @zhanglx13 before we merge, please let me know if this is what you had in mind

Copy link
Copy Markdown
Collaborator

@zhanglx13 zhanglx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Also checked the regression is gone.

@antiagainst antiagainst merged commit 037728b into triton-lang:main Oct 14, 2024
ptillet added a commit that referenced this pull request Oct 16, 2024
ptillet pushed a commit that referenced this pull request Oct 16, 2024
Adjust the placement of LDS writes and reads to immediately follow the
definition of their operands in case where LDS write is in the loop but
it's operand is not. This is a heuristic for optimizing fused attention
by hoisting Q tensor LDS read/write operations outside of the loop, as Q
is a loop invariant and can be loaded once before entering the loop.

In the previous implementation, the heuristic incorrectly assumed that
the operand of the LDS write had to be a load operation, which is
unnecessary. Additionally, there was no explicit check to verify whether
the LDS write was in the loop while its defining operand was not. This
PR addresses both issues.

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
antiagainst pushed a commit that referenced this pull request Oct 16, 2024
alexsamardzic pushed a commit to alexsamardzic/triton that referenced this pull request Oct 16, 2024
jtang10 pushed a commit to ROCm/triton that referenced this pull request Oct 21, 2024
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
Adjust the placement of LDS writes and reads to immediately follow the
definition of their operands in case where LDS write is in the loop but
it's operand is not. This is a heuristic for optimizing fused attention
by hoisting Q tensor LDS read/write operations outside of the loop, as Q
is a loop invariant and can be loaded once before entering the loop.

In the previous implementation, the heuristic incorrectly assumed that
the operand of the LDS write had to be a load operation, which is
unnecessary. Additionally, there was no explicit check to verify whether
the LDS write was in the loop while its defining operand was not. This
PR addresses both issues.

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
Adjust the placement of LDS writes and reads to immediately follow the
definition of their operands in case where LDS write is in the loop but
it's operand is not. This is a heuristic for optimizing fused attention
by hoisting Q tensor LDS read/write operations outside of the loop, as Q
is a loop invariant and can be loaded once before entering the loop.

In the previous implementation, the heuristic incorrectly assumed that
the operand of the LDS write had to be a load operation, which is
unnecessary. Additionally, there was no explicit check to verify whether
the LDS write was in the loop while its defining operand was not. This
PR addresses both issues.

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
Adjust the placement of LDS writes and reads to immediately follow the
definition of their operands in case where LDS write is in the loop but
it's operand is not. This is a heuristic for optimizing fused attention
by hoisting Q tensor LDS read/write operations outside of the loop, as Q
is a loop invariant and can be loaded once before entering the loop.

In the previous implementation, the heuristic incorrectly assumed that
the operand of the LDS write had to be a load operation, which is
unnecessary. Additionally, there was no explicit check to verify whether
the LDS write was in the loop while its defining operand was not. This
PR addresses both issues.

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@luxoft.com>
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
zhanglx13 added a commit to ROCm/triton that referenced this pull request Feb 25, 2025
This PR adds a new amd.pass that hoists conver_layout to dotOperand
layout for the Q tensor out of the loop. Therefore, Q tensor is kept
in registers instead of being loaded at every iteration of the loop.

This PR is actually achieving the same thing as
triton-lang#4901. However,
triton-lang#4901 does not hoist
local_load for Q in the epilogue, making Q tensor live in shared
memory all the time.
On the other hand, this PR does the trick before stream-pipeline
pass. Therefore, the livessness of Q tensor in shared memory is
limited in the prologue.
zhanglx13 added a commit to ROCm/triton that referenced this pull request Feb 25, 2025
This PR adds a new amd.pass that hoists conver_layout to dotOperand
layout for the Q tensor out of the loop. Therefore, Q tensor is kept
in registers instead of being loaded at every iteration of the loop.

This PR is actually achieving the same thing as
triton-lang#4901. However,
triton-lang#4901 does not hoist
local_load for Q in the epilogue, making Q tensor live in shared
memory all the time.
On the other hand, this PR does the trick before stream-pipeline
pass. Therefore, the livessness of Q tensor in shared memory is
limited in the prologue.
zhanglx13 added a commit that referenced this pull request Feb 26, 2025
This PR adds a new amd.pass that hoists conver_layout to dotOperand
layout for the Q tensor out of the loop. Therefore, Q tensor is kept in
registers instead of being loaded at every iteration of the loop.

This PR is actually achieving the same thing as
#4901. However,
#4901 does not hoist
local_load for Q in the epilogue, making Q tensor live in shared memory
all the time.
On the other hand, this PR does the trick before stream-pipeline pass.
Therefore, the livessness of Q tensor in shared memory is limited in the
prologue.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants