Skip to content

Conversation

@Jokeren
Copy link
Contributor

@Jokeren Jokeren commented Aug 17, 2024

Before this patch, if a * b = c, c.divideRight(b) might return nullopt even if a' * b = c, where a' is the potential result of divideRight.
This PR addresses the issue by conservatively removing input and output dimensions, ensuring that the division returns a non-nullopt result when a valid solution exists.
However, it does not guarantee that a and a' will have identical input and output dimensions.

In addition, this PR also fixes a bug in TritonGPURemoveLayoutConversionsPass. The backward slice should be continued when encountering a free conversion. This includes cases where c.divideRight(b) results in a layout that only permutes register values within individual threads.

@Jokeren Jokeren changed the title [BACKEND][DRAFT] Fix Linear Layout's divideRight logic for eliminating input and output dimensions [BACKEND][DRAFT] Fix Linear Layout's divideRight when eliminating input and output dimensions Aug 17, 2024
@Jokeren Jokeren marked this pull request as ready for review August 23, 2024 14:13
@Jokeren Jokeren requested a review from ptillet as a code owner August 23, 2024 14:13
@Jokeren Jokeren requested review from ThomasRaoux and jlebar August 23, 2024 14:14
@Jokeren Jokeren changed the title [BACKEND][DRAFT] Fix Linear Layout's divideRight when eliminating input and output dimensions [BACKEND] Fix divideRight method in Linear Layout when eliminating input and output dimensions Aug 23, 2024
@Jokeren Jokeren changed the title [BACKEND] Fix divideRight method in Linear Layout when eliminating input and output dimensions [BACKEND] Fix the divideRight method in Linear Layout when eliminating input and output dimensions Aug 23, 2024
@Jokeren Jokeren requested a review from zahimoud August 23, 2024 14:17
@jlebar
Copy link
Contributor

jlebar commented Aug 23, 2024

Before this patch, if a * b = c, c.divideRight(a) might return nullopt even if a' * b = c, where a' is the potential result of divideRight.

I'm confused... c.divideRight(a) should equal b, right? How does a' fit into it?

Ah, I think you meant c.divideRight(b).

@Jokeren
Copy link
Contributor Author

Jokeren commented Aug 23, 2024

Oh yes, typo fixed! Thank you!

@jlebar
Copy link
Contributor

jlebar commented Aug 23, 2024

However, it does not guarantee that a and a' will have identical input and output dimensions.

This seems pretty confusing to me, because now we do not have the invariant that divideRight is the inverse of *.

I haven't thought about this, but instead of removing the empty dimensions from the dividend, would it be impossible to infer the additional dimensions and add them?

// a' * b = c and a * b' = c.
//
// Note that a' and a may not have exactly the same input/output dimensions.
// a may contain additional empty input dimensions than a'. For example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grammar

}

// divideLeft and divideRight are the inverses of operator*.
// divideLeft and divideRight are the inverses of operator *.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In C++ it's called operator*, not operator *

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's a GPT bug...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't meant to change this line

