2727from tvm .tir .analysis import undefined_vars
2828from tvm .tir .schedule .schedule import BlockRV
2929
30- from ..base import ScheduleRule , analysis
30+ from ..base import analysis
31+ from .base import GPUScheduleRule
3132
3233
3334def _collect_producers (sch : tir .Schedule , block : tir .schedule .BlockRV ):
@@ -312,7 +313,7 @@ def check_sm_version(arch: str) -> int:
312313 return int (sm_version ) if sm_version .isdigit () else - 1
313314
314315
315- class MatmulTensorization (ScheduleRule ):
316+ class MatmulTensorization (GPUScheduleRule ):
316317 """
317318 The schedule rule for float16 tensor core matmul computation.
318319 func with attr 'dlight.do_not_tensorize' will not be tensorized.
@@ -328,6 +329,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
328329 get_wmma_intrin_group ,
329330 )
330331
332+ if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
333+ return None
331334 sch = tir .Schedule (func )
332335 root_block = analysis .get_root_block (sch )
333336 blocks = sch .get_child_blocks (root_block )
@@ -531,7 +534,7 @@ def tensorize_init_store_compute():
531534 return sch if tensorize_success else None
532535
533536
534- class MatmulInt8Tensorization (ScheduleRule ):
537+ class MatmulInt8Tensorization (GPUScheduleRule ):
535538 """
536539 The schedule rule for int8 tensor core matmul computation.
537540 func with attr 'dlight.do_not_tensorize' will not be tensorized.
@@ -547,6 +550,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
547550 get_wmma_intrin_group ,
548551 )
549552
553+ if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
554+ return None
550555 sch = tir .Schedule (func )
551556 root_block = analysis .get_root_block (sch )
552557 blocks = sch .get_child_blocks (root_block )
@@ -734,7 +739,7 @@ def tensorize_init_store_compute():
734739 return sch
735740
736741
737- class Matmul (ScheduleRule ):
742+ class Matmul (GPUScheduleRule ):
738743 """The schedule rule for matmul-like computation"""
739744
740745 @dataclass
@@ -793,6 +798,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
793798 target : Target ,
794799 _ : bool ,
795800 ) -> Optional [tir .Schedule ]:
801+ if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
802+ return None
796803 sch = tir .Schedule (func )
797804 root_block = analysis .get_root_block (sch )
798805 blocks = sch .get_child_blocks (root_block )
0 commit comments