diff --git a/python/tvm/dlight/base/schedule_rule.py b/python/tvm/dlight/base/schedule_rule.py index 3bb7e5c1a929..dda66b7cfe9c 100644 --- a/python/tvm/dlight/base/schedule_rule.py +++ b/python/tvm/dlight/base/schedule_rule.py @@ -103,3 +103,18 @@ def apply( return _Rule() return decorator + + def is_target_available(self, target: Target) -> bool: # pylint: disable=unused-argument + """Check whether the rule is available for the given target. + + Parameters + ---------- + target : Target + The compilation target the schedule is supposed to be built for. + + Returns + ------- + available : bool + Whether the rule is available for the given target. + """ + return True diff --git a/python/tvm/dlight/gpu/base.py b/python/tvm/dlight/gpu/base.py new file mode 100644 index 000000000000..b5cf0bb7a9b4 --- /dev/null +++ b/python/tvm/dlight/gpu/base.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base schedule rule for GPU operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class GPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to GPU targets, will return None if the target is not GPU.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for gpu rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "gpu" in target.keys diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py index 2c1e7424dcfe..7139c7ea4199 100644 --- a/python/tvm/dlight/gpu/fallback.py +++ b/python/tvm/dlight/gpu/fallback.py @@ -21,11 +21,12 @@ from tvm import tir from tvm.target import Target -from ..base import ScheduleRule, normalize_prim_func, try_inline +from ..base import normalize_prim_func, try_inline from . import utils +from .base import GPUScheduleRule -class Fallback(ScheduleRule): +class Fallback(GPUScheduleRule): """ A fallback schedule rule for all GPU operators. It will try to inline all the blocks first, and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks. @@ -37,6 +38,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring target: Target, _: bool, ) -> tir.Schedule: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None max_threads_per_block = utils.max_threads_per_block(target) sch = tir.Schedule(func) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 76839d41662d..27b155c6a754 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -24,13 +24,13 @@ from ..base import ( BlockInfo, - ScheduleRule, collect_vars_used_in_access_region, detect_dominant_read, is_broadcast_epilogue, normalize_prim_func, try_inline_contiguous_spatial, ) +from .base import GPUScheduleRule def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: @@ -154,7 +154,7 @@ def normalize( return is_inner_reduction -class GEMV(ScheduleRule): +class GEMV(GPUScheduleRule): """A rule for GEMV and DecodeGEMV.""" def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements @@ -163,7 +163,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- target: Target, _: bool, ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - if not isinstance(func, tir.PrimFunc): + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index bbd42a7524d6..28b68a8b62a7 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -21,10 +21,11 @@ from tvm import tir from tvm.target import Target -from ..base import ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial +from ..base import normalize_prim_func, try_inline_contiguous_spatial +from .base import GPUScheduleRule -class GeneralReduction(ScheduleRule): +class GeneralReduction(GPUScheduleRule): """General Reduction rule for operators including softmax, layer norm, RMS norm, etc""" def apply( # pylint: disable=too-many-locals @@ -33,7 +34,7 @@ def apply( # pylint: disable=too-many-locals target: Target, _: bool, ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - if not isinstance(func, tir.PrimFunc): + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None if target.kind.name == "cuda": diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 7d5d6489cb56..9318b9149245 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -27,7 +27,8 @@ from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV -from ..base import ScheduleRule, analysis +from ..base import analysis +from .base import GPUScheduleRule def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): @@ -312,7 +313,7 @@ def check_sm_version(arch: str) -> int: return int(sm_version) if sm_version.isdigit() else -1 -class MatmulTensorization(ScheduleRule): +class MatmulTensorization(GPUScheduleRule): """ The schedule rule for float16 tensor core matmul computation. func with attr 'dlight.do_not_tensorize' will not be tensorized. @@ -328,6 +329,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring get_wmma_intrin_group, ) + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) @@ -531,7 +534,7 @@ def tensorize_init_store_compute(): return sch if tensorize_success else None -class MatmulInt8Tensorization(ScheduleRule): +class MatmulInt8Tensorization(GPUScheduleRule): """ The schedule rule for int8 tensor core matmul computation. func with attr 'dlight.do_not_tensorize' will not be tensorized. @@ -547,6 +550,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring get_wmma_intrin_group, ) + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) @@ -734,7 +739,7 @@ def tensorize_init_store_compute(): return sch -class Matmul(ScheduleRule): +class Matmul(GPUScheduleRule): """The schedule rule for matmul-like computation""" @dataclass @@ -793,6 +798,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring target: Target, _: bool, ) -> Optional[tir.Schedule]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None sch = tir.Schedule(func) root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index 3e2e5ee5326d..2ccc11f7f49e 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -23,13 +23,13 @@ from ..base import ( BlockInfo, - ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial, detect_dominant_read, is_broadcast_epilogue, ) from . import utils +from .base import GPUScheduleRule def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: @@ -48,7 +48,7 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: return buffer_store.value.b -class Reduction(ScheduleRule): +class Reduction(GPUScheduleRule): """A rule for Reduction.""" def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements @@ -57,7 +57,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- target: Target, _: bool, ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - if not isinstance(func, tir.PrimFunc): + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) diff --git a/python/tvm/dlight/gpu/transpose.py b/python/tvm/dlight/gpu/transpose.py index a51fcdc87336..d4496756a2d0 100644 --- a/python/tvm/dlight/gpu/transpose.py +++ b/python/tvm/dlight/gpu/transpose.py @@ -17,21 +17,20 @@ """Reduction rule for operators including softmax, layer norm, RMS norm, etc""" from typing import List, Union -from tvm import tir, arith +from tvm import arith, tir from tvm.target import Target from tvm.tir import Schedule from tvm.tir.schedule import BlockRV - from ..base import ( - ScheduleRule, + detect_dominant_read, normalize_prim_func, try_inline_contiguous_spatial, - detect_dominant_read, ) +from .base import GPUScheduleRule -class Transpose(ScheduleRule): +class Transpose(GPUScheduleRule): """Schedule rule for transpose""" def is_transpose(self, sch: Schedule, block_rv: BlockRV): @@ -52,6 +51,8 @@ def apply( # pylint: disable=too-many-locals _: bool, ) -> Union[None, tir.Schedule, List[tir.Schedule]]: # pylint: disable=invalid-name + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None if target.kind.name == "cuda": len_tx = 16 len_ty = 8