From 421add5000dadce26e1f57f6341ff5dea4243f62 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 25 May 2021 09:58:10 -0700 Subject: [PATCH] [Vulkan][Codegen] Read vulkan device capabilities/limits from Target - Previously, the codegen assumed that all device features were present. Now, the codegen reads device capabilities from the Target, and throws an error if codegen would require use of an unsupported feature. --- src/target/spirv/spirv_support.cc | 50 ++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index ff9aee406574c..e06bde08895d9 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -35,17 +35,45 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { ICHECK_EQ(target->kind->device_type, kDLVulkan) << "SPIRVSupport can only be checked for vulkan device type"; - // Currently, this codifies the assumptions that were present and - // implicit in previous implementations. In the future, this will - // pull information from the specified `Target`. - - supports_storage_buffer_storage_class = (SPV_VERSION >= 0x10300); - supports_storage_buffer_8bit_access = true; - supports_storage_buffer_16bit_access = true; - supports_float16 = true; - supports_int8 = true; - supports_int16 = true; - supports_int64 = true; + if (target->GetAttr("supported_subgroup_operations")) { + supported_subgroup_operations = + target->GetAttr("supported_subgroup_operations").value(); + } + if (target->GetAttr("max_push_constants_size")) { + max_push_constants_size = target->GetAttr("max_push_constants_size").value(); + } + if (target->GetAttr("max_uniform_buffer_range")) { + max_uniform_buffer_range = target->GetAttr("max_uniform_buffer_range").value(); + } + if (target->GetAttr("max_storage_buffer_range")) { + max_storage_buffer_range = target->GetAttr("max_storage_buffer_range").value(); + } + if (target->GetAttr("max_per_stage_descriptor_storage_buffer")) { + max_per_stage_descriptor_storage_buffers = + target->GetAttr("max_per_stage_descriptor_storage_buffer").value(); + } + if (target->GetAttr("supports_storage_buffer_storage_class")) { + supports_storage_buffer_storage_class = + target->GetAttr("supports_storage_buffer_storage_class").value(); + } + if (target->GetAttr("supports_8bit_buffer")) { + supports_storage_buffer_8bit_access = target->GetAttr("supports_8bit_buffer").value(); + } + if (target->GetAttr("supports_16bit_buffer")) { + supports_storage_buffer_16bit_access = target->GetAttr("supports_16bit_buffer").value(); + } + if (target->GetAttr("supports_float16")) { + supports_float16 = target->GetAttr("supports_float16").value(); + } + if (target->GetAttr("supports_int8")) { + supports_int8 = target->GetAttr("supports_int8").value(); + } + if (target->GetAttr("supports_int16")) { + supports_int16 = target->GetAttr("supports_int16").value(); + } + if (target->GetAttr("supports_int64")) { + supports_int64 = target->GetAttr("supports_int64").value(); + } } } // namespace codegen