[Backend] Implement scaled_dot(mxfp4, fp8)#4904
Conversation
c15d411 to
104200d
Compare
20a64b1 to
33fceb2
Compare
| %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> | ||
| %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL> | ||
| // CHECK-NEXT: offset = 0, size = 4224 | ||
| // CHECK-NEXT: offset = 0, size = 4352 |
There was a problem hiding this comment.
nb. These changes are coming from the change in lib/Analysis/Allocation.cpp
There was a problem hiding this comment.
It's OK this path was never tested anyway. It will be tested in my next PR.
| // This should be getElemOrder, but we don't have such a method | ||
| // TODO Implement getElemOrder and make sure it's consistent with | ||
| // getContigPerThread | ||
| auto inOrd = gpu::getThreadOrder(srcLayout); |
There was a problem hiding this comment.
I think we assume getElemOrder == getOrder
There was a problem hiding this comment.
getThreadOrder is same as getOrder except for AMD's AMDMfmaEncodingAttr. I haven't taken a deep investigation.
pin @zhanglx13 for expertise maybe
There was a problem hiding this comment.
See that I changed the definition of getThreadOrder in this PR.
There was a problem hiding this comment.
To be specific I was referring to:
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
auto order = ::getOrder(*this);
if (getIsTransposed())
std::swap(order[0], order[1]);
return order;
}
I'm not sure if we should use getOrder or getThreadOrder for this encoding
| %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> | ||
| %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>, #BL> | ||
| // CHECK-NEXT: offset = 0, size = 4224 | ||
| // CHECK-NEXT: offset = 0, size = 4352 |
There was a problem hiding this comment.
It's OK this path was never tested anyway. It will be tested in my next PR.
| typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); | ||
|
|
||
| // FIXME [Dot LL] | ||
| // max(repN / 2, 1) is wrong for repN = 1! |
There was a problem hiding this comment.
Can you elaborate on // max(repN / 2, 1) is wrong for repN = 1!?
Why repN=1 is wrong?
There was a problem hiding this comment.
We are taking this max(repN / 2, 1) here, and then in the loop inside getValuesFromDotOperandLayoutStruct we are packing 4 elements at a time. Rather than that, the correct implementation packs 2 elements inside the function for opIdx=1 and iterates repN times.
This is a tentative PR to check how much breaks if we fix this.
ThomasRaoux
left a comment
There was a problem hiding this comment.
Looks good overall although I didn't look in details at the LL TODOs.
Just added few minor comments
| // FIXME: mma should just return getOrderForDotOperand(0, order.size(), | ||
| // kMajor=false) |
There was a problem hiding this comment.
I'm also confused by this comment.
There was a problem hiding this comment.
Here I just meant that the logic in mma is probably wrong and we just want this function to return what I wrote there. The point here is that, in terms of order, the mma layout is the same as the DotOperandEncoding(opIdx=0)
There was a problem hiding this comment.
I had another go at the comment. Third's a charm
| order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), | ||
| /*kMajor*/ false); |
There was a problem hiding this comment.
why is kMajor always false here?
There was a problem hiding this comment.
This is getting the warp order but not the element order. So m is the fastest changing dimension in opIdx=0. I think confusion may arise from the variable name kMajor.
There was a problem hiding this comment.
I don't have a suggestion for improvement though. Maybe just add some additional comments.
There was a problem hiding this comment.
Yep, similarly to in wgmma, we want the warps have the exterior dimension (i.e. not K) as their fastest running dimension.
| vType.getShape(), vType.getElementType(), newVEncoding); | ||
| return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v); | ||
| } else { | ||
| auto newVEncoding = DotOperandEncodingAttr::get( |
There was a problem hiding this comment.
nit: assert that this is a fp8 type?
There was a problem hiding this comment.
Done, although it's a bit redundant, as we are already asserting this at the beginning of the function and in semantics.py.
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes #4891 and #4895. I will rebase once those have landed.
It includes a number of hacks to work around bugs in
DotOperandEncodingAttr. All these are marked asFIXME [Dot LL]to be easy to grep for. @Jokeren is working on a comprehensive revamp ofDotOperandEncodingAttrwhich will get rid of all these. #4895 is the first step in this direction.