Skip to content

[Backend] Implement scaled_dot(mxfp4, fp8)#4904

Merged
lezcano merged 7 commits intotriton-lang:mainfrom
lezcano:mxfp_snd
Oct 16, 2024
Merged

[Backend] Implement scaled_dot(mxfp4, fp8)#4904
lezcano merged 7 commits intotriton-lang:mainfrom
lezcano:mxfp_snd

Conversation

@lezcano
Copy link
Copy Markdown
Contributor

@lezcano lezcano commented Oct 14, 2024

This PR includes #4891 and #4895. I will rebase once those have landed.

It includes a number of hacks to work around bugs in DotOperandEncodingAttr. All these are marked as FIXME [Dot LL] to be easy to grep for. @Jokeren is working on a comprehensive revamp of DotOperandEncodingAttr which will get rid of all these. #4895 is the first step in this direction.

@lezcano lezcano changed the title mxfp snd [Backend] Implement scaled_dot(mxfp4, fp8) Oct 14, 2024
@lezcano lezcano requested a review from ThomasRaoux October 14, 2024 17:25
@lezcano lezcano marked this pull request as draft October 14, 2024 17:51
@lezcano lezcano force-pushed the mxfp_snd branch 4 times, most recently from c15d411 to 104200d Compare October 15, 2024 14:23
@lezcano lezcano marked this pull request as ready for review October 15, 2024 14:44
@lezcano lezcano force-pushed the mxfp_snd branch 2 times, most recently from 20a64b1 to 33fceb2 Compare October 15, 2024 16:44
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL>
// CHECK-NEXT: offset = 0, size = 4224
// CHECK-NEXT: offset = 0, size = 4352
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nb. These changes are coming from the change in lib/Analysis/Allocation.cpp

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK this path was never tested anyway. It will be tested in my next PR.

// This should be getElemOrder, but we don't have such a method
// TODO Implement getElemOrder and make sure it's consistent with
// getContigPerThread
auto inOrd = gpu::getThreadOrder(srcLayout);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we assume getElemOrder == getOrder

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getThreadOrder is same as getOrder except for AMD's AMDMfmaEncodingAttr. I haven't taken a deep investigation.
pin @zhanglx13 for expertise maybe

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See that I changed the definition of getThreadOrder in this PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be specific I was referring to:

SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
  auto order = ::getOrder(*this);
  if (getIsTransposed())
    std::swap(order[0], order[1]);
  return order;
}

I'm not sure if we should use getOrder or getThreadOrder for this encoding

Comment thread lib/Dialect/TritonGPU/IR/Dialect.cpp
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL>
// CHECK-NEXT: offset = 0, size = 4224
// CHECK-NEXT: offset = 0, size = 4352
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK this path was never tested anyway. It will be tested in my next PR.

typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy);

// FIXME [Dot LL]
// max(repN / 2, 1) is wrong for repN = 1!
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on // max(repN / 2, 1) is wrong for repN = 1!?
Why repN=1 is wrong?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are taking this max(repN / 2, 1) here, and then in the loop inside getValuesFromDotOperandLayoutStruct we are packing 4 elements at a time. Rather than that, the correct implementation packs 2 elements inside the function for opIdx=1 and iterates repN times.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall although I didn't look in details at the LL TODOs.
Just added few minor comments

Comment thread lib/Dialect/TritonGPU/IR/Dialect.cpp Outdated
Comment on lines +260 to +261
// FIXME: mma should just return getOrderForDotOperand(0, order.size(),
// kMajor=false)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also confused by this comment.

Copy link
Copy Markdown
Contributor Author

@lezcano lezcano Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I just meant that the logic in mma is probably wrong and we just want this function to return what I wrote there. The point here is that, in terms of order, the mma layout is the same as the DotOperandEncoding(opIdx=0)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had another go at the comment. Third's a charm

Comment on lines +271 to +272
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
/*kMajor*/ false);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is kMajor always false here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting the warp order but not the element order. So m is the fastest changing dimension in opIdx=0. I think confusion may arise from the variable name kMajor.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a suggestion for improvement though. Maybe just add some additional comments.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, similarly to in wgmma, we want the warps have the exterior dimension (i.e. not K) as their fastest running dimension.

vType.getShape(), vType.getElementType(), newVEncoding);
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
} else {
auto newVEncoding = DotOperandEncodingAttr::get(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assert that this is a fp8 type?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, although it's a bit redundant, as we are already asserting this at the beginning of the function and in semantics.py.

Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lezcano lezcano merged commit 9e90089 into triton-lang:main Oct 16, 2024
@lezcano lezcano deleted the mxfp_snd branch October 16, 2024 15:21
alexsamardzic pushed a commit to alexsamardzic/triton that referenced this pull request Oct 16, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants