[Backend] Fix device assert inside reduction/scan region#4811
Merged
Conversation
Currently the reduction codegen unconditionally executes the combine region which can create problems because we conditionally load from shared memory, so this uses uninitialized registers. Generally combine regions should be pure, so this shouldn't be observable but with the overflow sanitizer the frontend injects assertions into the combine region. This changes the `accumulate` function to take a predicate and if the combine region isn't speculateble we only run it on threads where the predicate is true. In the common case, the codegen is unchanged.
pawelszczerbuk
approved these changes
Sep 26, 2024
Contributor
pawelszczerbuk
left a comment
There was a problem hiding this comment.
Thanks! Looks great!
peterbell10
pushed a commit
that referenced
this pull request
Nov 5, 2024
…on (#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
bertmaher
pushed a commit
that referenced
this pull request
Nov 5, 2024
…on (#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
Luosuu
pushed a commit
to Luosuu/triton
that referenced
this pull request
Nov 13, 2024
…#4811) Currently the reduction codegen unconditionally executes the combine region which can create problems because we conditionally load from shared memory, so this uses uninitialized registers. Generally combine regions should be pure, so this shouldn't be observable but with the overflow sanitizer the frontend injects assertions into the combine region. This changes the `accumulate` function to take a predicate and if the combine region isn't speculateble we only run it on threads where the predicate is true. In the common case, the codegen is unchanged.
Luosuu
pushed a commit
to Luosuu/triton
that referenced
this pull request
Nov 13, 2024
…on (triton-lang#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See triton-lang#4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for triton-lang#4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
guacamoleo
pushed a commit
to guacamoleo/triton
that referenced
this pull request
Nov 14, 2024
…on (triton-lang#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See triton-lang#4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for triton-lang#4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
bertmaher
pushed a commit
to bertmaher/triton
that referenced
this pull request
Dec 10, 2024
…#4811) Currently the reduction codegen unconditionally executes the combine region which can create problems because we conditionally load from shared memory, so this uses uninitialized registers. Generally combine regions should be pure, so this shouldn't be observable but with the overflow sanitizer the frontend injects assertions into the combine region. This changes the `accumulate` function to take a predicate and if the combine region isn't speculateble we only run it on threads where the predicate is true. In the common case, the codegen is unchanged.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Currently the reduction codegen unconditionally executes the combine region which can create problems because we conditionally load from shared memory, so this uses uninitialized registers.
Generally combine regions should be pure, so this shouldn't be observable but with the overflow sanitizer the frontend injects assertions into the combine region.
This changes the
accumulatefunction to take a predicate and if the combine region isn't speculateble we only run it on threads where the predicate is true. In the common case, the codegen is unchanged.