[BACKEND] Small fixes for dot operand properties#4895
Conversation
|
@lezcano I haven't committed all fixes yet since I noticed some differences between my PR and yours. Most changes in this PR are probably consistent except for |
| std::iota(order.rbegin(), order.rend(), 0); | ||
| if (dotOpLayout.getOpIdx() == 0) { | ||
| std::swap(order[0], order[1]); | ||
| } |
There was a problem hiding this comment.
This will result in
opIdx=0: [1, 2, 0]
opIdx=1: [2, 1, 0]
And I assume batch is dim0, so maybe you want the following?
opIdx=0: [1, 0, 2]
opIdx=1: [0, 1, 2]
There was a problem hiding this comment.
Yep. What I was trying to say is order[rank - 2] and order[rank - 1]. Thanks for the review.
There was a problem hiding this comment.
Because this is swapping the order array, I think the above code is correct.
order[0] refers to the leading dimension.
For opIdx=0, [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k], the leading dimension should be dim1=m
For opIdx=1, [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n], the leading dimension should be dim2=n
lezcano
left a comment
There was a problem hiding this comment.
Thank you for the clean-up!
Here's a list of other bugs I found. Feel free to bundle the fixes into this PR, or we can land another round of fixes after my PR:
getShapePerCTATileForDotOperandsshould beForOperand. Should take akWidthas argument, and should returnkWidth * 2 * 4in theKdimension (or equiv. for AMD layouts)getWarpsPerCTAshould clamp theKdimension to1as per
auto warps = distributedLayout.getWarpsPerCTA();
auto kDim = getOpIdx() == 0 ? 1 : 0;
warps[kDim] = 1;- All the
ForOperandops should be moved toDotOperandLayoutAttrclass, to be able to call class-dependent ops, like thegetWarpsPerCTAdefined above (this bit me whengetMMAv2RepcallsgetWarpsPerCTA, and in some other place as well). getThreadOrdershould be modified the same way asgetWarpOrder
| std::iota(order.rbegin(), order.rend(), 0); | ||
| } | ||
| return order; | ||
| return getOrderForDotOperand(dotLayout.getOpIdx(), rank); |
There was a problem hiding this comment.
This is change is correct for consistency, but will conflict with #4891. We will have to fix there (and perhaps the LL lowering) if this one lands after the LL one.
|
@lezcano ready for another round of review. Also added you as a co-author |
lezcano
left a comment
There was a problem hiding this comment.
LGTM. We are missing fixing getWarpsPerCTA, but we can fix that one in a different PR, as it may affect the way we lower MMA ops.
This PR includes #4891 and #4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. #4895 is the first step in this direction.
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
Co-authored-by: Mario Lezcano Casado <lezcano@openai.com>
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
Co-authored-by: Mario Lezcano Casado <lezcano@openai.com>
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
Co-authored-by: Mario Lezcano Casado <lezcano@openai.com>
This PR includes triton-lang#4891 and triton-lang#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang#4895 is the first step in this direction.
This PR includes triton-lang/triton#4891 and triton-lang/triton#4895. I will rebase once those have landed. It includes a number of hacks to work around bugs in `DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be easy to grep for. @Jokeren is working on a comprehensive revamp of `DotOperandEncodingAttr` which will get rid of all these. triton-lang/triton#4895 is the first step in this direction.
Co-authored-by: Mario Lezcano Casado lezcano@openai.com