Skip to content

Commit 50d1c97

Browse files
authored
[DLIGHT][GPU] Add OpenCL dequant matmul schedule (#17187)
* [DLIGHT][GPU] Add OpenCL dequant matmul schedule 1. Enhanced the GPU matmul schedule for OpenCL Android and windows backend. 2. It improves the 2X performance gain for Llama-2-7B prefill process Model device Earlier prefill perf Optimized prefill perf Llama-2-7B-chat-hf Snapdragon® 8 Gen 3 27 tok/sec 50 tok/sec * Update matmul.py
1 parent bbc97c7 commit 50d1c97

File tree

2 files changed

+292
-44
lines changed

2 files changed

+292
-44
lines changed

python/tvm/dlight/gpu/matmul.py

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tvm.tir.analysis import undefined_vars
2828
from tvm.tir.schedule.schedule import BlockRV
2929

30-
from ..base import analysis
30+
from ..base import analysis, BlockInfo, IterInfo
3131
from .base import GPUScheduleRule
3232

3333

@@ -273,6 +273,32 @@ def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]:
273273
)
274274

275275

276+
def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo:
277+
def _iter_kind(loop: tir.IterVar) -> str:
278+
return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O")
279+
280+
def _is_reduction_block(block: tir.schedule.BlockRV):
281+
for iter_var in sch.get(block).iter_vars:
282+
if _iter_kind(iter_var) == "R":
283+
return True
284+
return False
285+
286+
return BlockInfo(
287+
name=sch.get(block).name_hint,
288+
iters=[
289+
IterInfo(
290+
kind=_iter_kind(iter_var),
291+
var=iter_var.var,
292+
dom=iter_var.dom.extent,
293+
loop_rv=loop_rv,
294+
)
295+
for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars)
296+
],
297+
block_rv=block,
298+
reduction_block=_is_reduction_block(block),
299+
)
300+
301+
276302
def get_reduction_blocks(sch, blocks) -> bool:
277303
# Get the main computation block
278304
def is_reduction(block: BlockRV) -> bool:
@@ -914,17 +940,19 @@ def get_configs(self, target: Target) -> Config:
914940
storage_align=True,
915941
inner_x=False,
916942
)
917-
elif target.kind.name == "opencl" and "android" in str(target.host):
943+
elif target.kind.name == "opencl" and (
944+
("android" in str(target.host)) or ("windows" in str(target.host))
945+
):
918946
return Matmul.Config(
919-
block_size_x=8,
920-
block_size_y=16,
947+
block_size_x=32,
948+
block_size_y=8,
921949
vthread_x=1,
922950
vthread_y=1,
923951
micro_size_x=8,
924952
micro_size_y=2,
925953
micro_size_k=16,
926954
vector_size=8,
927-
unroll=64,
955+
unroll=4,
928956
use_shared=False,
929957
storage_align=False,
930958
inner_x=True,
@@ -941,6 +969,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
941969
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
942970
return None
943971
sch = tir.Schedule(func)
972+
config = self.get_configs(target)
944973
root_block = analysis.get_root_block(sch)
945974
blocks = sch.get_child_blocks(root_block)
946975

@@ -953,9 +982,22 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
953982
index_maps = get_index_map(block_stmt)
954983
if index_maps is None:
955984
return None
956-
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
985+
986+
main_block_info = get_block_info(sch, main_block)
987+
iter_infos = main_block_info.iters
988+
989+
# Checks if it's a inner reduction by getting the last matrix's inner Index
990+
def is_inner_reduction(block_stmt, iter_infos):
991+
end_it = block_stmt.reads[-1].region[-1].min
992+
return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == "R"
993+
994+
if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, iter_infos):
995+
ret = self.sch_outer_reduction(sch, config, main_block, blocks)
996+
if ret is not None:
997+
return ret
957998

958999
# Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
1000+
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
9591001
block = sch.reindex(main_block, ("read", 0))
9601002
sch.transform_layout(block, ("write", 0), a_index_map)
9611003
block = sch.reindex(main_block, ("read", 1))
@@ -994,10 +1036,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
9941036
except: # pylint: disable=bare-except
9951037
pass
9961038

997-
# Step 2. Get schedule config.
998-
config = self.get_configs(target)
999-
1000-
# Step 3. Schedule matmul
1039+
# Step 2. Schedule matmul
10011040
y_kernel_size = config.vthread_y * config.block_size_y * config.micro_size_y
10021041
x_kernel_size = config.vthread_x * config.block_size_x * config.micro_size_x
10031042
if config.inner_x:
@@ -1075,3 +1114,88 @@ def _cooperative_fetch(index, vec_len):
10751114

10761115
sch.decompose_reduction(main_block, ko)
10771116
return sch
1117+
1118+
def sch_outer_reduction(
1119+
self,
1120+
sch: tir.Schedule,
1121+
config: Config,
1122+
reduction_block: tir.schedule.BlockRV,
1123+
blocks: List[tir.schedule.BlockRV],
1124+
) -> Optional[tir.Schedule]:
1125+
reduction_loops = sch.get_loops(reduction_block)
1126+
if not len(reduction_loops) == 4:
1127+
return None
1128+
1129+
mb, ms, n, k = reduction_loops
1130+
if not (
1131+
isinstance(sch.get(n).extent, tir.IntImm)
1132+
and isinstance(sch.get(mb).extent, tir.IntImm)
1133+
and isinstance(sch.get(ms).extent, tir.Var)
1134+
):
1135+
return None
1136+
1137+
Threads_X, Threads_Y, VecSize, Unroll_M = (
1138+
config.block_size_x,
1139+
config.block_size_y,
1140+
config.vector_size,
1141+
config.unroll,
1142+
)
1143+
1144+
is_dequant_block = len(blocks) > 1
1145+
if is_dequant_block:
1146+
compute_block, dequant_block, matmul_block = blocks
1147+
sch.compute_inline(compute_block)
1148+
else:
1149+
(matmul_block,) = blocks
1150+
1151+
m = sch.fuse(mb, ms)
1152+
1153+
sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * VecSize, 1])
1154+
1155+
rmat_block, wmat_block = (
1156+
sch.get_producers(matmul_block)[0],
1157+
sch.get_consumers(matmul_block)[0],
1158+
)
1159+
mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M])
1160+
no, ni, nv = sch.split(n, [None, Threads_X, VecSize])
1161+
k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 8])
1162+
sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)
1163+
1164+
sch.compute_at(rmat_block, k0)
1165+
if is_dequant_block:
1166+
sch.compute_at(dequant_block, k3)
1167+
sch.reverse_compute_at(wmat_block, mi)
1168+
sch.set_scope(rmat_block, 0, "shared")
1169+
sch.set_scope(matmul_block, 0, "local")
1170+
if is_dequant_block:
1171+
sch.set_scope(dequant_block, 0, "local")
1172+
1173+
sch.bind(mo, "blockIdx.y")
1174+
sch.bind(no, "blockIdx.x")
1175+
sch.bind(mi, "threadIdx.y")
1176+
sch.bind(ni, "threadIdx.x")
1177+
sch.vectorize(sch.get_loops(matmul_block)[-1])
1178+
if is_dequant_block:
1179+
sch.vectorize(sch.get_loops(dequant_block)[-1])
1180+
1181+
# Co-operative Memory Fetch
1182+
ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize])
1183+
sch.bind(ro, "threadIdx.x")
1184+
sch.vectorize(rv)
1185+
1186+
wv = sch.get_loops(wmat_block)[-1]
1187+
sch.vectorize(wv)
1188+
1189+
# Scale and Quant Cache
1190+
if is_dequant_block:
1191+
qb = sch.cache_read(dequant_block, 0, "local")
1192+
sb = sch.cache_read(dequant_block, 1, "local")
1193+
sch.compute_at(sb, k1)
1194+
sch.compute_at(qb, k2)
1195+
sch.set_scope(sb, 0, "local")
1196+
sch.set_scope(qb, 0, "local")
1197+
sch.vectorize(sch.get_loops(qb)[-1])
1198+
sch.vectorize(sch.get_loops(sb)[-1])
1199+
1200+
sch.decompose_reduction(matmul_block, k0)
1201+
return sch

0 commit comments

Comments
 (0)