Skip to content

[LAYOUTS] Use least squares solution in invertAndCompose#5309

Merged
lezcano merged 5 commits intomainfrom
fix_invert
Dec 5, 2024
Merged

[LAYOUTS] Use least squares solution in invertAndCompose#5309
lezcano merged 5 commits intomainfrom
fix_invert

Conversation

@lezcano
Copy link
Contributor

@lezcano lezcano commented Dec 3, 2024

In this PR, we remove the need fro a few hacks in invertAndCompose,
namely the need for getInjectiveMat which did not work in cases where
the input and the output had a different number of registers (same with
different number of blocks) and led to the implementation of hacks on
top of it like the gymnastics with getFreeVariable.

We now just compute the invertAndCompose as the matrix X which is
the solution to the system AX = B. We add enough asserts to check that
this system has at least one solution (i.e. A is surjective) and we make
explicit the heuristic we use to minimise data-movement (not consider
dimensions that are the same, and otherwise incentivise broadcasting via
choosing the solution of minimal norm). For an explanation of how to solve
this system, see https://github.com/triton-lang/triton/pull/5309/files/a9069c73637a6b4735cdc39d1c7f338cfdd17a8f#r1869084111

In the future, this function would be better returning the compact form
of the system, where if a dimension is not present it's because the
conversion is uniform over that dimension, but for that we need to adapt
our lowering algorithms.

@lezcano lezcano requested a review from ptillet as a code owner December 3, 2024 18:07
@lezcano lezcano changed the title [Layouts] Use least squares solution in invertAndCompose [LAYOUTS] Use least squares solution in invertAndCompose Dec 3, 2024
@lezcano lezcano requested a review from Jokeren as a code owner December 3, 2024 23:01
@lezcano
Copy link
Contributor Author

lezcano commented Dec 3, 2024

This PR effectively reverts most of #4991 as now everything works out of the box given the algorithm for invertAndCompose.

Copy link
Contributor

@Jokeren Jokeren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize, the main difference between this version and the previous one is that identity dimensions are removed before solving the equation. It makes a lot of sense to me

// - If a dimension is the same for both layouts, we want to map it as the
// identity
// Equivalently, we don't add it to the conversion
// - Otherwise, we just call lstsq (i.e. map all the equivalent elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why it's called least square in this case?

Shouldn't A^T be involved to solve Ax=B by forming A^T Ax=A^T B?

Copy link
Contributor Author

@lezcano lezcano Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. A^T is a different kind of object (albeit of the same shape). Here you need in general what's called the pseudo-inverse. See If you've never come across this object, the best description of what's going on in the real case (in terms of the SVD) is in this section.

That being said, you don't need to know about these objects really. Here's an explanation in terms of systems of linear equations:

What's happening here is that A is a wide matrix, so we want to solve AX = B where A is mxk with k >= m, B is mxn and X = kxn. Then, since A is surjective, we know that it has m independent columns. To avoid tracking indices, let's assume they are the first m (but they could be any subset of size m of the k columns it has). This means that we can split A as A = [A_1, A_2] where A_1 is mxm and invertible, and A_2 is mx(k-m). So we now write the system as:

[ A_1  A_2 ] [ X_1 ] = A_1 X_1 + A_2 X_2 = B
             [ X_2 ]

then, since A_1 is invertible, we multiply everything by A_1^{-1} to get:

X_1 = A^{-1}B - A^{-1}A_2X_2

where X_2 is any matrix of our choice (the - here in our case is an xor, which is the minus (and plus) on F_2). In this case, we choose X_2 = 0 because of the broadcasting considerations, so we get our solution:

X = [ A^{-1} B]
    [   0     ]

This is exactly what's going on here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it now. Thanks for the explanation. Can you add pseudo-inverse into the code comments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to do a proper write-up, but in the meantime I can add a link to this comment in the code.


auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced);

// TODO(Lezcano): We should return the reduced layout instead of re-adding the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes more sense to me by returning the layout with the original input dimensions. Maybe we don't need to clean up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then we need to clean them up to see which conversion we need to do, so it's better to return the reduced form, and know that "if a dimension is missing, it's because it's the identity"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then we need to clean them up to see which conversion we need to do

It seems fine to me. Removing dimensions instead seems not intuitive, but feel free to document

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a task to the project to keep better track of this.

SmallVector<int32_t> pivotCols;
for (int r = 0; r < numRows; r++) {
auto row = combinedMat[r];
if (row == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since combinedMat is already int the row-reduced echelon form, if row[i] is all zeros, row[i+1] should all be zeros?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. Note that row-reduced eschelon form can be something of the form:

1 0 0 a
0 0 0 0
0 0 1 b

This means that the output of the linear map in this basis always has a zero on the second coordinate, meaning that it's not surjective. This can happen because we have removed a few dimensions in invertAndCompose (we have removed the dims that are the same in the input and output).

The row eschelon form is one such that (as per wiki):
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get it now. It's the row-reduced echelon form but not the row echelon form. Thanks

continue;
}
int c = __builtin_ctzll(combinedMat[r]);
assert(c < numColsA && "Precondition broken. Im(A) != Im(B)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it guaranteed that c < numColsA -> Im(A) == Im(B)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If numColsA <= c it means that we have found a row (say it's the i-th row) that's linearly dependent on A but not on B. That means that, in the transformed basis, A will always output 0 in the i-th dimension, while B will not. In other words, it means that Im(A) is not a subset of Im(B).
To prove that Im(B) is a subset of Im(A) we would need to check there is no row in which A has a pivot but the transformed B is zero. I will add that that check when extracting the solution using the pivots below, and I'll improve the error message here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the error message. Rather than asserting the other case, I just updated the precondition, as we just need Im(B) \subset Im(A) for the algorithm to work (this way A^{-1}B is well-defined as a function)

// Equivalently, we don't add it to the conversion
// - Otherwise, we just call lstsq (i.e. map all the equivalent elements
// to the same input element) to take advantage of broadcasting in shared
// memory and avoid saving repeated elements in shared memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean broadcasting in “shared memory", do we really duplicate values in shared memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the broadcasting in the reads, as we do now. But more generally, this gives us the following behaviour:

Before, going to shared memory, we always did shmem^{-1} · distributed when going to shared memory, we crafted shmem to be invertible, so the inverse wors just fine, and then we use two facts:

  • That when going to shmem, if multiple threads write to the same location there'll be a race condition and a bank conflict, but this is logically fine because we know that they all have the same data. This race condition is not great perf-wise as they need to be serialised tho.
  • When reading from shmem, broadcasting happens just as expected.

In the future, we will always want to form the convert layout that goes from the to layout to the from layout. This means that we will iterate the elements of the output and we ask "which element from the input do I need to pull this from?". This is how we do the reads from shmem, but for the writes we'll need to change the code, but having the zeros above will give us an approach that will take most advantage of predicated stores that will avoid this stalls, which is nice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiple threads write to the same location there'll be a race condition and a bank conflict

I agree that more store instructions will be issued but I don't think bank conflicts will be caused.
Multiple accesses to the same location (shared memory address) by any number of threads within a warp are served simultaneously.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens when there is a read, but aren't writes to the same memory address serialised?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least it's not called a bank conflict as far as I know. ncu won't report any bank conflicts if all writes go to the same address

@lezcano
Copy link
Contributor Author

lezcano commented Dec 4, 2024

the main difference between this version and the previous one is that identity dimensions are removed before solving the equation.

Not quite. That is one part, but the main part is that we remove the getInjectiveMat hack and we instead implement support for arbitrary matrices "natively" following the high-school algorithm of performing row transformations (AKA Reduced Row Eschelon form RREF) and computing the solution to AX = B of minimal norm from it.

For an explanation of the maths behind, see https://github.com/triton-lang/triton/pull/5309/files/a9069c73637a6b4735cdc39d1c7f338cfdd17a8f#r1869084111

In this PR, we remove the need fro a few hacks in `invertAndCompose`,
namely the need for `getInjectiveMat` which did not work in cases where
the input and the output had a different number of registers (same with
different number of blocks) and led to the implementation of hacks on
top of it like the gymnastics with `getFreeVariable`.

We now just compute the `invertAndCompose` as the matrix `X` which is
the solution to the system `AX = B`. We add enough asserts to check that
this system has at least one solution (i.e. A is surjective) and we make
explicit the heuristic we use to minimise data-movement (not consider
dimensions that are the same, and otherwise incentivise broadcasting via
choosing the solution of minimal norm).

In the future, this function would be better returning the compact form
of the system, where if a dimension is not present it's because the
conversion is uniform over that dimension, but for that we need to adapt
our lowering algorithms.
@lezcano
Copy link
Contributor Author

lezcano commented Dec 5, 2024

Addressed the review

Copy link
Contributor

@Jokeren Jokeren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@lezcano lezcano merged commit 67ea999 into main Dec 5, 2024
@lezcano lezcano deleted the fix_invert branch December 5, 2024 20:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants