Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LU decomposition runtime is dominated by lu_pivots_to_permutation on GPU #5880

Closed
ahoenselaar opened this issue Feb 28, 2021 · 6 comments · Fixed by #6337
Closed

LU decomposition runtime is dominated by lu_pivots_to_permutation on GPU #5880

ahoenselaar opened this issue Feb 28, 2021 · 6 comments · Fixed by #6337
Assignees
Labels
enhancement New feature or request

Comments

@ahoenselaar
Copy link
Contributor

The actual decomposition in handled by cusolver. However, the pivots returned by getrf are transformed into a permutation matrix via lu_pivots_to_permutation. For an (n, n) matrix this transformation is implemented as a loop with n iterations. Even though the amount of data involved is small (int32 vector of length n), this loop performs poorly and its runtime by far exceeds that of getrf.

Trace for a 75x75 doube-precision matrix on a V100.
image

Batching does not improve the situation much. Trace for a batch of 64 double-precision matrices with shape (75, 75):
image

@ahoenselaar ahoenselaar added the enhancement New feature or request label Feb 28, 2021
@mattjj
Copy link
Collaborator

mattjj commented Mar 1, 2021

Thanks for the report!

@mattjj
Copy link
Collaborator

mattjj commented Mar 2, 2021

Based on some comments from @hawknsp, I think the issue is that on GPU each iteration of the loop gets turned into a separate kernel launch, which is a current limitation of XLA:GPU. Our options for improving this include:

  1. help XLA:GPU folks make loops faster (by gaining the ability to generate a single kernel launch for some loops);
  2. write a custom GPU kernel (like we do for PRNG sampling, because there is a bad compile time / execution time tradeoff on GPU);
  3. partially unroll this loop at the JAX Python level.

The first fix seems like the right one for the long term, but I don't know when it'll happen. The last one is more of a mitigation, but it seems like quite a good one, especially since we can just write the fori_loop as a scan, and use scan's built-in unroll value.

@mattjj mattjj assigned mattjj and unassigned hawkinsp Mar 2, 2021
@ahoenselaar
Copy link
Contributor Author

Rewriting as a scan with unroll=16 reduces the runtime by only 20%. The profile shows a large number of same-device memory transfer operations of size 300 (=75 * 4 bytes, i.e. the full permutation vector in this example), likely related to copy ops in the optimized HLO. These copies seem unnecessary and likely hurt performance.

image

1614762065673455.module_0000.before_optimizations.txt
1614762065673455.module_0000.after_optimizations.txt
1614762065673455.module_0000.after_optimizations-buffer-assignment.txt

@mattjj
Copy link
Collaborator

mattjj commented Mar 4, 2021

Wow interesting. Thanks so much for investigating the unrolling option.

@hawkinsp any thoughts on what XLA:GPU could do?

@hawkinsp
Copy link
Collaborator

hawkinsp commented Mar 4, 2021

I think we're going to have to give in and hand-write a kernel for this operation until such time as XLA/GPU improves. It's sort of a worst case for XLA on GPU: loopy code, very low compute intensity in the loop.

@ahoenselaar
Copy link
Contributor Author

TF has implemented a kernel for this very purpose.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
3 participants