Skip to content

Commit 38b85c9

Browse files
[Fix][dlight] add an explicit reduction loop check in Reduce (#17711)
* added an explicit check to verify that the block has a reduction loop since this is assumed in later stages * added unit test to verify that the Reduction schedule is not applied to prim funcs without a reduction loop
1 parent 59d4077 commit 38b85c9

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

python/tvm/dlight/gpu/reduction.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""A rule for reduction. """
17+
"""A rule for reduction."""
1818
# TODO: combine reduction rule and general reduction rule into one file.
1919
from typing import List, Mapping, Optional, Tuple, Union
2020

@@ -47,6 +47,10 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
4747
return buffer_store.value.b
4848

4949

50+
def _has_reduction_loop(block_info):
51+
return any([info.kind == "R" for info in block_info.iters])
52+
53+
5054
class Reduction(GPUScheduleRule):
5155
"""A rule for Reduction."""
5256

@@ -79,6 +83,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
7983
# Step 1. Check reduction block
8084
if (
8185
(not block_info.is_reduction())
86+
or (not _has_reduction_loop(block_info))
8287
or len(block_stmt.writes) != 1
8388
or _get_reduction_expr(block_stmt) is None
8489
):

tests/python/dlight/test_gpu_reduction.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,5 +1152,31 @@ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((
11521152
assert_structural_equal(mod, Expected)
11531153

11541154

1155+
def test_no_reduction_loop_check():
1156+
# The normalized prime func will not contain a reduction loop since its extent is one.
1157+
# This checks that the Reduction schedule is correctly not applied in this case
1158+
# fmt: off
1159+
@I.ir_module
1160+
class Before:
1161+
@T.prim_func(private=True)
1162+
def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")):
1163+
T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
1164+
# with T.block("root"):
1165+
for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(1)):
1166+
with T.block("matmul"):
1167+
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
1168+
T.reads(lv43[v_i0, v_i1, v_k], lv44[v_i0, v_k, v_i2])
1169+
T.writes(matmul[v_i0, v_i1, v_i2])
1170+
with T.init():
1171+
matmul[v_i0, v_i1, v_i2] = T.float16(0.0)
1172+
matmul[v_i0, v_i1, v_i2] = matmul[v_i0, v_i1, v_i2] + lv43[v_i0, v_i1, v_k] * lv44[v_i0, v_k, v_i2]
1173+
# fmt: on
1174+
1175+
target = Target("nvidia/geforce-rtx-3090-ti")
1176+
with target:
1177+
mod = dl.ApplyDefaultSchedule(dl.gpu.Reduction())(Before) # pylint: disable=not-callable
1178+
assert_structural_equal(mod, Before)
1179+
1180+
11551181
if __name__ == "__main__":
11561182
tvm.testing.main()

0 commit comments

Comments
 (0)