Skip to content

[Linear Layouts] Implement LL conversion for DotOperand(version=2)#4891

Merged
lezcano merged 4 commits intotriton-lang:mainfrom
lezcano:ll_dot
Oct 15, 2024
Merged

[Linear Layouts] Implement LL conversion for DotOperand(version=2)#4891
lezcano merged 4 commits intotriton-lang:mainfrom
lezcano:ll_dot

Conversation

@lezcano
Copy link
Copy Markdown
Contributor

@lezcano lezcano commented Oct 11, 2024

Note that the current implementation uses DotOperandEncodingAttr::getWarpsPerCTA, which was buggy for cases where the warps are not of the form [numWarps, 1] or [1, numWarps]. This PR bundles a fix for this issue.

We will activate its use for a subset of DotOperandEncodings in a PR coming soon.

@lezcano lezcano requested a review from ptillet as a code owner October 11, 2024 15:20
@lezcano lezcano changed the title Implement LL [Linear Layouts] Implement LL conversion for DotOperand Oct 11, 2024
@lezcano lezcano requested a review from Jokeren October 11, 2024 15:21
@lezcano lezcano changed the title [Linear Layouts] Implement LL conversion for DotOperand [Linear Layouts] Implement LL conversion for DotOperand(version=2) Oct 11, 2024
@ThomasRaoux
Copy link
Copy Markdown
Collaborator

Looks good overall but I'll look more in details after the lit tests are fixed

Comment thread lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp Outdated
Comment thread unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp Outdated
Comment on lines 1040 to +1045
SmallVector<unsigned> DotOperandEncodingAttr::getWarpsPerCTA() const {
auto parentLayout = getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
if (auto distributedLayout =
mlir::dyn_cast<DistributedEncodingTrait>(parentLayout)) {
return distributedLayout.getWarpsPerCTA();
} else {
llvm::report_fatal_error(
"DotOperandEncodingAttr non-DistributedEncodingAttr parent not "
"supported yet");
}
auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent());
auto warps = distributedLayout.getWarpsPerCTA();
auto rank = warps.size();
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
warps[kDim] = 1;
return warps;
Copy link
Copy Markdown
Contributor Author

@lezcano lezcano Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jokeren added this fix as I needed it for the layout conversion to be correct and pass the newly added tests.

@lezcano
Copy link
Copy Markdown
Contributor Author

lezcano commented Oct 14, 2024

Addressed the reviews


LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
// TODO,BE. Implement ampereMMA in terms of this one
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "BE" mean? Backend?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better Engineering

Comment thread lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp Outdated
@lezcano lezcano merged commit ec0bd4a into triton-lang:main Oct 15, 2024
lezcano added a commit that referenced this pull request Oct 16, 2024
This PR includes #4891 and
#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
#4895 is the first step in
this direction.
alexsamardzic pushed a commit to alexsamardzic/triton that referenced this pull request Oct 16, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…riton-lang#4891)

Note that the current implementation uses
`DotOperandEncodingAttr::getWarpsPerCTA`, which was buggy for cases
where the warps are not of the form `[numWarps, 1]` or `[1, numWarps]`.
This PR bundles a fix for this issue.

We will activate its use for a subset of `DotOperandEncoding`s in a PR
coming soon.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
…riton-lang#4891)

Note that the current implementation uses
`DotOperandEncodingAttr::getWarpsPerCTA`, which was buggy for cases
where the warps are not of the form `[numWarps, 1]` or `[1, numWarps]`.
This PR bundles a fix for this issue.

We will activate its use for a subset of `DotOperandEncoding`s in a PR
coming soon.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…riton-lang#4891)

Note that the current implementation uses
`DotOperandEncodingAttr::getWarpsPerCTA`, which was buggy for cases
where the warps are not of the form `[numWarps, 1]` or `[1, numWarps]`.
This PR bundles a fix for this issue.

We will activate its use for a subset of `DotOperandEncoding`s in a PR
coming soon.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
This PR includes triton-lang#4891 and
triton-lang#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang#4895 is the first step in
this direction.
liuyunqi20 pushed a commit to flagos-ai/FlagTree that referenced this pull request Oct 21, 2025
This PR includes triton-lang/triton#4891 and
triton-lang/triton#4895. I will rebase once
those have landed.

It includes a number of hacks to work around bugs in
`DotOperandEncodingAttr`. All these are marked as `FIXME [Dot LL]` to be
easy to grep for. @Jokeren is working on a comprehensive revamp of
`DotOperandEncodingAttr` which will get rid of all these.
triton-lang/triton#4895 is the first step in
this direction.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants