PR #19275: [NVIDIA] Add fixes for supporting determinism expander for high-dimensional scatter operation and a flag to disable it #19296
+1,201
−142
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR #19275: [NVIDIA] Add fixes for supporting determinism expander for high-dimensional scatter operation and a flag to disable it
Imported from GitHub PR #19275
This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.
The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.
Moreover, this PR also adds an option
xla_gpu_enable_scatter_determinism_expander
, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable thescatter_determinism_expander
pass without getting blocked.Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
Bugs resolved: jax-ml/jax#17844
Copybara import of the project:
--
3b7b56a by Chenhao Jiang [email protected]:
PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations
Imported from GitHub PR #18326
This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.
The change of this PR is on top of #17886
Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
Bugs resolved: jax-ml/jax#17844
Copybara import of the project:
--
de647d4 by Chenhao Jiang [email protected]:
Support scatter with non-scalar indices and updates
Merging this change closes #18326
PiperOrigin-RevId: 691023328
--
126c952 by Chenhao Jiang [email protected]:
Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.
--
1ecb608 by Chenhao Jiang [email protected]:
Fix the scatter determinism expander for various dimension numbers
--
985079f by Chenhao Jiang [email protected]:
Add a flag for enabling the scatter_determinism_expander on GPU.
Merging this change closes #19275
FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4