-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Backend] Improve dot support to target FMA #4516
base: main
Are you sure you want to change the base?
Conversation
68350e9
to
8e620d3
Compare
8e620d3
to
9d01eab
Compare
9d01eab
to
3033970
Compare
3033970
to
6907073
Compare
This PR: - Refactors FMA dot implementation - Supports dot3d in FMA path - Fixes several issues in operand offset computation - Enables small dot operands
…ompiltion time and reduce number of instructions in assembly, fix bug with wrong order field used for share mem load size computation
35bae87
to
fe8d557
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First batch of comments; I still need to review SharedToDotOperandFMA.cpp
more carefully.
@@ -1471,6 +1471,22 @@ inline bool isLayoutMmaV1(Attribute layout) { | |||
return isMmaV1; | |||
} | |||
|
|||
inline SharedMemoryObject |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some documentation to this function.
@@ -129,6 +129,16 @@ void dumpHWLayout(RankedTensorType tensorType); | |||
// Return a string representation of the layout of the tensor. | |||
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView); | |||
|
|||
template <typename T> | |||
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s) { | |||
llvm::SmallVector<T> expanded(3 - s.size(), 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assert s.size() <= 3
and directly return if == 3
?
@@ -3205,6 +3202,15 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType, | |||
return layoutStr; | |||
} | |||
|
|||
llvm::SmallVector<unsigned> | |||
mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o) { | |||
int oldRank = o.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assert o.size <= 3
and return directly if == 3
?
@@ -15,7 +15,16 @@ | |||
|
|||
|
|||
def min_dot_size(target: GPUTarget): | |||
return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16) | |||
|
|||
def fma_supported(lhsType, rhsType): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's don't touch nvidia side for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, but note that changes in common code also affects nvidia side.
This is just a switch which enables this functionality in frontend.
third_party/amd/backend/compiler.py
Outdated
def fma_supported(lhsType, rhsType): | ||
return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32()) | ||
|
||
def gfx94_limits(lhsType, rhsType): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_gfx94_limits
auto dTensorTy = cast<RankedTensorType>(D.getType()); | ||
auto dElemTy = dTensorTy.getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dead code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this PR yes. It will be used in later parts to choose data specific intrinsics instead of simple FMA.
I will remove it here.
unsigned idx[] = {b, m, n}; | ||
unsigned linearIdx = 0; | ||
for (auto dim : llvm::reverse(order)) { | ||
linearIdx = linearIdx * retSize[dim] + idx[dim]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is non-trivial. Can you add some comment to explain the how values are stored in ret
so why we compute the linearIdx
this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have values scattered across multiple dimensions of the tensor, but in LLVM IR Triton stores them in linear structure.
This part computes linear index in this structure where to put dot result according it's batch, M and N coordinates.
I will put this part in separate function.
SmallVector<Value> aOff(aNumPtr); | ||
for (int i = 0; i < aNumPtr; ++i) { | ||
aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1)); | ||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the commonly used style. Can we follow the common style in the codebase to be consistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change it.
Do we have a code style guide for stuff like this?
Sometimes it is not obvious which option to pick, I've seen few different styles of comments in many places in our code base and seems this depends on a code author taste.
@@ -3282,6 +3286,9 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid | |||
return | |||
# make sure ld/st are vectorized | |||
ptx = pgm.asm['ptx'] | |||
is_fma = K < 16 or N < 16 or M < 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we skip testing these small matmuls for nvidia?
auto bTensorTy = cast<MemDescType>(B.getType()); | ||
auto bLayout = cast<SharedEncodingAttr>(bTensorTy.getEncoding()); | ||
auto bShapePerCTA = getShapePerCTA(bTensorTy); | ||
Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are lots of magic indexing and calculation in this function in general. Can you provide more comments inside to make it easier for others to follow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think comments will make this function easier to understand.
Instead I am trying to break it into smaller functions, but this is not easy so far.
When addressing comments, please make sure to add new commits and not squashing into existing ones. Otherwise it's hard to re-review again. |
fe8d557
to
4d70a5e
Compare
hmm, let me rework this with merge commits, I did not notice this comment in time Typically I rebase changes on top of main branch, when I have conflicts. Because this way history look clean and it is easier to review, but I see that you prefer merge updates, so I'll continue doing it that way. |
4d70a5e
to
04678cc
Compare
This PR:
This PR is a part of PR series. Final goal is to improve efficiency of small dot operations and bypass as much shared memory accesses as possible.
Rough list of PRs: