Skip to content

[Backend] Fix device assert inside reduction/scan region#4811

Merged
peterbell10 merged 1 commit into
mainfrom
pb/reduction-assert
Sep 26, 2024
Merged

[Backend] Fix device assert inside reduction/scan region#4811
peterbell10 merged 1 commit into
mainfrom
pb/reduction-assert

Conversation

@peterbell10
Copy link
Copy Markdown
Contributor

@peterbell10 peterbell10 commented Sep 26, 2024

Currently the reduction codegen unconditionally executes the combine region which can create problems because we conditionally load from shared memory, so this uses uninitialized registers.

Generally combine regions should be pure, so this shouldn't be observable but with the overflow sanitizer the frontend injects assertions into the combine region.

This changes the accumulate function to take a predicate and if the combine region isn't speculateble we only run it on threads where the predicate is true. In the common case, the codegen is unchanged.

Currently the reduction codegen unconditionally executes the combine
region which can create problems because we conditionally load from
shared memory, so this uses uninitialized registers.

Generally combine regions should be pure, so this shouldn't be
observable but with the overflow sanitizer the frontend injects
assertions into the combine region.

This changes the `accumulate` function to take a predicate and if the
combine region isn't speculateble we only run it on threads where the
predicate is true. In the common case, the codegen is unchanged.
Copy link
Copy Markdown
Collaborator

@apgoucher apgoucher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

Copy link
Copy Markdown
Contributor

@pawelszczerbuk pawelszczerbuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks great!

@peterbell10 peterbell10 merged commit 1b0f9ea into main Sep 26, 2024
@peterbell10 peterbell10 deleted the pb/reduction-assert branch September 26, 2024 20:56
peterbell10 pushed a commit that referenced this pull request Nov 5, 2024
…on (#5033)

Reductions have special handling for side effectful "combine ops" (e.g.
"add" for a sum reduction). In the presence of side effects, a predicate
is computed to determine whether a thread should participate in the
reduction, to ensure that invalid/uninitialized data is not operated on.
See #4811 for more details.

~Previously, the predicate logic was incorrect for 2D reductions. This
PR fixes the logic and adds a python test.~

Edit: after additional discussion with @peterbell10, we removed the
lanePred logic. Here's our thinking on why this is valid:
* lanePred info is computed based entirely on the blocked layout info
and properties of the reduction
* the blocked layout won't tell you which threads do or don't have
uninitialized data

Instead, it sounds like the motivation for #4811 is based on
uninitialized values that can be indicated by the `pred` variable passed
into `warpReduce()`.
bertmaher pushed a commit that referenced this pull request Nov 5, 2024
…on (#5033)

Reductions have special handling for side effectful "combine ops" (e.g.
"add" for a sum reduction). In the presence of side effects, a predicate
is computed to determine whether a thread should participate in the
reduction, to ensure that invalid/uninitialized data is not operated on.
See #4811 for more details.

~Previously, the predicate logic was incorrect for 2D reductions. This
PR fixes the logic and adds a python test.~

Edit: after additional discussion with @peterbell10, we removed the
lanePred logic. Here's our thinking on why this is valid:
* lanePred info is computed based entirely on the blocked layout info
and properties of the reduction
* the blocked layout won't tell you which threads do or don't have
uninitialized data

Instead, it sounds like the motivation for #4811 is based on
uninitialized values that can be indicated by the `pred` variable passed
into `warpReduce()`.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…#4811)

Currently the reduction codegen unconditionally executes the combine
region which can create problems because we conditionally load from
shared memory, so this uses uninitialized registers.

Generally combine regions should be pure, so this shouldn't be
observable but with the overflow sanitizer the frontend injects
assertions into the combine region.

This changes the `accumulate` function to take a predicate and if the
combine region isn't speculateble we only run it on threads where the
predicate is true. In the common case, the codegen is unchanged.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…on (triton-lang#5033)

Reductions have special handling for side effectful "combine ops" (e.g.
"add" for a sum reduction). In the presence of side effects, a predicate
is computed to determine whether a thread should participate in the
reduction, to ensure that invalid/uninitialized data is not operated on.
See triton-lang#4811 for more details.

~Previously, the predicate logic was incorrect for 2D reductions. This
PR fixes the logic and adds a python test.~

Edit: after additional discussion with @peterbell10, we removed the
lanePred logic. Here's our thinking on why this is valid:
* lanePred info is computed based entirely on the blocked layout info
and properties of the reduction
* the blocked layout won't tell you which threads do or don't have
uninitialized data

Instead, it sounds like the motivation for triton-lang#4811 is based on
uninitialized values that can be indicated by the `pred` variable passed
into `warpReduce()`.
guacamoleo pushed a commit to guacamoleo/triton that referenced this pull request Nov 14, 2024
…on (triton-lang#5033)

Reductions have special handling for side effectful "combine ops" (e.g.
"add" for a sum reduction). In the presence of side effects, a predicate
is computed to determine whether a thread should participate in the
reduction, to ensure that invalid/uninitialized data is not operated on.
See triton-lang#4811 for more details.

~Previously, the predicate logic was incorrect for 2D reductions. This
PR fixes the logic and adds a python test.~

Edit: after additional discussion with @peterbell10, we removed the
lanePred logic. Here's our thinking on why this is valid:
* lanePred info is computed based entirely on the blocked layout info
and properties of the reduction
* the blocked layout won't tell you which threads do or don't have
uninitialized data

Instead, it sounds like the motivation for triton-lang#4811 is based on
uninitialized values that can be indicated by the `pred` variable passed
into `warpReduce()`.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…#4811)

Currently the reduction codegen unconditionally executes the combine
region which can create problems because we conditionally load from
shared memory, so this uses uninitialized registers.

Generally combine regions should be pure, so this shouldn't be
observable but with the overflow sanitizer the frontend injects
assertions into the combine region.

This changes the `accumulate` function to take a predicate and if the
combine region isn't speculateble we only run it on threads where the
predicate is true. In the common case, the codegen is unchanged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants