@@ -412,7 +412,7 @@ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[Sche
412412 # pylint: disable=protected-access
413413 if target .kind .name == "llvm" :
414414 return DefaultLLVM ._sch_rules ()
415- if target .kind .name == "cuda" :
415+ if target .kind .name in [ "cuda" , "rocm" , "vulkan" ] :
416416 return DefaultCUDA ._sch_rules ()
417417 # pylint: enable=protected-access
418418 raise ValueError (f"Unsupported target: { target } " )
@@ -426,7 +426,7 @@ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]:
426426 # pylint: disable=protected-access
427427 if target .kind .name == "llvm" :
428428 return DefaultLLVM ._postproc ()
429- if target .kind .name == "cuda" :
429+ if target .kind .name in [ "cuda" , "rocm" , "vulkan" ] :
430430 return DefaultCUDA ._postproc ()
431431 # pylint: enable=protected-access
432432 raise ValueError (f"Unsupported target: { target } " )
@@ -445,7 +445,7 @@ def _mutator_probs(
445445 # pylint: disable=protected-access
446446 if target .kind .name == "llvm" :
447447 return DefaultLLVM ._mutator_probs ()
448- if target .kind .name == "cuda" :
448+ if target .kind .name in [ "cuda" , "rocm" , "vulkan" ] :
449449 return DefaultCUDA ._mutator_probs ()
450450 # pylint: enable=protected-access
451451 raise ValueError (f"Unsupported target: { target } " )
0 commit comments