Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/tvm/dlight/base/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,18 @@ def apply(
return _Rule()

return decorator

def is_target_available(self, target: Target) -> bool: # pylint: disable=unused-argument
"""Check whether the rule is available for the given target.

Parameters
----------
target : Target
The compilation target the schedule is supposed to be built for.

Returns
-------
available : bool
Whether the rule is available for the given target.
"""
return True
40 changes: 40 additions & 0 deletions python/tvm/dlight/gpu/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Base schedule rule for GPU operators."""

from tvm.target import Target

from ..base import ScheduleRule


class GPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods
"""The Schedule Rule specific to GPU targets, will return None if the target is not GPU."""

def is_target_available(self, target: Target) -> bool:
"""Check whether the target is available for gpu rule.

Parameters
----------
target : Target
The compilation target to check.

Returns
-------
available : bool
Whether the target is available for this rule.
"""
return super().is_target_available(target) and "gpu" in target.keys
7 changes: 5 additions & 2 deletions python/tvm/dlight/gpu/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from tvm import tir
from tvm.target import Target

from ..base import ScheduleRule, normalize_prim_func, try_inline
from ..base import normalize_prim_func, try_inline
from . import utils
from .base import GPUScheduleRule


class Fallback(ScheduleRule):
class Fallback(GPUScheduleRule):
"""
A fallback schedule rule for all GPU operators. It will try to inline all the blocks first,
and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks.
Expand All @@ -37,6 +38,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
target: Target,
_: bool,
) -> tir.Schedule:
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
max_threads_per_block = utils.max_threads_per_block(target)

sch = tir.Schedule(func)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@

from ..base import (
BlockInfo,
ScheduleRule,
collect_vars_used_in_access_region,
detect_dominant_read,
is_broadcast_epilogue,
normalize_prim_func,
try_inline_contiguous_spatial,
)
from .base import GPUScheduleRule


def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
Expand Down Expand Up @@ -154,7 +154,7 @@ def normalize(
return is_inner_reduction


class GEMV(ScheduleRule):
class GEMV(GPUScheduleRule):
"""A rule for GEMV and DecodeGEMV."""

def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
Expand All @@ -163,7 +163,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
target: Target,
_: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
if not isinstance(func, tir.PrimFunc):
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
block_infos = normalize_prim_func(sch)
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/dlight/gpu/general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from tvm import tir
from tvm.target import Target

from ..base import ScheduleRule, normalize_prim_func, try_inline_contiguous_spatial
from ..base import normalize_prim_func, try_inline_contiguous_spatial
from .base import GPUScheduleRule


class GeneralReduction(ScheduleRule):
class GeneralReduction(GPUScheduleRule):
"""General Reduction rule for operators including softmax, layer norm, RMS norm, etc"""

def apply( # pylint: disable=too-many-locals
Expand All @@ -33,7 +34,7 @@ def apply( # pylint: disable=too-many-locals
target: Target,
_: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
if not isinstance(func, tir.PrimFunc):
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None

if target.kind.name == "cuda":
Expand Down
15 changes: 11 additions & 4 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV

from ..base import ScheduleRule, analysis
from ..base import analysis
from .base import GPUScheduleRule


def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV):
Expand Down Expand Up @@ -312,7 +313,7 @@ def check_sm_version(arch: str) -> int:
return int(sm_version) if sm_version.isdigit() else -1


class MatmulTensorization(ScheduleRule):
class MatmulTensorization(GPUScheduleRule):
"""
The schedule rule for float16 tensor core matmul computation.
func with attr 'dlight.do_not_tensorize' will not be tensorized.
Expand All @@ -328,6 +329,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
get_wmma_intrin_group,
)

if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
Expand Down Expand Up @@ -531,7 +534,7 @@ def tensorize_init_store_compute():
return sch if tensorize_success else None


class MatmulInt8Tensorization(ScheduleRule):
class MatmulInt8Tensorization(GPUScheduleRule):
"""
The schedule rule for int8 tensor core matmul computation.
func with attr 'dlight.do_not_tensorize' will not be tensorized.
Expand All @@ -547,6 +550,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
get_wmma_intrin_group,
)

if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
Expand Down Expand Up @@ -734,7 +739,7 @@ def tensorize_init_store_compute():
return sch


class Matmul(ScheduleRule):
class Matmul(GPUScheduleRule):
"""The schedule rule for matmul-like computation"""

@dataclass
Expand Down Expand Up @@ -793,6 +798,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/dlight/gpu/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@

from ..base import (
BlockInfo,
ScheduleRule,
normalize_prim_func,
try_inline_contiguous_spatial,
detect_dominant_read,
is_broadcast_epilogue,
)
from . import utils
from .base import GPUScheduleRule


def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
Expand All @@ -48,7 +48,7 @@ def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
return buffer_store.value.b


class Reduction(ScheduleRule):
class Reduction(GPUScheduleRule):
"""A rule for Reduction."""

def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
Expand All @@ -57,7 +57,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
target: Target,
_: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
if not isinstance(func, tir.PrimFunc):
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
block_infos = normalize_prim_func(sch)
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/dlight/gpu/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,20 @@
"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
from typing import List, Union

from tvm import tir, arith
from tvm import arith, tir
from tvm.target import Target
from tvm.tir import Schedule
from tvm.tir.schedule import BlockRV


from ..base import (
ScheduleRule,
detect_dominant_read,
normalize_prim_func,
try_inline_contiguous_spatial,
detect_dominant_read,
)
from .base import GPUScheduleRule


class Transpose(ScheduleRule):
class Transpose(GPUScheduleRule):
"""Schedule rule for transpose"""

def is_transpose(self, sch: Schedule, block_rv: BlockRV):
Expand All @@ -52,6 +51,8 @@ def apply( # pylint: disable=too-many-locals
_: bool,
) -> Union[None, tir.Schedule, List[tir.Schedule]]:
# pylint: disable=invalid-name
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
if target.kind.name == "cuda":
len_tx = 16
len_ty = 8
Expand Down