2727from tvm .tir .analysis import undefined_vars
2828from tvm .tir .schedule .schedule import BlockRV
2929
30- from ..base import ScheduleRule , analysis
30+ from ..base import analysis
31+ from .base import GPUScheduleRule
3132
3233
3334def _collect_producers (sch : tir .Schedule , block : tir .schedule .BlockRV ):
@@ -312,7 +313,7 @@ def check_sm_version(arch: str) -> int:
312313 return int (sm_version ) if sm_version .isdigit () else - 1
313314
314315
315- class MatmulTensorization (ScheduleRule ):
316+ class MatmulTensorization (GPUScheduleRule ):
316317 """
317318 The schedule rule for float16 tensor core matmul computation.
318319 func with attr 'dlight.do_not_tensorize' will not be tensorized.
@@ -327,7 +328,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
327328 from tvm .tir .tensor_intrin .cuda import ( # pylint: disable=import-outside-toplevel
328329 get_wmma_intrin_group ,
329330 )
330-
331+ if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
332+ return None
331333 sch = tir .Schedule (func )
332334 root_block = analysis .get_root_block (sch )
333335 blocks = sch .get_child_blocks (root_block )
@@ -531,7 +533,7 @@ def tensorize_init_store_compute():
531533 return sch if tensorize_success else None
532534
533535
534- class MatmulInt8Tensorization (ScheduleRule ):
536+ class MatmulInt8Tensorization (GPUScheduleRule ):
535537 """
536538 The schedule rule for int8 tensor core matmul computation.
537539 func with attr 'dlight.do_not_tensorize' will not be tensorized.
@@ -546,7 +548,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
546548 from tvm .tir .tensor_intrin .cuda import ( # pylint: disable=import-outside-toplevel
547549 get_wmma_intrin_group ,
548550 )
549-
551+ if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
552+ return None
550553 sch = tir .Schedule (func )
551554 root_block = analysis .get_root_block (sch )
552555 blocks = sch .get_child_blocks (root_block )
@@ -734,7 +737,7 @@ def tensorize_init_store_compute():
734737 return sch
735738
736739
737- class Matmul (ScheduleRule ):
740+ class Matmul (GPUScheduleRule ):
738741 """The schedule rule for matmul-like computation"""
739742
740743 @dataclass
@@ -793,6 +796,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
793796 target : Target ,
794797 _ : bool ,
795798 ) -> Optional [tir .Schedule ]:
799+ if not isinstance (func , tir .PrimFunc ) or not self .is_target_available (target ):
800+ return None
796801 sch = tir .Schedule (func )
797802 root_block = analysis .get_root_block (sch )
798803 blocks = sch .get_child_blocks (root_block )
0 commit comments