Skip to content

Commit e5f85c0

Browse files
authored
[DLIGHT][ADRENO] Fix for opencl adreno matmul schedule (#17259)
Fixed the matmul schedule for the case of epilog blocks
1 parent 6ae2961 commit e5f85c0

File tree

2 files changed

+85
-54
lines changed

2 files changed

+85
-54
lines changed

python/tvm/dlight/gpu/matmul.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ def get_configs(self, target: Target) -> Config:
941941
inner_x=False,
942942
)
943943
elif target.kind.name == "opencl" and (
944-
("android" in str(target.host)) or ("windows" in str(target.host))
944+
("android" in str(target.host)) or ("adreno" in str(target.attrs))
945945
):
946946
return Matmul.Config(
947947
block_size_x=32,
@@ -991,7 +991,10 @@ def is_inner_reduction(block_stmt, iter_infos):
991991
end_it = block_stmt.reads[-1].region[-1].min
992992
return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == "R"
993993

994-
if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, iter_infos):
994+
if (
995+
target.kind.name == "opencl"
996+
and (("android" in str(target.host)) or ("adreno" in str(target.attrs)))
997+
) and not is_inner_reduction(block_stmt, iter_infos):
995998
ret = self.sch_outer_reduction(sch, config, main_block, blocks)
996999
if ret is not None:
9971000
return ret
@@ -1122,6 +1125,16 @@ def sch_outer_reduction(
11221125
reduction_block: tir.schedule.BlockRV,
11231126
blocks: List[tir.schedule.BlockRV],
11241127
) -> Optional[tir.Schedule]:
1128+
1129+
"""Get vectorization factor"""
1130+
1131+
def get_max_factor(n, factors):
1132+
factors = sorted(factors, reverse=True)
1133+
for factor in factors:
1134+
if n % factor == 0:
1135+
return factor
1136+
return 1
1137+
11251138
reduction_loops = sch.get_loops(reduction_block)
11261139
if not len(reduction_loops) == 4:
11271140
return None
@@ -1140,13 +1153,17 @@ def sch_outer_reduction(
11401153
config.vector_size,
11411154
config.unroll,
11421155
)
1143-
1144-
is_dequant_block = len(blocks) > 1
1145-
if is_dequant_block:
1146-
compute_block, dequant_block, matmul_block = blocks
1147-
sch.compute_inline(compute_block)
1148-
else:
1149-
(matmul_block,) = blocks
1156+
VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize)
1157+
dequant_block = None
1158+
matmul_block = reduction_block
1159+
epilogue_block = None
1160+
if blocks[-1] is not matmul_block:
1161+
epilogue_block = blocks[-1]
1162+
for blk in blocks[:-1]:
1163+
if "dequantize" in sch.get(blk).name_hint:
1164+
dequant_block = blk
1165+
elif blk is not matmul_block:
1166+
sch.compute_inline(blk)
11501167

11511168
m = sch.fuse(mb, ms)
11521169

@@ -1162,20 +1179,21 @@ def sch_outer_reduction(
11621179
sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)
11631180

11641181
sch.compute_at(rmat_block, k0)
1165-
if is_dequant_block:
1182+
if dequant_block is not None:
11661183
sch.compute_at(dequant_block, k3)
11671184
sch.reverse_compute_at(wmat_block, mi)
11681185
sch.set_scope(rmat_block, 0, "shared")
11691186
sch.set_scope(matmul_block, 0, "local")
1170-
if is_dequant_block:
1187+
1188+
if dequant_block is not None:
11711189
sch.set_scope(dequant_block, 0, "local")
11721190

11731191
sch.bind(mo, "blockIdx.y")
11741192
sch.bind(no, "blockIdx.x")
11751193
sch.bind(mi, "threadIdx.y")
11761194
sch.bind(ni, "threadIdx.x")
11771195
sch.vectorize(sch.get_loops(matmul_block)[-1])
1178-
if is_dequant_block:
1196+
if dequant_block is not None:
11791197
sch.vectorize(sch.get_loops(dequant_block)[-1])
11801198

11811199
# Co-operative Memory Fetch
@@ -1187,7 +1205,7 @@ def sch_outer_reduction(
11871205
sch.vectorize(wv)
11881206

11891207
# Scale and Quant Cache
1190-
if is_dequant_block:
1208+
if dequant_block is not None:
11911209
qb = sch.cache_read(dequant_block, 0, "local")
11921210
sb = sch.cache_read(dequant_block, 1, "local")
11931211
sch.compute_at(sb, k1)
@@ -1197,5 +1215,11 @@ def sch_outer_reduction(
11971215
sch.vectorize(sch.get_loops(qb)[-1])
11981216
sch.vectorize(sch.get_loops(sb)[-1])
11991217

1218+
if epilogue_block is not None:
1219+
sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True)
1220+
sch.set_scope(wmat_block, 0, "local")
1221+
sch.compute_inline(wmat_block)
1222+
sch.vectorize(sch.get_loops(epilogue_block)[-1])
1223+
12001224
sch.decompose_reduction(matmul_block, k0)
12011225
return sch

tests/python/dlight/test_gpu_matmul.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -685,47 +685,54 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
685685
class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
686686
# fmt: off
687687
@T.prim_func
688-
def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle):
688+
def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle):
689689
T.func_attr({"tir.noalias": T.bool(True)})
690690
seq_len = T.int64()
691-
rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16")
692-
matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
691+
rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16")
692+
T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
693693
# with T.block("root"):
694694
compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
695695
dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
696+
matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(12288)), "float16")
696697
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
697698
with T.block("compute"):
698699
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
699-
T.reads(lv840[v_i0 // T.int64(8), v_i1])
700+
T.reads(lv452[v_i0 // T.int64(8), v_i1])
700701
T.writes(compute[v_i0, v_i1])
701-
compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
702+
compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
702703
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
703704
with T.block("dequantize"):
704705
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
705-
T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1])
706+
T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1])
706707
T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
707-
dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1]
708+
dequantize_intermediate_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1]
708709
for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288), T.int64(4096)):
709710
with T.block("matmul"):
710711
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
711-
T.reads(rms_norm260[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2])
712+
T.reads(rms_norm130[v_i0, v_i1, v_k], dequantize_intermediate_intermediate[v_k, v_i2])
712713
T.writes(matmul_intermediate[v_i0, v_i1, v_i2])
713714
with T.init():
714715
matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
715-
matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2]
716+
matmul_intermediate[v_i0, v_i1, v_i2] = matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate[v_k, v_i2]
717+
for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)):
718+
with T.block("T_add"):
719+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
720+
T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2])
721+
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
722+
T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]
716723

717724
@T.prim_func
718-
def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260: T.handle, p_output0: T.handle):
725+
def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle):
719726
T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
720727
seq_len = T.int64()
721-
rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len, T.int64(4096)), "float16")
722-
matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
728+
rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16")
729+
T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16")
723730
# with T.block("root"):
724731
dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local")
725-
rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared")
732+
rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared")
726733
matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", scope="local")
727-
lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local")
728-
lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local")
734+
lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local")
735+
lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local")
729736
for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
730737
for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.y"):
731738
for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
@@ -743,57 +750,57 @@ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv841: T
743750
for ax0 in range(T.int64(4)):
744751
for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
745752
for ax1_1 in T.vectorized(T.int64(8)):
746-
with T.block("rms_norm260_pad"):
753+
with T.block("rms_norm130_pad"):
747754
v0 = T.axis.spatial(T.int64(1), T.int64(0))
748755
v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
749756
v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
750-
T.reads(rms_norm260[v0, v1, v2])
751-
T.writes(rms_norm260_pad_shared[v0, v1, v2])
752-
rms_norm260_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0))
757+
T.reads(rms_norm130[v0, v1, v2])
758+
T.writes(rms_norm130_pad_shared[v0, v1, v2])
759+
rms_norm130_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0))
753760
for k_1 in range(T.int64(8)):
754761
for ax0 in T.vectorized(T.int64(8)):
755-
with T.block("lv841_local"):
762+
with T.block("lv453_local"):
756763
v0 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + k_1)
757764
v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
758-
T.reads(lv841[v0, v1])
759-
T.writes(lv841_local[v0, v1])
760-
lv841_local[v0, v1] = lv841[v0, v1]
765+
T.reads(lv453[v0, v1])
766+
T.writes(lv453_local[v0, v1])
767+
lv453_local[v0, v1] = lv453[v0, v1]
761768
for k_2 in range(T.int64(4)):
762769
for ax0 in T.vectorized(T.int64(8)):
763-
with T.block("lv840_local"):
770+
with T.block("lv452_local"):
764771
v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)
765772
v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
766-
T.reads(lv840[v0, v1])
767-
T.writes(lv840_local[v0, v1])
768-
lv840_local[v0, v1] = lv840[v0, v1]
773+
T.reads(lv452[v0, v1])
774+
T.writes(lv452_local[v0, v1])
775+
lv452_local[v0, v1] = lv452[v0, v1]
769776
for k_3 in range(T.int64(8)):
770777
for ax0 in T.vectorized(T.int64(8)):
771778
with T.block("dequantize"):
772779
v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3)
773780
v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
774-
T.reads(lv840_local[v_i0 // T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1])
781+
T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1])
775782
T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1])
776-
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1]
783+
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1]
777784
for i0_i1_fused_2 in range(T.int64(4)):
778785
for i2_2 in T.vectorized(T.int64(8)):
779786
with T.block("matmul_update"):
780787
v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
781788
v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
782789
v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
783790
v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3)
784-
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm260_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2])
791+
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2])
785792
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
786-
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
787-
for ax0 in range(T.int64(4)):
788-
for ax1 in T.vectorized(T.int64(8)):
789-
with T.block("matmul_intermediate_pad"):
790-
v0 = T.axis.spatial(T.int64(1), T.int64(0))
791-
v1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
792-
v2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1)
793-
T.where((i0_i1_fused_0 - (seq_len + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len)
794-
T.reads(matmul_intermediate_pad_local[v0, v1, v2])
795-
T.writes(matmul_intermediate[v0, v1, v2])
796-
matmul_intermediate[v0, v1, v2] = matmul_intermediate_pad_local[v0, v1, v2]
793+
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
794+
for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
795+
for ax2 in T.vectorized(T.int64(8)):
796+
with T.block("T_add"):
797+
v_ax0 = T.axis.spatial(T.int64(1), ax0)
798+
v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1)
799+
v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2)
800+
T.where(i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1 < seq_len)
801+
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2])
802+
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
803+
T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]
797804
# fmt: on
798805

799806

0 commit comments

Comments
 (0)