Skip to content

[BACKEND] Small fixes for dot operand properties#4895

Merged
Jokeren merged 21 commits intomainfrom
keren/dot-op-fix
Oct 15, 2024
Merged

[BACKEND] Small fixes for dot operand properties#4895
Jokeren merged 21 commits intomainfrom
keren/dot-op-fix

Conversation

@Jokeren
Copy link
Copy Markdown
Contributor

@Jokeren Jokeren commented Oct 11, 2024

Co-authored-by: Mario Lezcano Casado lezcano@openai.com

@Jokeren
Copy link
Copy Markdown
Contributor Author

Jokeren commented Oct 11, 2024

@lezcano I haven't committed all fixes yet since I noticed some differences between my PR and yours. Most changes in this PR are probably consistent except for getContigPerThread

@Jokeren Jokeren marked this pull request as ready for review October 12, 2024 12:32
std::iota(order.rbegin(), order.rend(), 0);
if (dotOpLayout.getOpIdx() == 0) {
std::swap(order[0], order[1]);
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will result in
opIdx=0: [1, 2, 0]
opIdx=1: [2, 1, 0]

And I assume batch is dim0, so maybe you want the following?
opIdx=0: [1, 0, 2]
opIdx=1: [0, 1, 2]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. What I was trying to say is order[rank - 2] and order[rank - 1]. Thanks for the review.

Copy link
Copy Markdown
Contributor Author

@Jokeren Jokeren Oct 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is swapping the order array, I think the above code is correct.
order[0] refers to the leading dimension.
For opIdx=0, [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k], the leading dimension should be dim1=m
For opIdx=1, [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n], the leading dimension should be dim2=n

Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clean-up!

Here's a list of other bugs I found. Feel free to bundle the fixes into this PR, or we can land another round of fixes after my PR:

  • getShapePerCTATileForDotOperands should be ForOperand. Should take a kWidth as argument, and should return kWidth * 2 * 4 in the K dimension (or equiv. for AMD layouts)
  • getWarpsPerCTA should clamp the K dimension to 1 as per
auto warps = distributedLayout.getWarpsPerCTA();
auto kDim = getOpIdx() == 0 ? 1 : 0;
warps[kDim] = 1;
  • All the ForOperand ops should be moved to DotOperandLayoutAttr class, to be able to call class-dependent ops, like the getWarpsPerCTA defined above (this bit me when getMMAv2Rep calls getWarpsPerCTA, and in some other place as well).
  • getThreadOrder should be modified the same way as getWarpOrder

Comment thread lib/Dialect/TritonGPU/IR/Dialect.cpp
std::iota(order.rbegin(), order.rend(), 0);
}
return order;
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is change is correct for consistency, but will conflict with #4891. We will have to fix there (and perhaps the LL lowering) if this one lands after the LL one.

Comment thread include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
@Jokeren
Copy link
Copy Markdown
Contributor Author

Jokeren commented Oct 15, 2024

@lezcano ready for another round of review. Also added you as a co-author

Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. We are missing fixing getWarpsPerCTA, but we can fix that one in a different PR, as it may affect the way we lower MMA ops.

@Jokeren Jokeren merged commit f9688ab into main Oct 15, 2024
@Jokeren Jokeren deleted the keren/dot-op-fix branch October 15, 2024 16:16
lezcano added a commit that referenced this pull request Oct 16, 2024
This PR includes #4891 and
#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
#4895 is the first step in
this direction.
alexsamardzic pushed a commit to alexsamardzic/triton that referenced this pull request Oct 16, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
Co-authored-by: Mario Lezcano Casado <lezcano@openai.com>
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
Co-authored-by: Mario Lezcano Casado <lezcano@openai.com>
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
Co-authored-by: Mario Lezcano Casado <lezcano@openai.com>
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
liuyunqi20 pushed a commit to flagos-ai/FlagTree that referenced this pull request Oct 21, 2025
This PR includes triton-lang/triton#4891 and
triton-lang/triton#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang/triton#4895 is the first step in
this direction.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants