Skip to content

Commit daea033

Browse files
committed
[Metaschedule] Support rocm and spirv
1 parent eb0cae2 commit daea033

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

python/tvm/meta_schedule/tune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def _sch_rules(sch_rules: Optional[FnScheduleRule], target: Target) -> List[Sche
412412
# pylint: disable=protected-access
413413
if target.kind.name == "llvm":
414414
return DefaultLLVM._sch_rules()
415-
if target.kind.name == "cuda":
415+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
416416
return DefaultCUDA._sch_rules()
417417
# pylint: enable=protected-access
418418
raise ValueError(f"Unsupported target: {target}")
@@ -426,7 +426,7 @@ def _postproc(postproc: Optional[FnPostproc], target: Target) -> List[Postproc]:
426426
# pylint: disable=protected-access
427427
if target.kind.name == "llvm":
428428
return DefaultLLVM._postproc()
429-
if target.kind.name == "cuda":
429+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
430430
return DefaultCUDA._postproc()
431431
# pylint: enable=protected-access
432432
raise ValueError(f"Unsupported target: {target}")
@@ -445,7 +445,7 @@ def _mutator_probs(
445445
# pylint: disable=protected-access
446446
if target.kind.name == "llvm":
447447
return DefaultLLVM._mutator_probs()
448-
if target.kind.name == "cuda":
448+
if target.kind.name in ["cuda", "rocm", "vulkan"]:
449449
return DefaultCUDA._mutator_probs()
450450
# pylint: enable=protected-access
451451
raise ValueError(f"Unsupported target: {target}")

src/target/target_kind.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,9 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
308308
.add_attr_option<String>("mtriple")
309309
.add_attr_option<Bool>("system-lib")
310310
.add_attr_option<Integer>("max_num_threads", Integer(256))
311+
.add_attr_option<Integer>("max_threads_per_block", Integer(256))
311312
.add_attr_option<Integer>("thread_warp_size", Integer(64))
313+
.add_attr_option<Integer>("max_shared_memory_per_block", Integer(64000))
312314
.set_default_keys({"rocm", "gpu"})
313315
.set_attrs_preprocessor(UpdateROCmAttrs);
314316

@@ -349,6 +351,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
349351
.add_attr_option<Integer>("supported_subgroup_operations")
350352
// Physical device limits
351353
.add_attr_option<Integer>("max_num_threads", Integer(256))
354+
.add_attr_option<Integer>("max_threads_per_block", Integer(256))
352355
.add_attr_option<Integer>("thread_warp_size", Integer(1))
353356
.add_attr_option<Integer>("max_block_size_x")
354357
.add_attr_option<Integer>("max_block_size_y")

0 commit comments

Comments
 (0)