template <typename T, typename U>
void assertCommonDimsSameOrder(T &&outerDims, U &&innerDims) {
// Check that elements common to both outerDimsRange and innerDimsRange
// appear in the same relative order.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment describes the behavior of the function. Therefore move it outside the function?


if (outerCommonDims != innerCommonDims) {
llvm::report_fatal_error(
"Cannot multiply layouts. All in/out dimensions common to both "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this error is not correct anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we can just remove "Cannot multiply layouts"

enqueue(definingOp->getOperand(0), encoding);
continue;
}
if (canFoldIntoConversion(definingOp, encoding))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this change is doing. Is it related to this PR?

Copy link
Contributor Author

@Jokeren Jokeren Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there was a bug in layout removal. It's contributed by @ThomasRaoux. I should add his commit message also

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would help to put it in a separate PR. You can use one of these tools https://www.stacking.dev/?utm_source=stack-comment to stack your PRs, or you can use my very hacky tool https://github.com/jlebar/git-pr-chain

// out-dim0
// in-dim0 | size 1
// in-dim1 | size 1
// in-dim2 | size 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know what you mean with these diagrams. Perhaps we could use something that matches the toString() output of LinearLayout, since at least that has a well-defined meaning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@Jokeren
Copy link
Contributor Author

Jokeren commented Aug 23, 2024

I haven't thought about this, but instead of removing the empty dimensions from the dividend, would it be impossible to infer the additional dimensions and add them?

It's impossible I think for the reasons stated in the code comments

@Jokeren
Copy link
Contributor Author

Jokeren commented Aug 23, 2024

I think there could be multiple divideRight answers due to either empty output or empty input dimensions.

e.g., output
["out0", "out1", "out2", "out3"] * ["out1", "out3"] = ["out0", "out1", "out2"] * ["out1", "out3"], if out3 is an empty dimension

e.g., input
["in0", "in1"] * ["in2"] = ["in0", "in1"] * ["in1", "in2"], if in1 is an empty dimension

@jlebar
Copy link
Contributor

jlebar commented Aug 24, 2024

Ah, I understand this PR better now.

What I think we are saying is, we canonicalize the div output. There may be many values a for which a * b = c. We have to choose one. We choose the one with as few size-0 input- and output-dimensions as possible.

Our correctness property now becomes canonicalize(b) == (a*b).divRight(b) for all a and b.

Is that correct?

I think it might help me verify the correctness of this PR if we did two things.

  1. Rewrite the PR description and comments in the code to talk about a canonicalization and this new correctness property.
  2. Split out Thomas's change so that I don't have to figure out which test changes relate to that change versus which changes relate to this change.
  3. If possible, I wonder if the new logic in LinearLayout::divRight can be simplified? If all we're doing is a canonicalization, could we use the same logic we had before, and then simply canonicalize the result?

// b = L("in2") -> ("out2")
//
// c = a * b = a' * b if "in1" is an empty dimension that maps everything
// to 0.
Copy link
Contributor

@jlebar jlebar Aug 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe say something like the following instead.

Size-zero dimensions are effectively ignored by operator*: a*b == a*b' if (and only if) b and b' are the same ignoring any size-zero input- and output-dimensions that are present in a. Therefore if we want divLeft to be the inverse of operator*, there are many possible values that we could return for (a*b).divLeft(a) which would satisfy a * (a*b).divLeft(a) == a*b.

divideLeft and divideRight resolve this ambiguity by always returning the "canonical" quotient, namely the one with the fewest possible size-zero input- and output-dimensions.

ConversionPatternRewriter &rewriter) const {
// TODO(jlebar): Implement me.
return failure();
return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the "implement me" comment above?

auto srcIdx = dstToSrc->apply({{kRegister, i}});
outVals.resize(subLayout.getInDimSize(kRegister));
for (int i = 0; i < subLayout.getInDimSize(kRegister); i++) {
auto srcIdx = subLayout.apply({{kRegister, i}});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this change be put into a separate PR? It would make the LinearLayout change here easier to understand in isolation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It couldn't because the following condition doesn't return true any more:

assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==

We must get a subLayout first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps worth writing a comment? "You might be tempted to do X, but it doesn't work because Y."?

enqueue(definingOp->getOperand(0), encoding);
continue;
}
if (canFoldIntoConversion(definingOp, encoding))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would help to put it in a separate PR. You can use one of these tools https://www.stacking.dev/?utm_source=stack-comment to stack your PRs, or you can use my very hacky tool https://github.com/jlebar/git-pr-chain

if (outerCommonDims != innerCommonDims) {
llvm::report_fatal_error("All in/out dimensions common to both layouts "
"must appear in the same relative order, but they "
"don't.\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had \n here because we used to be outputting the in/out dims. Now we're not outputting them anymore, so we should lose the \n (or output the dims again, I thought that was kind of helpful?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry. It got removed accidentally

// Check that elements common to both outerDimsRange and innerDimsRange
// appear in the same relative order.
template <typename T, typename U>
void assertCommonDimsSameOrder(T &&outerDims, U &&innerDims) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure "outerDims" and "innerDims" are the correct names, based on how this function is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it doesn't make sense anymore. Should I just call them dimsA and dimsB?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

divisor.getInDimSizeLog2(inDim));
}

// Record size 1 out-dims caused by the division.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this loop now does two things.

  1. Record size-1 out-dims caused by the division.
  2. Check if newOutDims[outDim] > outDimSize (what the loop used to do).

I think this is confusing. For one thing, there's a comment above the loop which says "Record size 1 out-dims caused by the division." but actually that is misleading, it only mentions one of teh two things done by the loop.

Perhaps we could simply have two loops?

//
// If we remove "out1" from o, we get:
//
// out-dims(l) = ["out0", "out2", "out3"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am having trouble following this example.

It seems to be explaining an important edge case in an algorithm, but the overall algorithm -- the thing we're trying to do -- is not clear to me. The example also has some assumptions that I don't understand. For example, I don't understand how o / r returns anything at all if we remove "out1" from o. Isn't that just an infeasible division? (Unless you're assuming out1 is a size-zero dim? But I don't see how I'm supposed to know that?)

I wonder if the following algorithm works.

  1. Assume c = a*b.
  2. We are given b and c, and we want to compute c.divRight(b). i.e. we want to find a (or some a' which is equivalent to a ignoring size-zero dims).
  3. Let b' be b but with all size-zero in-dims and out-dims removed. Same for c'.
  4. Compute our candidate quotient a' = c' / b', same as before.
  5. Check if a' * b' == c', same as before.
  6. If it matches, then let a'' be a' but with the minimum number of size-zero in- and out-dims added back to it so that a'' * b' = c as desired.
  7. Return a''.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the new algorithm, I think it would fix the same problem as the existing code in this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, I don't understand how o / r returns anything at all if we remove "out1" from o. Isn't that just an infeasible division? (Unless you're assuming out1 is a size-zero dim? But I don't see how I'm supposed to know that?)

Yeah, we assume that "out1" is a size-zero dim. Why did you say "I don't see how I'm supposed to know that"? Is it because of lacking a comment, or you think there's no way to check it? If the former I'll add a comment; I thought it's clear because there's a variable emptyOutDimIndices.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key point is that we cannot remove arbitrary empty output dimensions from the quotient.

The following code simulates quotient * divisor = result and enumerates the output dimensions of the result from right to left to check which ones can be removed. When we perform the multiplication, the output dimensions of the quotient are always placed before the output dimensions of the divisor in the result. So if this order breaks in the result, we should stop enumerating output dimensions.

Therefore, I believe the new algorithm you described solves the same problem as the existing code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some confusion about size-1 or size-0 dimensions though. I call them sizeOneOutDimIndices since the out-dim maps everything to 0 and still has a size of 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you say "I don't see how I'm supposed to know that"? Is it because of lacking a comment, or you think there's no way to check it?

Lacking a comment. :)

I thought it's clear because there's a variable emptyOutDimIndices.

...yeah I'm doing my best to understand what's going on, but I did not consider looking at the code below to understand the algorithm above (and anyway I'm not sure it would have helped me).

There are some confusion about size-1 or size-0 dimensions though. I call them sizeOneOutDimIndices since the out-dim maps everything to 0 and still has a size of 1.

Ah yes, "size-1" dims is correct, I was calling it the wrong thing.

Therefore, I believe the new algorithm you described solves the same problem as the existing code.

I think I would have an easier time understanding the new algorithm I proposed, but if you don't think that's the best approach (who knows, I'm not the one implementing it), it would help me if we could explain the algorithm we're using as a comment in the code. I think you're starting to get there with this:

The following code simulates quotient * divisor = result and enumerates the output dimensions of the result from right to left to check which ones can be removed. When we perform the multiplication, the output dimensions of the quotient are always placed before the output dimensions of the divisor in the result. So if this order breaks in the result, we should stop enumerating output dimensions.

I promise I'm doing my best to understand things here and not just being obstinate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored the algorithm as follows:

  1. Consider c = a * b and construct a candidate quotient a' first without removing any dimensions.
  2. Check if a' * b == c. If yes, we start to remove empty dimensions.
  3. First, we remove empty trailing output dimensions from a'.
  4. Then, we remove empty trailing input dimensions from a'.
  5. Finally, we return the final quotient a''.

Also, I found that we actually allow a linear layout with no input dimensions but with empty output dimensions. Therefore, the hacky len(input_dims) == len(output_dims) condition can be removed, allowing us to use the same code to check if data transfer can happen within each thread.

Copy link
Contributor

@jlebar jlebar Aug 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't empty dims in the middle of a' (i.e. not at either end) also removable? They're only not removable if b does not contain the dimension, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I'm wrong, you explain it in the comment. Thanks. :)

@Jokeren
Copy link
Contributor Author

Jokeren commented Aug 24, 2024

I think it would help to put it in a separate PR. You can use one of these tools https://www.stacking.dev/?utm_source=stack-comment to stack your PRs, or you can use my very hacky tool https://github.com/jlebar/git-pr-chain

Sure, I'll probably just cherry pick the commit out of this PR.

@Jokeren Jokeren force-pushed the keren/eliminate-dims branch from 081d226 to 238c9b4 Compare August 24, 2024 12:25
}

// Check that elements common to both outerDimsRange and innerDimsRange
// appear in the same relative order.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the comment now that you updated the variable names.

//
// If we remove "out1" from o, we get:
//
// out-dims(l) = ["out0", "out2", "out3"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I'm wrong, you explain it in the comment. Thanks. :)

// input and output dimensions that are present in `b`. Therefore, if we want
// divideRight to be the inverse of operator*, there are many possible values
// that we could return for `(a*b).divideRight(b)` which would satisfy
// `((a*b).divideRight(b))*b == a*b`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't match what you do, which is only to remove empty dims at the end of the quotient. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments updated. Thanks for your kind suggestions!

@Jokeren Jokeren merged commit 381ff67 into main Aug 26, 2024
@Jokeren Jokeren deleted the keren/eliminate-dims branch August 26, 2024 19:16
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…ing input and output dimensions (triton-lang#4530)

Before this patch, if `a * b = c`, `c.divideRight(b)` might return
`nullopt` even if `a' * b = c`, where `a'` is the potential result of
`divideRight`.
This PR addresses the issue by conservatively removing input and output
dimensions, ensuring that the division returns a non-nullopt result when
a valid solution exists.
However, it does not guarantee that `a` and `a'` will have identical
input and output dimensions.

In addition, this PR also fixes a bug in
`TritonGPURemoveLayoutConversionsPass`. The backward slice should be
continued when encountering a free conversion. This includes cases where
`c.divideRight(b)` results in a layout that only permutes register
values within individual threads.

---------

Co-authored-by: Thomas Raoux <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants