Skip to content

Commit 5c49ac9

Browse files
author
Siyuan Feng
committed
[DLight] Skip rule if target is not suitable
This PR adds a check for GPU rules to skip if the target is not suitable for the rule.
1 parent 4a7e4fe commit 5c49ac9

File tree

8 files changed

+87
-20
lines changed

8 files changed

+87
-20
lines changed

python/tvm/dlight/base/schedule_rule.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,18 @@ def apply(
103103
return _Rule()
104104

105105
return decorator
106+
107+
def is_target_available(self, target: Target) -> bool: # pylint: disable=unused-argument
108+
"""Check whether the rule is available for the given target.
109+
110+
Parameters
111+
----------
112+
target : Target
113+
The compilation target the schedule is supposed to be built for.
114+
115+
Returns
116+
-------
117+
available : bool
118+
Whether the rule is available for the given target.
119+
"""
120+
return True

python/tvm/dlight/gpu/base.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Base schedule rule for GPU operators."""
18+
19+
from tvm.target import Target
20+
21+
from ..base import ScheduleRule
22+
23+
24+
class GPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods
25+
"""The Schedule Rule specific to GPU targets, will return None if the target is not GPU."""
26+
27+
def is_target_available(self, target: Target) -> bool:
28+
"""Check whether the target is available for gpu rule.
29+
30+
Parameters
31+
----------
32+
target : Target
33+
The compilation target to check.
34+
35+
Returns
36+
-------
37+
available : bool
38+
Whether the target is available for this rule.
39+
"""
40+
return super().is_target_available(target) and "gpu" in target.keys

python/tvm/dlight/gpu/fallback.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
from tvm import tir
2222
from tvm.target import Target
2323

24-
from ..base import ScheduleRule, normalize_prim_func, try_inline
24+
from ..base import normalize_prim_func, try_inline
2525
from . import utils
26+
from .base import GPUScheduleRule
2627

2728

28-
class Fallback(ScheduleRule):
29+
class Fallback(GPUScheduleRule):
2930
"""
3031
A fallback schedule rule for all GPU operators. It will try to inline all the blocks first,
3132
and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks.
@@ -37,6 +38,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
3738
target: Target,
3839
_: bool,
3940
) -> tir.Schedule:
41+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
42+
return None
4043
max_threads_per_block = utils.max_threads_per_block(target)
4144

4245
sch = tir.Schedule(func)

python/tvm/dlight/gpu/gemv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424

2525
from ..base import (
2626
BlockInfo,
27-
ScheduleRule,
2827
collect_vars_used_in_access_region,
2928
detect_dominant_read,
3029
is_broadcast_epilogue,
3130
normalize_prim_func,
3231
try_inline_contiguous_spatial,
3332
)
33+
from .base import GPUScheduleRule
3434

3535

3636
def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
@@ -154,7 +154,7 @@ def normalize(
154154
return is_inner_reduction
155155

156156

157-
class GEMV(ScheduleRule):
157+
class GEMV(GPUScheduleRule):
158158
"""A rule for GEMV and DecodeGEMV."""
159159

160160
def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
@@ -163,7 +163,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
163163
target: Target,
164164
_: bool,
165165
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
166-
if not isinstance(func, tir.PrimFunc):
166+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
167167
return None
168168
sch = tir.Schedule(func)
169169
block_infos = normalize_prim_func(sch)

python/tvm/dlight/gpu/general_reduction.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
from tvm import tir
2222
from tvm.target import Target
2323

24-
from ..base import ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial
24+
from ..base import normalize_prim_func, try_inline_contiguous_spatial
25+
from .base import GPUScheduleRule
2526

2627

27-
class GeneralReduction(ScheduleRule):
28+
class GeneralReduction(GPUScheduleRule):
2829
"""General Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
2930

3031
def apply( # pylint: disable=too-many-locals
@@ -33,7 +34,7 @@ def apply( # pylint: disable=too-many-locals
3334
target: Target,
3435
_: bool,
3536
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
36-
if not isinstance(func, tir.PrimFunc):
37+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
3738
return None
3839

3940
if target.kind.name == "cuda":

python/tvm/dlight/gpu/matmul.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from tvm.tir.analysis import undefined_vars
2828
from tvm.tir.schedule.schedule import BlockRV
2929

30-
from ..base import ScheduleRule, analysis
30+
from ..base import analysis
31+
from .base import GPUScheduleRule
3132

3233

3334
def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
@@ -312,7 +313,7 @@ def check_sm_version(arch: str) -> int:
312313
return int(sm_version) if sm_version.isdigit() else -1
313314

314315

315-
class MatmulTensorization(ScheduleRule):
316+
class MatmulTensorization(GPUScheduleRule):
316317
"""
317318
The schedule rule for float16 tensor core matmul computation.
318319
func with attr 'dlight.do_not_tensorize' will not be tensorized.
@@ -328,6 +329,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
328329
get_wmma_intrin_group,
329330
)
330331

332+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
333+
return None
331334
sch = tir.Schedule(func)
332335
root_block = analysis.get_root_block(sch)
333336
blocks = sch.get_child_blocks(root_block)
@@ -531,7 +534,7 @@ def tensorize_init_store_compute():
531534
return sch if tensorize_success else None
532535

533536

534-
class MatmulInt8Tensorization(ScheduleRule):
537+
class MatmulInt8Tensorization(GPUScheduleRule):
535538
"""
536539
The schedule rule for int8 tensor core matmul computation.
537540
func with attr 'dlight.do_not_tensorize' will not be tensorized.
@@ -547,6 +550,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
547550
get_wmma_intrin_group,
548551
)
549552

553+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
554+
return None
550555
sch = tir.Schedule(func)
551556
root_block = analysis.get_root_block(sch)
552557
blocks = sch.get_child_blocks(root_block)
@@ -734,7 +739,7 @@ def tensorize_init_store_compute():
734739
return sch
735740

736741

737-
class Matmul(ScheduleRule):
742+
class Matmul(GPUScheduleRule):
738743
"""The schedule rule for matmul-like computation"""
739744

740745
@dataclass
@@ -793,6 +798,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
793798
target: Target,
794799
_: bool,
795800
) -> Optional[tir.Schedule]:
801+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
802+
return None
796803
sch = tir.Schedule(func)
797804
root_block = analysis.get_root_block(sch)
798805
blocks = sch.get_child_blocks(root_block)

python/tvm/dlight/gpu/reduction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323

2424
from ..base import (
2525
BlockInfo,
26-
ScheduleRule,
2726
normalize_prim_func,
2827
try_inline_contiguous_spatial,
2928
detect_dominant_read,
3029
is_broadcast_epilogue,
3130
)
3231
from . import utils
32+
from .base import GPUScheduleRule
3333

3434

3535
def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
@@ -48,7 +48,7 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
4848
return buffer_store.value.b
4949

5050

51-
class Reduction(ScheduleRule):
51+
class Reduction(GPUScheduleRule):
5252
"""A rule for Reduction."""
5353

5454
def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
@@ -57,7 +57,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
5757
target: Target,
5858
_: bool,
5959
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
60-
if not isinstance(func, tir.PrimFunc):
60+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
6161
return None
6262
sch = tir.Schedule(func)
6363
block_infos = normalize_prim_func(sch)

python/tvm/dlight/gpu/transpose.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,20 @@
1717
"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
1818
from typing import List, Union
1919

20-
from tvm import tir, arith
20+
from tvm import arith, tir
2121
from tvm.target import Target
2222
from tvm.tir import Schedule
2323
from tvm.tir.schedule import BlockRV
2424

25-
2625
from ..base import (
27-
ScheduleRule,
26+
detect_dominant_read,
2827
normalize_prim_func,
2928
try_inline_contiguous_spatial,
30-
detect_dominant_read,
3129
)
30+
from .base import GPUScheduleRule
3231

3332

34-
class Transpose(ScheduleRule):
33+
class Transpose(GPUScheduleRule):
3534
"""Schedule rule for transpose"""
3635

3736
def is_transpose(self, sch: Schedule, block_rv: BlockRV):
@@ -52,6 +51,8 @@ def apply( # pylint: disable=too-many-locals
5251
_: bool,
5352
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
5453
# pylint: disable=invalid-name
54+
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
55+
return None
5556
if target.kind.name == "cuda":
5657
len_tx = 16
5758
len_ty = 8

0 commit comments

Comments
 (0)