-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[BACKEND] Fix the divideRight method in Linear Layout when eliminating input and output dimensions
#4530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
divideRight logic for eliminating input and output dimensionsdivideRight when eliminating input and output dimensions
divideRight when eliminating input and output dimensionsdivideRight method in Linear Layout when eliminating input and output dimensions
divideRight method in Linear Layout when eliminating input and output dimensionsdivideRight method in Linear Layout when eliminating input and output dimensions
Ah, I think you meant |
|
Oh yes, typo fixed! Thank you! |
This seems pretty confusing to me, because now we do not have the invariant that divideRight is the inverse of I haven't thought about this, but instead of removing the empty dimensions from the dividend, would it be impossible to infer the additional dimensions and add them? |
include/triton/Tools/LinearLayout.h
Outdated
| // a' * b = c and a * b' = c. | ||
| // | ||
| // Note that a' and a may not have exactly the same input/output dimensions. | ||
| // a may contain additional empty input dimensions than a'. For example: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grammar
include/triton/Tools/LinearLayout.h
Outdated
| } | ||
|
|
||
| // divideLeft and divideRight are the inverses of operator*. | ||
| // divideLeft and divideRight are the inverses of operator *. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In C++ it's called operator*, not operator *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's a GPT bug...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't meant to change this line
lib/Tools/LinearLayout.cpp
Outdated
| template <typename T, typename U> | ||
| void assertCommonDimsSameOrder(T &&outerDims, U &&innerDims) { | ||
| // Check that elements common to both outerDimsRange and innerDimsRange | ||
| // appear in the same relative order. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment describes the behavior of the function. Therefore move it outside the function?
lib/Tools/LinearLayout.cpp
Outdated
|
|
||
| if (outerCommonDims != innerCommonDims) { | ||
| llvm::report_fatal_error( | ||
| "Cannot multiply layouts. All in/out dimensions common to both " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this error is not correct anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we can just remove "Cannot multiply layouts"
| enqueue(definingOp->getOperand(0), encoding); | ||
| continue; | ||
| } | ||
| if (canFoldIntoConversion(definingOp, encoding)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what this change is doing. Is it related to this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there was a bug in layout removal. It's contributed by @ThomasRaoux. I should add his commit message also
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would help to put it in a separate PR. You can use one of these tools https://www.stacking.dev/?utm_source=stack-comment to stack your PRs, or you can use my very hacky tool https://github.com/jlebar/git-pr-chain
lib/Tools/LinearLayout.cpp
Outdated
| // out-dim0 | ||
| // in-dim0 | size 1 | ||
| // in-dim1 | size 1 | ||
| // in-dim2 | size 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really know what you mean with these diagrams. Perhaps we could use something that matches the toString() output of LinearLayout, since at least that has a well-defined meaning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
It's impossible I think for the reasons stated in the code comments |
|
I think there could be multiple e.g., output e.g., input |
|
Ah, I understand this PR better now. What I think we are saying is, we canonicalize the div output. There may be many values Our correctness property now becomes Is that correct? I think it might help me verify the correctness of this PR if we did two things.
|
include/triton/Tools/LinearLayout.h
Outdated
| // b = L("in2") -> ("out2") | ||
| // | ||
| // c = a * b = a' * b if "in1" is an empty dimension that maps everything | ||
| // to 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe say something like the following instead.
Size-zero dimensions are effectively ignored by operator*:
a*b == a*b'if (and only if) b and b' are the same ignoring any size-zero input- and output-dimensions that are present ina. Therefore if we want divLeft to be the inverse of operator*, there are many possible values that we could return for(a*b).divLeft(a)which would satisfya * (a*b).divLeft(a) == a*b.divideLeft and divideRight resolve this ambiguity by always returning the "canonical" quotient, namely the one with the fewest possible size-zero input- and output-dimensions.
| ConversionPatternRewriter &rewriter) const { | ||
| // TODO(jlebar): Implement me. | ||
| return failure(); | ||
| return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the "implement me" comment above?
| auto srcIdx = dstToSrc->apply({{kRegister, i}}); | ||
| outVals.resize(subLayout.getInDimSize(kRegister)); | ||
| for (int i = 0; i < subLayout.getInDimSize(kRegister); i++) { | ||
| auto srcIdx = subLayout.apply({{kRegister, i}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this change be put into a separate PR? It would make the LinearLayout change here easier to understand in isolation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It couldn't because the following condition doesn't return true any more:
assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==
We must get a subLayout first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps worth writing a comment? "You might be tempted to do X, but it doesn't work because Y."?
| enqueue(definingOp->getOperand(0), encoding); | ||
| continue; | ||
| } | ||
| if (canFoldIntoConversion(definingOp, encoding)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would help to put it in a separate PR. You can use one of these tools https://www.stacking.dev/?utm_source=stack-comment to stack your PRs, or you can use my very hacky tool https://github.com/jlebar/git-pr-chain
lib/Tools/LinearLayout.cpp
Outdated
| if (outerCommonDims != innerCommonDims) { | ||
| llvm::report_fatal_error("All in/out dimensions common to both layouts " | ||
| "must appear in the same relative order, but they " | ||
| "don't.\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We had \n here because we used to be outputting the in/out dims. Now we're not outputting them anymore, so we should lose the \n (or output the dims again, I thought that was kind of helpful?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry. It got removed accidentally
lib/Tools/LinearLayout.cpp
Outdated
| // Check that elements common to both outerDimsRange and innerDimsRange | ||
| // appear in the same relative order. | ||
| template <typename T, typename U> | ||
| void assertCommonDimsSameOrder(T &&outerDims, U &&innerDims) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure "outerDims" and "innerDims" are the correct names, based on how this function is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it doesn't make sense anymore. Should I just call them dimsA and dimsB?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm
lib/Tools/LinearLayout.cpp
Outdated
| divisor.getInDimSizeLog2(inDim)); | ||
| } | ||
|
|
||
| // Record size 1 out-dims caused by the division. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this loop now does two things.
- Record size-1 out-dims caused by the division.
- Check if newOutDims[outDim] > outDimSize (what the loop used to do).
I think this is confusing. For one thing, there's a comment above the loop which says "Record size 1 out-dims caused by the division." but actually that is misleading, it only mentions one of teh two things done by the loop.
Perhaps we could simply have two loops?
| // | ||
| // If we remove "out1" from o, we get: | ||
| // | ||
| // out-dims(l) = ["out0", "out2", "out3"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am having trouble following this example.
It seems to be explaining an important edge case in an algorithm, but the overall algorithm -- the thing we're trying to do -- is not clear to me. The example also has some assumptions that I don't understand. For example, I don't understand how o / r returns anything at all if we remove "out1" from o. Isn't that just an infeasible division? (Unless you're assuming out1 is a size-zero dim? But I don't see how I'm supposed to know that?)
I wonder if the following algorithm works.
- Assume c = a*b.
- We are given b and c, and we want to compute c.divRight(b). i.e. we want to find
a(or some a' which is equivalent toaignoring size-zero dims). - Let b' be b but with all size-zero in-dims and out-dims removed. Same for c'.
- Compute our candidate quotient a' = c' / b', same as before.
- Check if a' * b' == c', same as before.
- If it matches, then let a'' be a' but with the minimum number of size-zero in- and out-dims added back to it so that a'' * b' = c as desired.
- Return a''.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About the new algorithm, I think it would fix the same problem as the existing code in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, I don't understand how o / r returns anything at all if we remove "out1" from o. Isn't that just an infeasible division? (Unless you're assuming out1 is a size-zero dim? But I don't see how I'm supposed to know that?)
Yeah, we assume that "out1" is a size-zero dim. Why did you say "I don't see how I'm supposed to know that"? Is it because of lacking a comment, or you think there's no way to check it? If the former I'll add a comment; I thought it's clear because there's a variable emptyOutDimIndices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key point is that we cannot remove arbitrary empty output dimensions from the quotient.
The following code simulates quotient * divisor = result and enumerates the output dimensions of the result from right to left to check which ones can be removed. When we perform the multiplication, the output dimensions of the quotient are always placed before the output dimensions of the divisor in the result. So if this order breaks in the result, we should stop enumerating output dimensions.
Therefore, I believe the new algorithm you described solves the same problem as the existing code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some confusion about size-1 or size-0 dimensions though. I call them sizeOneOutDimIndices since the out-dim maps everything to 0 and still has a size of 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you say "I don't see how I'm supposed to know that"? Is it because of lacking a comment, or you think there's no way to check it?
Lacking a comment. :)
I thought it's clear because there's a variable emptyOutDimIndices.
...yeah I'm doing my best to understand what's going on, but I did not consider looking at the code below to understand the algorithm above (and anyway I'm not sure it would have helped me).
There are some confusion about size-1 or size-0 dimensions though. I call them sizeOneOutDimIndices since the out-dim maps everything to 0 and still has a size of 1.
Ah yes, "size-1" dims is correct, I was calling it the wrong thing.
Therefore, I believe the new algorithm you described solves the same problem as the existing code.
I think I would have an easier time understanding the new algorithm I proposed, but if you don't think that's the best approach (who knows, I'm not the one implementing it), it would help me if we could explain the algorithm we're using as a comment in the code. I think you're starting to get there with this:
The following code simulates quotient * divisor = result and enumerates the output dimensions of the result from right to left to check which ones can be removed. When we perform the multiplication, the output dimensions of the quotient are always placed before the output dimensions of the divisor in the result. So if this order breaks in the result, we should stop enumerating output dimensions.
I promise I'm doing my best to understand things here and not just being obstinate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've refactored the algorithm as follows:
- Consider
c = a * band construct a candidate quotienta'first without removing any dimensions. - Check if
a' * b == c. If yes, we start to remove empty dimensions. - First, we remove empty trailing output dimensions from
a'. - Then, we remove empty trailing input dimensions from
a'. - Finally, we return the final quotient
a''.
Also, I found that we actually allow a linear layout with no input dimensions but with empty output dimensions. Therefore, the hacky len(input_dims) == len(output_dims) condition can be removed, allowing us to use the same code to check if data transfer can happen within each thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't empty dims in the middle of a' (i.e. not at either end) also removable? They're only not removable if b does not contain the dimension, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, I'm wrong, you explain it in the comment. Thanks. :)
Sure, I'll probably just cherry pick the commit out of this PR. |
5ef4326 to
081d226
Compare
081d226 to
238c9b4
Compare
| } | ||
|
|
||
| // Check that elements common to both outerDimsRange and innerDimsRange | ||
| // appear in the same relative order. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the comment now that you updated the variable names.
| // | ||
| // If we remove "out1" from o, we get: | ||
| // | ||
| // out-dims(l) = ["out0", "out2", "out3"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, I'm wrong, you explain it in the comment. Thanks. :)
include/triton/Tools/LinearLayout.h
Outdated
| // input and output dimensions that are present in `b`. Therefore, if we want | ||
| // divideRight to be the inverse of operator*, there are many possible values | ||
| // that we could return for `(a*b).divideRight(b)` which would satisfy | ||
| // `((a*b).divideRight(b))*b == a*b`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't match what you do, which is only to remove empty dims at the end of the quotient. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments updated. Thanks for your kind suggestions!
…ing input and output dimensions (triton-lang#4530) Before this patch, if `a * b = c`, `c.divideRight(b)` might return `nullopt` even if `a' * b = c`, where `a'` is the potential result of `divideRight`. This PR addresses the issue by conservatively removing input and output dimensions, ensuring that the division returns a non-nullopt result when a valid solution exists. However, it does not guarantee that `a` and `a'` will have identical input and output dimensions. In addition, this PR also fixes a bug in `TritonGPURemoveLayoutConversionsPass`. The backward slice should be continued when encountering a free conversion. This includes cases where `c.divideRight(b)` results in a layout that only permutes register values within individual threads. --------- Co-authored-by: Thomas Raoux <[email protected]>
Before this patch, if
a * b = c,c.divideRight(b)might returnnullopteven ifa' * b = c, wherea'is the potential result ofdivideRight.This PR addresses the issue by conservatively removing input and output dimensions, ensuring that the division returns a non-nullopt result when a valid solution exists.
However, it does not guarantee that
aanda'will have identical input and output dimensions.In addition, this PR also fixes a bug in
TritonGPURemoveLayoutConversionsPass. The backward slice should be continued when encountering a free conversion. This includes cases wherec.divideRight(b)results in a layout that only permutes register values within individual threads.