[Backend] Fix predicates for device assert inside reduction/scan region#5033
Conversation
| pytest.skip() | ||
|
|
||
| @triton.jit(debug=True) | ||
| @triton.jit |
There was a problem hiding this comment.
I think the debug=True needs to be added as a kwarg to the invocation of the triton kernel. Previously I wasn't seeing any asserts in the ttgir
There was a problem hiding this comment.
Thanks for spotting this, I've opened #5037 to fix it
| // Predicate to ensure we don't read from invalid memory. | ||
| // definitions: | ||
| // "Lane": the strip of values that are being reduced along. | ||
| // relevant variables: | ||
| // interleave: for two consecutive elements in a lane, the difference | ||
| // between their thread ids is the interleave. | ||
| // numLanesToReduce: how many lanes we're reducing across. | ||
| // totalNumLanes: how many lanes exist in total. If the reduction | ||
| // skips some threads, totalNumLanes might not equal numLanesToReduce. |
There was a problem hiding this comment.
@peterbell10 is this accurate? tbh I didn't quite understand what scenario requires a predicate - I verified that this fixes my scenario, but I don't know if it regresses the scenario you were initially targeting.
| Value laneId = | ||
| urem(udiv(threadId, i32_val(interleave)), i32_val(totalNumLanes)); | ||
| Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); |
There was a problem hiding this comment.
The definition of lane is the position of a thread within its warp, so this is a bit confusing. Would it work to do this?
| Value laneId = | |
| urem(udiv(threadId, i32_val(interleave)), i32_val(totalNumLanes)); | |
| Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); | |
| Value laneId = urem(threadId, warpSize); | |
| Value lanePred = icmp_slt(laneId, i32_val(totalNumLanes * interleave)); |
There was a problem hiding this comment.
@peterbell10 thanks for the suggestion!
Instead I'm using
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce * interleave));since I presume that the reason for predicating is due to the difference between numLaneToReduce vs. totalNumLanes?
fe5c1fc to
d430559
Compare
|
note: the other test_side_effectful_reduction and side_effectful_scan tests are failing after #5035, but somehow not failing on the test_side_effectful_reduction_2d test added by this PR. |
eb32075 to
2ea66e9
Compare
2614adc to
2ea66e9
Compare
In upstream triton, triton-lang/triton#4589 introduces overflow checks. However, overflow checks likely add some overhead, and have some correctness bugs at the moment (e.g. triton-lang/triton#5033). Let's set `sanitize_overflow=False` but keep `debug=True` so that we can keep using device_assert but without the additional asserts added by `sanitize_overflow`. Pull Request resolved: #139502 Approved by: https://github.com/bertmaher
2ea66e9 to
40986be
Compare
40986be to
4e5ba83
Compare
Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See [] for more details. Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.
4e5ba83 to
1f3198c
Compare
…9502) In upstream triton, triton-lang/triton#4589 introduces overflow checks. However, overflow checks likely add some overhead, and have some correctness bugs at the moment (e.g. triton-lang/triton#5033). Let's set `sanitize_overflow=False` but keep `debug=True` so that we can keep using device_assert but without the additional asserts added by `sanitize_overflow`. Pull Request resolved: pytorch#139502 Approved by: https://github.com/bertmaher
…on (#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
…can region (triton-lang#5033)" This reverts commit 732aee7.
…on (triton-lang#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See triton-lang#4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for triton-lang#4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
…riton-lang#5075) This is a follow up to triton-lang#5033 but for scan ops, and also improving the testing as it was clearly insufficient before.
…on (triton-lang#5033) Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See triton-lang#4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for triton-lang#4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details.
Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid:
Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the
predvariable passed intowarpReduce().