Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend] Improve dot support to target FMA #4516

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

binarman
Copy link
Contributor

@binarman binarman commented Aug 14, 2024

This PR:

  • Refactors FMA dot implementation
  • Supports dot3d in FMA path
  • Fixes several issues in operand offset computation
  • Enables small dot operands

This PR is a part of PR series. Final goal is to improve efficiency of small dot operations and bypass as much shared memory accesses as possible.

Rough list of PRs:

@binarman binarman force-pushed the small_fma_dot branch 3 times, most recently from 68350e9 to 8e620d3 Compare August 15, 2024 22:19
@binarman binarman changed the title [WIP] Relax dot operand constrains with FMA based dot Relax dot operand constrains with FMA based dot Aug 17, 2024
This PR:
- Refactors FMA dot implementation
- Supports dot3d in FMA path
- Fixes several issues in operand offset computation
- Enables small dot operands
…ompiltion time and reduce number of instructions in assembly,

fix bug with wrong order field used for share mem load size computation
Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First batch of comments; I still need to review SharedToDotOperandFMA.cpp more carefully.

@@ -1471,6 +1471,22 @@ inline bool isLayoutMmaV1(Attribute layout) {
return isMmaV1;
}

inline SharedMemoryObject
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some documentation to this function.

@@ -129,6 +129,16 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s) {
llvm::SmallVector<T> expanded(3 - s.size(), 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert s.size() <= 3 and directly return if == 3?

@@ -3205,6 +3202,15 @@ std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
return layoutStr;
}

llvm::SmallVector<unsigned>
mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o) {
int oldRank = o.size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert o.size <= 3 and return directly if == 3?

@@ -15,7 +15,16 @@


def min_dot_size(target: GPUTarget):
return lambda lhsType, rhsType: (16, 32, 16) if lhsType.is_int8() else (16, 16, 16)

def fma_supported(lhsType, rhsType):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's don't touch nvidia side for now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but note that changes in common code also affects nvidia side.
This is just a switch which enables this functionality in frontend.

def fma_supported(lhsType, rhsType):
return lhsType == rhsType and (lhsType.is_fp16() or lhsType.is_fp32())

def gfx94_limits(lhsType, rhsType):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_gfx94_limits

auto dTensorTy = cast<RankedTensorType>(D.getType());
auto dElemTy = dTensorTy.getElementType();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dead code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR yes. It will be used in later parts to choose data specific intrinsics instead of simple FMA.
I will remove it here.

unsigned idx[] = {b, m, n};
unsigned linearIdx = 0;
for (auto dim : llvm::reverse(order)) {
linearIdx = linearIdx * retSize[dim] + idx[dim];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is non-trivial. Can you add some comment to explain the how values are stored in ret so why we compute the linearIdx this way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have values scattered across multiple dimensions of the tensor, but in LLVM IR Triton stores them in linear structure.

This part computes linear index in this structure where to put dot result according it's batch, M and N coordinates.

I will put this part in separate function.

SmallVector<Value> aOff(aNumPtr);
for (int i = 0; i < aNumPtr; ++i) {
aOff[i] = add(mul(offA0, strideA0), mul(offA1, strideA1));
/**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the commonly used style. Can we follow the common style in the codebase to be consistent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change it.
Do we have a code style guide for stuff like this?

Sometimes it is not obvious which option to pick, I've seen few different styles of comments in many places in our code base and seems this depends on a code author taste.

@@ -3282,6 +3286,9 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
return
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
is_fma = K < 16 or N < 16 or M < 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip testing these small matmuls for nvidia?

auto bTensorTy = cast<MemDescType>(B.getType());
auto bLayout = cast<SharedEncodingAttr>(bTensorTy.getEncoding());
auto bShapePerCTA = getShapePerCTA(bTensorTy);
Value loadFMAOp(Value dotOp, Value llA, BlockedEncodingAttr dLayout,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are lots of magic indexing and calculation in this function in general. Can you provide more comments inside to make it easier for others to follow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think comments will make this function easier to understand.
Instead I am trying to break it into smaller functions, but this is not easy so far.

@antiagainst
Copy link
Collaborator

antiagainst commented Sep 23, 2024

When addressing comments, please make sure to add new commits and not squashing into existing ones. Otherwise it's hard to re-review again.

@antiagainst antiagainst changed the title Relax dot operand constrains with FMA based dot [Backend] Improve dot support to target FMA Sep 23, 2024
@binarman
Copy link
Contributor Author

binarman commented Oct 2, 2024

When addressing comments, please make sure to add new commits and not squashing into existing ones. Otherwise it's hard to re-review again.

hmm, let me rework this with merge commits, I did not notice this comment in time

Typically I rebase changes on top of main branch, when I have conflicts. Because this way history look clean and it is easier to review, but I see that you prefer merge updates, so I'll continue doing it that way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants