Skip to content

Commit 5d715aa

Browse files
committed
[Dlight] Enhance Decode-GEMV Rules
1 parent b51bee8 commit 5d715aa

File tree

4 files changed

+180
-68
lines changed

4 files changed

+180
-68
lines changed

python/tvm/dlight/gpu/decode_gemv.py

Lines changed: 88 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
# pylint: disable=missing-docstring
1818
"""A fallback schedule rule for GPU operators."""
1919
# pylint: disable=invalid-name
20-
21-
from typing import List, Optional, Union
20+
from typing import List, Optional, Set, Tuple, Union
2221

2322
from tvm import tir
2423
from tvm._ffi import get_global_func
@@ -27,6 +26,7 @@
2726
from tvm.target import Target
2827

2928
from ..base import ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial
29+
from . import utils
3030

3131

3232
def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
@@ -47,7 +47,7 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
4747

4848
def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
4949
dominant_read, read_iters = None, None
50-
tir_vars = set()
50+
tir_vars: Set[tir.Var] = set()
5151
for buffer_region in block.reads:
5252
tir_vars.clear()
5353

@@ -67,25 +67,37 @@ def _collect_tir_var(e):
6767
return result
6868

6969

70-
class DecodeGEMV(ScheduleRule):
71-
def __init__(self) -> None:
72-
super().__init__()
73-
self.get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType")
70+
_get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType")
71+
72+
73+
def _fuse_spatial_reduction_loops(
74+
sch: tir.Schedule,
75+
loops: List[tir.schedule.LoopRV],
76+
) -> Tuple[tir.schedule.LoopRV, tir.schedule.LoopRV]:
77+
s_loops, r_loops = [], []
78+
for loop_rv in loops:
79+
iter_type = _get_loop_iter_type(sch, loop_rv)
80+
if iter_type == "S":
81+
s_loops.append(loop_rv)
82+
elif iter_type == "R":
83+
r_loops.append(loop_rv)
84+
else:
85+
raise RuntimeError("Unknown loop type " + str(iter_type))
86+
sch.reorder(*s_loops, *r_loops)
87+
s_ctr = sch.fuse(*s_loops)
88+
r_ctr = sch.fuse(*r_loops)
89+
return s_ctr, r_ctr
7490

