diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 5f2d2eb72c7..a51fd4ed526 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -59,17 +59,25 @@ Context::~Context() { DescriptorSet Context::get_descriptor_set( const ShaderInfo& shader_descriptor, - const utils::uvec3& local_workgroup_size) { + const utils::uvec3& local_workgroup_size, + const SpecVarList& additional_constants) { VkDescriptorSetLayout shader_layout = shader_layout_cache().retrieve(shader_descriptor.kernel_layout); VkPipelineLayout pipeline_layout = pipeline_layout_cache().retrieve(shader_layout); + SpecVarList spec_constants = { + SV(local_workgroup_size.data[0u]), + SV(local_workgroup_size.data[1u]), + SV(local_workgroup_size.data[2u])}; + + spec_constants.append(additional_constants); + VkPipeline pipeline = pipeline_cache().retrieve( {pipeline_layout_cache().retrieve(shader_layout), shader_cache().retrieve(shader_descriptor), - local_workgroup_size}); + spec_constants}); cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size); diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 0813d4190de..59ea2e2d88d 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -172,7 +172,16 @@ class Context final { } } - DescriptorSet get_descriptor_set(const ShaderInfo&, const utils::uvec3&); + DescriptorSet get_descriptor_set( + const ShaderInfo&, + const utils::uvec3&, + const SpecVarList&); + + inline DescriptorSet get_descriptor_set( + const ShaderInfo& shader_descriptor, + const utils::uvec3& local_work_group_size) { + return get_descriptor_set(shader_descriptor, local_work_group_size, {}); + } void register_shader_dispatch( const DescriptorSet&, diff --git a/backends/vulkan/runtime/api/Pipeline.cpp b/backends/vulkan/runtime/api/Pipeline.cpp index 7207814707c..f4be0039e67 100644 --- a/backends/vulkan/runtime/api/Pipeline.cpp +++ b/backends/vulkan/runtime/api/Pipeline.cpp @@ -98,6 +98,101 @@ VkImageLayout vk_layout( return VK_IMAGE_LAYOUT_UNDEFINED; } +// +// SpecVar +// + +SpecVar::SpecVar() : type(SpecVar::Type::INT) { + value.as_int32 = 0; +} + +SpecVar::SpecVar(const float val) : type(SpecVar::Type::FLOAT) { + value.as_float = val; +} + +SpecVar::SpecVar(const int32_t val) : type(SpecVar::Type::INT) { + value.as_int32 = val; +} + +SpecVar::SpecVar(const uint32_t val) : type(SpecVar::Type::UINT) { + value.as_uint32 = val; +} + +SpecVar::SpecVar(const bool val) : type(SpecVar::Type::BOOL) { + value.as_bool = val; +} + +uint32_t SpecVar::val_size() const { + switch (type) { + case SpecVar::Type::FLOAT: + return sizeof(float); + case SpecVar::Type::INT: + return sizeof(int32_t); + case SpecVar::Type::UINT: + return sizeof(uint32_t); + case SpecVar::Type::BOOL: + return sizeof(bool); + } + return 4; +} + +uint32_t SpecVar::val_offset() const { + return api::utils::safe_downcast(offsetof(SpecVar, value)); +} + +bool operator==(const SpecVar& lhs, const SpecVar& rhs) { + if (lhs.type != rhs.type) { + return false; + } + switch (lhs.type) { + case SpecVar::Type::FLOAT: + return lhs.value.as_float == rhs.value.as_float; + case SpecVar::Type::INT: + return lhs.value.as_int32 == rhs.value.as_int32; + case SpecVar::Type::UINT: + return lhs.value.as_uint32 == rhs.value.as_uint32; + case SpecVar::Type::BOOL: + return lhs.value.as_bool == rhs.value.as_bool; + } + return false; +} + +SpecVarList::SpecVarList() {} + +SpecVarList::SpecVarList(std::initializer_list init_list) { + vars.resize(init_list.size()); + std::copy(init_list.begin(), init_list.end(), vars.begin()); +} + +void SpecVarList::append(const SpecVarList& other) { + vars.insert(vars.end(), other.vars.begin(), other.vars.end()); +} + +std::vector SpecVarList::generate_map_entries() + const { + std::vector map_entries; + map_entries.resize(vars.size()); + uint32_t cur_offset = 0u; + for (uint32_t i = 0; i < vars.size(); ++i) { + map_entries.at(i) = { + i, cur_offset + vars.at(i).val_offset(), vars.at(i).val_size()}; + cur_offset += sizeof(SpecVar); + } + return map_entries; +} + +bool operator==(const SpecVarList& lhs, const SpecVarList& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (uint32_t i = 0; i < lhs.size(); ++i) { + if (lhs.vars.at(i) != rhs.vars.at(i)) { + return false; + } + } + return true; +} + // // PipelineLayout // @@ -154,33 +249,14 @@ ComputePipeline::ComputePipeline( const ComputePipeline::Descriptor& descriptor, VkPipelineCache pipeline_cache) : device_(device), handle_{VK_NULL_HANDLE} { - // NOLINTNEXTLINE - constexpr VkSpecializationMapEntry specialization_map_entries[3]{ - // X - { - 0u, - offsetof(utils::uvec3, data[0u]), - sizeof(utils::uvec3::data[0u]), - }, - // Y - { - 1u, - offsetof(utils::uvec3, data[1u]), - sizeof(utils::uvec3::data[1u]), - }, - // Z - { - 2u, - offsetof(utils::uvec3, data[2u]), - sizeof(utils::uvec3::data[2u]), - }, - }; + std::vector map_entries = + descriptor.specialization_constants.generate_map_entries(); const VkSpecializationInfo specialization_info{ - 3u, // mapEntryCount - specialization_map_entries, // pMapEntries - sizeof(descriptor.local_work_group), // dataSize - &descriptor.local_work_group, // pData + descriptor.specialization_constants.size(), // mapEntryCount + map_entries.data(), // pMapEntries + descriptor.specialization_constants.data_nbytes(), // dataSize + descriptor.specialization_constants.data(), // pData }; const VkPipelineShaderStageCreateInfo shader_stage_create_info{ @@ -242,7 +318,7 @@ bool operator==( return ( _1.pipeline_layout == _2.pipeline_layout && _1.shader_module == _2.shader_module && - _1.local_work_group == _2.local_work_group); + _1.specialization_constants == _2.specialization_constants); } // diff --git a/backends/vulkan/runtime/api/Pipeline.h b/backends/vulkan/runtime/api/Pipeline.h index 409fd2afa87..b8c16efd910 100644 --- a/backends/vulkan/runtime/api/Pipeline.h +++ b/backends/vulkan/runtime/api/Pipeline.h @@ -18,9 +18,73 @@ #include #include +#define SV(x) ::vkcompute::api::SpecVar(x) + namespace vkcompute { namespace api { +struct SpecVar final { + enum class Type : uint8_t { + FLOAT, + INT, + UINT, + BOOL, + }; + + union Value { + int32_t as_int32; + uint32_t as_uint32; + float as_float; + bool as_bool; + }; + + Value value; + Type type; + + SpecVar(); + SpecVar(const float val); + SpecVar(const int32_t val); + SpecVar(const uint32_t val); + SpecVar(const bool val); + + uint32_t val_size() const; + uint32_t val_offset() const; +}; + +bool operator==(const SpecVar& lhs, const SpecVar& rhs); + +class SpecVarList final { + std::vector vars; + + public: + SpecVarList(); + SpecVarList(std::initializer_list init_list); + + inline const SpecVar& at(const size_t index) const { + return vars.at(index); + } + + inline const SpecVar* data() const { + return vars.data(); + } + + inline uint32_t size() const { + return api::utils::safe_downcast(vars.size()); + } + + inline uint32_t data_nbytes() const { + return vars.size() * sizeof(SpecVar); + } + + void append(const SpecVarList& other); + + std::vector generate_map_entries() const; + + friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs); +}; + +bool operator==(const SpecVarList& lhs, const SpecVarList& rhs); + struct PipelineBarrier final { struct Stages final { VkPipelineStageFlags src; @@ -83,7 +147,7 @@ class ComputePipeline final { struct Descriptor final { VkPipelineLayout pipeline_layout; VkShaderModule shader_module; - utils::uvec3 local_work_group; + SpecVarList specialization_constants; }; explicit ComputePipeline( @@ -171,12 +235,29 @@ class ComputePipelineCache final { seed, std::hash()(descriptor.pipeline_layout)); seed = utils::hash_combine( seed, std::hash()(descriptor.shader_module)); - seed = utils::hash_combine( - seed, std::hash()(descriptor.local_work_group.data[0u])); - seed = utils::hash_combine( - seed, std::hash()(descriptor.local_work_group.data[1u])); - seed = utils::hash_combine( - seed, std::hash()(descriptor.local_work_group.data[2u])); + + const SpecVarList& spec_vars = descriptor.specialization_constants; + seed = utils::hash_combine(seed, std::hash()(spec_vars.size())); + + for (int i = 0; i < spec_vars.size(); ++i) { + const SpecVar& spec_var = spec_vars.at(i); + size_t new_seed = 0; + switch (spec_var.type) { + case SpecVar::Type::FLOAT: + new_seed = std::hash()(spec_var.value.as_float); + break; + case SpecVar::Type::INT: + new_seed = std::hash()(spec_var.value.as_int32); + break; + case SpecVar::Type::UINT: + new_seed = std::hash()(spec_var.value.as_uint32); + break; + case SpecVar::Type::BOOL: + new_seed = std::hash()(spec_var.value.as_bool); + break; + } + seed = utils::hash_combine(seed, new_seed); + } return seed; } diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 271977c1450..64712f381ca 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -50,6 +50,53 @@ TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) { ASSERT_TRUE(kernel.kernel_name == "test_shader"); } +TEST_F(VulkanComputeAPITest, spec_var_classes_test) { + // Check equality operator + ASSERT_TRUE(SV(1.5f) == SV(1.5f)); + ASSERT_FALSE(SV(15.0f) == SV(15)); + ASSERT_FALSE(SV(1u) == SV(true)); + + size_t sv_size = sizeof(api::SpecVar); + + api::SpecVarList spec_vars = {}; + ASSERT_TRUE(spec_vars.size() == 0); + spec_vars = {SV(1.1f), SV(32), SV(45)}; + ASSERT_TRUE(spec_vars.size() == 3); + api::SpecVarList spec_vars_other = {SV(2.6f), SV(true), SV(78u), SV(5.5f)}; + spec_vars.append(spec_vars_other); + ASSERT_TRUE(spec_vars.size() == 7); + + // Check validity of the data + const api::SpecVar* data = spec_vars.data(); + ASSERT_TRUE(*(reinterpret_cast(data + 3)) == 2.6f); + ASSERT_TRUE(*(reinterpret_cast(data + 1)) == 32); + ASSERT_TRUE(*(reinterpret_cast(data + 5)) == 78u); + + // Check validity of the map entries + std::vector entries = + spec_vars.generate_map_entries(); + + for (size_t i = 0; i < spec_vars.size(); ++i) { + ASSERT_TRUE(entries[i].constantID == i); + ASSERT_TRUE(entries[i].offset == sv_size * i); + if (i != 4) { + ASSERT_TRUE(entries[i].size == 4); + } else { + ASSERT_TRUE(entries[i].size == 1); + } + } + + // Check copy + api::SpecVarList spec_vars_copy(spec_vars); + ASSERT_TRUE(spec_vars_copy.size() == 7); + + // Check validity of the copied data + const api::SpecVar* copy_data = spec_vars_copy.data(); + ASSERT_TRUE(*(reinterpret_cast(copy_data + 4)) == true); + ASSERT_TRUE(*(reinterpret_cast(copy_data + 2)) == 45); + ASSERT_TRUE(*(reinterpret_cast(copy_data + 6)) == 5.5f); +} + TEST_F(VulkanComputeAPITest, update_params_between_submit) { api::context()->set_cmd(/*reusable = */ true); std::vector sizes = {4, 4, 2};