Skip to content

Commit 21566a4

Browse files
committed
[Transform] unsupported_dtype_legalize.cc - Only check cuda compute version for fp8 support on cuda target
1 parent 901fd0b commit 21566a4

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/tir/transforms/unsupported_dtype_legalize.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -695,11 +695,13 @@ namespace transform {
695695

696696
bool CheckDataTypeSupport(const Target& target, const std::string& support_func_name) {
697697
bool has_native_support = false;
698-
if (const PackedFunc* get_cv =
699-
tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) {
700-
std::string compute_version = (*get_cv)(target);
701-
if (const PackedFunc* check_support = tvm::runtime::Registry::Get(support_func_name)) {
702-
has_native_support = (*check_support)(compute_version);
698+
if (target->kind->name == "cuda") {
699+
if (const PackedFunc* get_cv =
700+
tvm::runtime::Registry::Get("tvm.contrib.nvcc.get_compute_version")) {
701+
std::string compute_version = (*get_cv)(target);
702+
if (const PackedFunc* check_support = tvm::runtime::Registry::Get(support_func_name)) {
703+
has_native_support = (*check_support)(compute_version);
704+
}
703705
}
704706
}
705707
return has_native_support;

0 commit comments

Comments
 (0)