75-
def apply( # pylint: disable=too-many-locals
91+
92+
class DecodeGEMV(ScheduleRule):
93+
def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
7694
self,
7795
func: tir.PrimFunc,
7896
target: Target,
7997
_: bool,
8098
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
8199
if not isinstance(func, tir.PrimFunc):
82100
return None
83-
84-
if target.kind.name == "cuda":
85-
len_tx, len_ty = 16, 16
86-
else:
87-
len_tx, len_ty = 8, 8
88-
89101
sch = tir.Schedule(func)
90102
block_infos = try_inline_contiguous_spatial(sch, normalize_prim_func(sch))
91103

@@ -113,6 +125,7 @@ def apply( # pylint: disable=too-many-locals
113125
return None
114126
iter_to_info = {i.var: i for i in block_info.iters}
115127
s_loops, r_loops, c_loops = [], [], []
128+
c_loop_factor = None
116129
for split in sorted_iter_access.args:
117130
block_var = split.source.source
118131
block_var_info = iter_to_info[block_var]
@@ -122,71 +135,82 @@ def apply( # pylint: disable=too-many-locals
122135
c_loop_factor = split.lower_factor
123136
loop_rv, c_loop = sch.split(loop_rv, factors=[None, c_loop_factor])
124137
c_loops.append(c_loop)
125-
is_loop_c_reduction = is_inner_reduction
138+
if is_inner_reduction:
139+
c_loop_factor = None
126140
if is_inner_reduction:
127141
r_loops.append(loop_rv)
128142
else:
129143
s_loops.append(loop_rv)
130144

131-
if len(c_loops) > 1:
145+
if len(c_loops) > 1 or len(s_loops) == 0 or len(r_loops) == 0:
132146
return None
133147
if len(s_loops) != len([_ for i in block_info.iters if i.kind == "S"]):
134148
return None
135-
if len(s_loops) == 0 or len(r_loops) == 0:
136-
return None
137149

138150
sch.reorder(*s_loops, *r_loops, *c_loops)
139151
s = sch.fuse(*s_loops)
140152
r = sch.fuse(*r_loops)
141-
142-
if is_inner_reduction:
143-
_, tx = sch.split(r, factors=[None, len_tx * len_ty])
144-
rf = sch.rfactor(tx, 0)
145-
s, r, tx = sch.get_loops(rf)[:3]
146-
sch.reorder(s, tx, r)
147-
sch.reverse_compute_at(block, s, preserve_unit_loops=True)
148-
sch.bind(tx, "threadIdx.x")
149-
sch.bind(s, "blockIdx.x")
150-
else:
151-
sch.split(s, factors=[None, len_tx])
152-
_, ty = sch.split(r, factors=[None, len_ty])
153-
rf = sch.rfactor(ty, 0)
154-
bx, tx, r, ty = sch.get_loops(rf)[:4]
155-
sch.reorder(bx, tx, ty, r)
156-
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
157-
sch.bind(tx, "threadIdx.x")
158-
sch.bind(ty, "threadIdx.y")
159-
sch.bind(bx, "blockIdx.x")
160-
161-
s_loops, r_loops = [], []
162-
for loop_rv in sch.get_loops(block)[1:]:
163-
iter_type = self.get_loop_iter_type(sch, loop_rv)
164-
if iter_type == "S":
165-
s_loops.append(loop_rv)
166-
elif iter_type == "R":
167-
r_loops.append(loop_rv)
168-
else:
169-
raise RuntimeError("Unknown loop type " + str(iter_type))
170-
sch.reorder(*s_loops, *r_loops)
171-
s_ctr = sch.fuse(*s_loops)
172-
r_ctr = sch.fuse(*r_loops)
173-
174-
if c_loops and not is_loop_c_reduction:
175-
s_ctr, inner = sch.split(s_ctr, factors=[None, c_loop_factor])
176-
sch.reorder(s_ctr, r_ctr, inner)
177-
178153
if is_inner_reduction:
179-
sch.bind(r_ctr, "threadIdx.x")
180-
sch.set_scope(rf, 0, "local")
181-
sch.decompose_reduction(rf, sch.get_loops(rf)[2])
154+
self._sch_inner_reduction(sch, block, target, s, r, c_loop_factor)
182155
else:
183-
sch.bind(s_ctr, "threadIdx.x")
184-
sch.bind(r_ctr, "threadIdx.y")
185-
sch.set_scope(rf, 0, "local")
186-
sch.decompose_reduction(rf, sch.get_loops(rf)[3])
187-
156+
self._sch_inner_spatial(sch, block, target, s, r, c_loop_factor)
188157
if len(block_infos) == 2:
189158
sch.set_scope(block, 0, "local")
190159
sch.reverse_compute_at(block_infos[1].block_rv, sch.get_loops(block)[0])
191-
192160
return sch
161+
162+
def _sch_inner_reduction( # pylint: disable=too-many-arguments
163+
self,
164+
sch: tir.Schedule,
165+
block: tir.schedule.BlockRV,
166+
target: Target,
167+
_: tir.schedule.LoopRV,
168+
r: tir.schedule.LoopRV,
169+
unroll_spatial_factor: Optional[int],
170+
):
171+
(len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking
172+
target, [sch.get(r)]
173+
)
174+
175+
_, tx = sch.split(r, factors=[None, len_tx])
176+
rf = sch.rfactor(tx, 0)
177+
s, r, tx = sch.get_loops(rf)[:3]
178+
sch.reorder(s, tx, r)
179+
sch.reverse_compute_at(block, s, preserve_unit_loops=True)
180+
sch.bind(tx, "threadIdx.x")
181+
sch.bind(s, "blockIdx.x")
182+
s_ctr, r_ctr = _fuse_spatial_reduction_loops(sch, sch.get_loops(block)[1:])
183+
if unroll_spatial_factor:
184+
s_ctr, inner = sch.split(s_ctr, factors=[None, unroll_spatial_factor])
185+
sch.reorder(s_ctr, r_ctr, inner)
186+
sch.bind(r_ctr, "threadIdx.x")
187+
sch.set_scope(rf, 0, "local")
188+
sch.decompose_reduction(rf, sch.get_loops(rf)[2])
189+
190+
def _sch_inner_spatial( # pylint: disable=too-many-locals,too-many-arguments
191+
self,
192+
sch: tir.Schedule,
193+
block: tir.schedule.BlockRV,
194+
target: Target,
195+
s: tir.schedule.LoopRV,
196+
r: tir.schedule.LoopRV,
197+
unroll_spatial_factor: Optional[int],
198+
):
199+
len_tx, len_ty = 16, 16
200+
sch.split(s, factors=[None, len_tx])
201+
_, ty = sch.split(r, factors=[None, len_ty])
202+
rf = sch.rfactor(ty, 0)
203+
bx, tx, r, ty = sch.get_loops(rf)[:4]
204+
sch.reorder(bx, tx, ty, r)
205+
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
206+
sch.bind(tx, "threadIdx.x")
207+
sch.bind(ty, "threadIdx.y")
208+
sch.bind(bx, "blockIdx.x")
209+
s_ctr, r_ctr = _fuse_spatial_reduction_loops(sch, sch.get_loops(block)[1:])
210+
if unroll_spatial_factor:
211+
s_ctr, inner = sch.split(s_ctr, factors=[None, unroll_spatial_factor])
212+
sch.reorder(s_ctr, r_ctr, inner)
213+
sch.bind(s_ctr, "threadIdx.x")
214+
sch.bind(r_ctr, "threadIdx.y")
215+
sch.set_scope(rf, 0, "local")
216+
sch.decompose_reduction(rf, sch.get_loops(rf)[3])

python/tvm/dlight/gpu/fallback.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from tvm import tir
2222
from tvm.target import Target
2323

24-
from ..base import ScheduleRule, analysis, normalize_prim_func, try_inline
24+
from ..base import ScheduleRule, normalize_prim_func, try_inline
25+
from . import utils
2526

2627

2728
class Fallback(ScheduleRule):
@@ -36,7 +37,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
3637
target: Target,
3738
_: bool,
3839
) -> tir.Schedule:
39-
max_threads_per_block = analysis.get_max_threads_per_block(target)
40+
max_threads_per_block = utils.max_threads_per_block(target)
4041

4142
sch = tir.Schedule(func)
4243
block_infos = try_inline(sch, normalize_prim_func(sch))

python/tvm/dlight/gpu/utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
"""Utility methods for generic GPU."""
19+
from typing import List, Optional
20+
21+
from tvm import tir
22+
from tvm.target import Target
23+
24+
25+
def max_threads_per_block(target: Target) -> int:
26+
"""Get the maximum number of threads per block for a given target.
27+
28+
Parameters
29+
----------
30+
target : Target
31+
The target to get the maximum number of threads per block for.
32+
33+
Returns
34+
-------
35+
max_threads_per_block : int
36+
The maximum number of threads per block for the given target.
37+
"""
38+
for name in ["max_threads_per_block", "max_num_threads"]:
39+
result = target.attrs.get(name, None)
40+
if result is not None:
41+
return result
42+
if target.kind.name == "cuda":
43+
return 1024
44+
return 256
45+
46+
47+
def suggest_threads_per_block(
48+
target: Target,
49+
loops: List[tir.For],
50+
max_threads_for_dynamic_loop: int = 32,
51+
) -> List[int]:
52+
if target.kind.name == "cuda":
53+
threads = 256
54+
else:
55+
threads = 64
56+
results: List[Optional[int]] = []
57+
dynamic: List[int] = []
58+
for i, loop in enumerate(loops):
59+
loop_extent = loop.extent
60+
if isinstance(loop_extent, tir.IntImm):
61+
loop_extent = loop_extent.value
62+
extent = 1
63+
while extent <= loop_extent and extent <= threads:
64+
extent *= 2
65+
extent //= 2
66+
assert extent >= 1
67+
assert threads % extent == 0
68+
threads //= extent
69+
results.append(extent)
70+
else:
71+
results.append(None)
72+
dynamic.append(i)
73+
74+
for i in dynamic:
75+
extent = 1
76+
while extent <= max_threads_for_dynamic_loop and extent <= threads:
77+
extent *= 2
78+
extent //= 2
79+
assert extent >= 1
80+
assert threads % extent == 0
81+
threads //= extent
82+
results[i] = extent
83+
84+
if dynamic:
85+
results[dynamic[0]] *= threads
86+
87+
return results

tests/python/dlight/test_gpu_reduction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def main(p_lv44: T.handle, p_output0: T.handle):
100100
v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n + ax0)
101101
v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
102102
v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) + ax2_fused_1)
103-
T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n and ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
103+
T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
104104
T.reads(lv44[T.int64(0), v0, v1, v2])
105105
T.writes(T_softmax_maxelem_shared[T.int64(0), v0, v1])
106106
with T.init():
@@ -112,7 +112,7 @@ def main(p_lv44: T.handle, p_output0: T.handle):
112112
v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n + ax0)
113113
v1 = T.axis.spatial(n, ax0_ax1_fused % n + ax1)
114114
v2 = T.axis.reduce(m, ax2_fused_0 * T.int64(256) + ax2_fused_1)
115-
T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n and ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
115+
T.where(ax2_fused_0 * T.int64(256) + ax2_fused_1 < m)
116116
T.reads(lv44[T.int64(0), v0, v1, v2], T_softmax_maxelem_shared[T.int64(0), v0, v1])
117117
T.writes(T_softmax_expsum_shared[T.int64(0), v0, v1])
118118
with T.init():

0 commit comments

Comments
 (0)