Skip to content

Commit e542043

Browse files
authored
[Dlight] Skip GeMV when normalization fails (#16665)
Prior to this PR, GeMV does not skip the cases of normalization failure, which leads to error. This PR fixes this issue. A unit test is added accordingly.
1 parent e56c5e1 commit e542043

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

python/tvm/dlight/gpu/gemv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
180180
sch = tir.Schedule(func)
181181
block_infos = normalize_prim_func(sch)
182182
block_infos = try_inline_contiguous_spatial(sch, block_infos)
183+
if block_infos is None:
184+
return None
183185
if len(block_infos) == 1:
184186
epilogue = None
185187
elif len(block_infos) == 2:

tests/python/dlight/test_gpu_gemv.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,5 +996,38 @@ def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "f
996996
tvm.ir.assert_structural_equal(mod["main"], expected)
997997

998998

999+
def test_func_to_skip():
1000+
@T.prim_func
1001+
def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len: T.int64):
1002+
data_buf = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32", align=8)
1003+
output_buf = T.match_buffer(
1004+
var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32", align=8
1005+
)
1006+
with T.block("exclusive_scan_thrust"):
1007+
T.reads()
1008+
T.writes()
1009+
T.call_packed(
1010+
"tvm.contrib.thrust.sum_scan",
1011+
T.tvm_stack_make_array(
1012+
data_buf.data, T.tvm_stack_make_shape(seq_len * T.int64(8)), 0, 1, 0, T.int64(0)
1013+
),
1014+
T.tvm_stack_make_array(
1015+
output_buf.data,
1016+
T.tvm_stack_make_shape(seq_len * T.int64(8)),
1017+
0,
1018+
1,
1019+
0,
1020+
T.int64(0),
1021+
),
1022+
T.bool(False),
1023+
)
1024+
1025+
# This function should be skipped.
1026+
mod = tvm.IRModule({"main": before})
1027+
with Target("metal"):
1028+
mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
1029+
tvm.ir.assert_structural_equal(mod["main"], before)
1030+
1031+
9991032
if __name__ == "__main__":
10001033
tvm.testing.main()

0 commit comments

Comments
 (0)