diff --git a/backends/vulkan/runtime/api/Adapter.cpp b/backends/vulkan/runtime/api/Adapter.cpp index b1930ad1de4..a02a6aa3e0a 100644 --- a/backends/vulkan/runtime/api/Adapter.cpp +++ b/backends/vulkan/runtime/api/Adapter.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +// @lint-ignore-every CLANGTIDY clang-diagnostic-missing-field-initializers + #include #include @@ -21,15 +23,33 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) : handle(physical_device_handle), properties{}, memory_properties{}, + shader_16bit_storage{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}, + shader_8bit_storage{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}, + shader_float16_int8_types{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR}, queue_families{}, num_compute_queues(0), has_unified_memory(false), has_timestamps(properties.limits.timestampComputeAndGraphics), - timestamp_period(properties.limits.timestampPeriod) { + timestamp_period(properties.limits.timestampPeriod), + extension_features(&shader_16bit_storage) { // Extract physical device properties vkGetPhysicalDeviceProperties(handle, &properties); vkGetPhysicalDeviceMemoryProperties(handle, &memory_properties); + VkPhysicalDeviceFeatures2 features2{ + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + + // Create linked list to query availability of extensions + features2.pNext = &shader_16bit_storage; + shader_16bit_storage.pNext = &shader_8bit_storage; + shader_8bit_storage.pNext = &shader_float16_int8_types; + shader_float16_int8_types.pNext = nullptr; + + vkGetPhysicalDeviceFeatures2(handle, &features2); + // Check if there are any memory types have both the HOST_VISIBLE and the // DEVICE_LOCAL property flags const VkMemoryPropertyFlags unified_memory_flags = @@ -140,6 +160,9 @@ VkDevice create_logical_device( #ifdef VK_KHR_portability_subset VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME, #endif /* VK_KHR_portability_subset */ + VK_KHR_16BIT_STORAGE_EXTENSION_NAME, + VK_KHR_8BIT_STORAGE_EXTENSION_NAME, + VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME, }; std::vector enabled_device_extensions; @@ -148,7 +171,7 @@ VkDevice create_logical_device( enabled_device_extensions, requested_device_extensions); - const VkDeviceCreateInfo device_create_info{ + VkDeviceCreateInfo device_create_info{ VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType nullptr, // pNext 0u, // flags @@ -162,6 +185,8 @@ VkDevice create_logical_device( nullptr, // pEnabledFeatures }; + device_create_info.pNext = physical_device.extension_features; + VkDevice handle = nullptr; VK_CHECK(vkCreateDevice( physical_device.handle, &device_create_info, nullptr, &handle)); @@ -371,33 +396,53 @@ std::string Adapter::stringize() const { ss << " deviceType: " << device_type << std::endl; ss << " deviceName: " << properties.deviceName << std::endl; -#define PRINT_LIMIT_PROP(name) \ - ss << " " << std::left << std::setw(36) << #name << limits.name \ +#define PRINT_PROP(struct, name) \ + ss << " " << std::left << std::setw(36) << #name << struct.name \ << std::endl; -#define PRINT_LIMIT_PROP_VEC3(name) \ - ss << " " << std::left << std::setw(36) << #name << limits.name[0] \ - << "," << limits.name[1] << "," << limits.name[2] << std::endl; +#define PRINT_PROP_VEC3(struct, name) \ + ss << " " << std::left << std::setw(36) << #name << struct.name[0] \ + << "," << struct.name[1] << "," << struct.name[2] << std::endl; ss << " Physical Device Limits {" << std::endl; - PRINT_LIMIT_PROP(maxImageDimension1D); - PRINT_LIMIT_PROP(maxImageDimension2D); - PRINT_LIMIT_PROP(maxImageDimension3D); - PRINT_LIMIT_PROP(maxTexelBufferElements); - PRINT_LIMIT_PROP(maxPushConstantsSize); - PRINT_LIMIT_PROP(maxMemoryAllocationCount); - PRINT_LIMIT_PROP(maxSamplerAllocationCount); - PRINT_LIMIT_PROP(maxComputeSharedMemorySize); - PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupCount); - PRINT_LIMIT_PROP(maxComputeWorkGroupInvocations); - PRINT_LIMIT_PROP_VEC3(maxComputeWorkGroupSize); + PRINT_PROP(limits, maxImageDimension1D); + PRINT_PROP(limits, maxImageDimension2D); + PRINT_PROP(limits, maxImageDimension3D); + PRINT_PROP(limits, maxTexelBufferElements); + PRINT_PROP(limits, maxPushConstantsSize); + PRINT_PROP(limits, maxMemoryAllocationCount); + PRINT_PROP(limits, maxSamplerAllocationCount); + PRINT_PROP(limits, maxComputeSharedMemorySize); + PRINT_PROP_VEC3(limits, maxComputeWorkGroupCount); + PRINT_PROP(limits, maxComputeWorkGroupInvocations); + PRINT_PROP_VEC3(limits, maxComputeWorkGroupSize); + ss << " }" << std::endl; + + ss << " 16bit Storage Features {" << std::endl; + PRINT_PROP(physical_device_.shader_16bit_storage, storageBuffer16BitAccess); + PRINT_PROP( + physical_device_.shader_16bit_storage, + uniformAndStorageBuffer16BitAccess); + PRINT_PROP(physical_device_.shader_16bit_storage, storagePushConstant16); + PRINT_PROP(physical_device_.shader_16bit_storage, storageInputOutput16); + ss << " }" << std::endl; + + ss << " 8bit Storage Features {" << std::endl; + PRINT_PROP(physical_device_.shader_8bit_storage, storageBuffer8BitAccess); + PRINT_PROP( + physical_device_.shader_8bit_storage, uniformAndStorageBuffer8BitAccess); + PRINT_PROP(physical_device_.shader_8bit_storage, storagePushConstant8); + ss << " }" << std::endl; + + ss << " Shader 16bit and 8bit Features {" << std::endl; + PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16); + PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8); ss << " }" << std::endl; - ss << " }" << std::endl; - ; const VkPhysicalDeviceMemoryProperties& mem_props = physical_device_.memory_properties; + ss << " }" << std::endl; ss << " Memory Info {" << std::endl; ss << " Memory Types [" << std::endl; for (size_t i = 0; i < mem_props.memoryTypeCount; ++i) { @@ -432,6 +477,9 @@ std::string Adapter::stringize() const { ss << " ]" << std::endl; ss << "}"; +#undef PRINT_PROP +#undef PRINT_PROP_VEC3 + return ss.str(); } diff --git a/backends/vulkan/runtime/api/Adapter.h b/backends/vulkan/runtime/api/Adapter.h index afbb48f4059..b038aea9fa8 100644 --- a/backends/vulkan/runtime/api/Adapter.h +++ b/backends/vulkan/runtime/api/Adapter.h @@ -30,6 +30,12 @@ struct PhysicalDevice final { // Properties obtained from Vulkan VkPhysicalDeviceProperties properties; VkPhysicalDeviceMemoryProperties memory_properties; + // Additional features available from extensions + VkPhysicalDevice16BitStorageFeatures shader_16bit_storage; + VkPhysicalDevice8BitStorageFeatures shader_8bit_storage; + VkPhysicalDeviceShaderFloat16Int8Features shader_float16_int8_types; + + // Available GPU queues std::vector queue_families; // Metadata @@ -38,6 +44,9 @@ struct PhysicalDevice final { bool has_timestamps; float timestamp_period; + // Head of the linked list of extensions to be requested + void* extension_features{nullptr}; + explicit PhysicalDevice(VkPhysicalDevice); }; @@ -189,6 +198,34 @@ class Adapter final { return vma_; } + // Physical Device Features + + inline bool has_16bit_storage() { + return physical_device_.shader_16bit_storage.storageBuffer16BitAccess == + VK_TRUE; + } + + inline bool has_8bit_storage() { + return physical_device_.shader_8bit_storage.storageBuffer8BitAccess == + VK_TRUE; + } + + inline bool has_16bit_compute() { + return physical_device_.shader_float16_int8_types.shaderFloat16 == VK_TRUE; + } + + inline bool has_8bit_compute() { + return physical_device_.shader_float16_int8_types.shaderInt8 == VK_TRUE; + } + + inline bool has_full_float16_buffers_support() { + return has_16bit_storage() && has_16bit_compute(); + } + + inline bool has_full_int8_buffers_support() { + return has_8bit_storage() && has_8bit_compute(); + } + // Command Buffer Submission void diff --git a/backends/vulkan/runtime/api/Runtime.cpp b/backends/vulkan/runtime/api/Runtime.cpp index dee9f94fd3b..e113a4e3b4f 100644 --- a/backends/vulkan/runtime/api/Runtime.cpp +++ b/backends/vulkan/runtime/api/Runtime.cpp @@ -85,7 +85,7 @@ VkInstance create_instance(const RuntimeConfiguration& config) { 0, // applicationVersion nullptr, // pEngineName 0, // engineVersion - VK_API_VERSION_1_0, // apiVersion + VK_API_VERSION_1_1, // apiVersion }; std::vector enabled_layers; diff --git a/backends/vulkan/runtime/api/Tensor.cpp b/backends/vulkan/runtime/api/Tensor.cpp index 019c3ab736f..bffe00c836b 100644 --- a/backends/vulkan/runtime/api/Tensor.cpp +++ b/backends/vulkan/runtime/api/Tensor.cpp @@ -228,7 +228,14 @@ vTensor::vTensor( memory_layout_, gpu_sizes_, dtype_, - allocate_memory)) {} + allocate_memory)) { + if (dtype == api::kHalf) { + VK_CHECK_COND( + api::context()->adapter_ptr()->has_16bit_storage(), + "Half dtype is only available if the physical device supports float16 " + "storage buffers!"); + } +} vTensor::vTensor( api::Context* const context, diff --git a/backends/vulkan/runtime/api/Types.h b/backends/vulkan/runtime/api/Types.h index 3c1f9e2056b..c63f164aa8f 100644 --- a/backends/vulkan/runtime/api/Types.h +++ b/backends/vulkan/runtime/api/Types.h @@ -23,15 +23,15 @@ #define VK_FORMAT_FLOAT4 VK_FORMAT_R32G32B32A32_SFLOAT #endif /* USE_VULKAN_FP16_INFERENCE */ -#define VK_FORALL_SCALAR_TYPES(_) \ - _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ - _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ - _(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \ - _(float, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ - _(float, VK_FORMAT_FLOAT4, Float) \ - _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ - _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ +#define VK_FORALL_SCALAR_TYPES(_) \ + _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ + _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ + _(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \ + _(uint16_t, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ + _(float, VK_FORMAT_FLOAT4, Float) \ + _(int8_t, VK_FORMAT_R8G8B8A8_SINT, QInt8) \ + _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, QUInt8) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, QInt32) namespace vkcompute { diff --git a/backends/vulkan/runtime/api/gen_vulkan_spv.py b/backends/vulkan/runtime/api/gen_vulkan_spv.py index 89f6353944b..c7a5eda8b13 100644 --- a/backends/vulkan/runtime/api/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/api/gen_vulkan_spv.py @@ -100,6 +100,22 @@ def get_buffer_scalar_type(dtype: str) -> str: return dtype +def get_buffer_gvec_type(dtype: str, n: int) -> str: + if n == 1: + return get_buffer_scalar_type(dtype) + + if dtype == "float": + return f"vec{n}" + elif dtype == "half": + return f"f16vec{n}" + elif dtype == "int8": + return f"i8vec{n}" + elif dtype == "uint8": + return f"u8vec{n}" + + raise AssertionError(f"Invalid dtype: {dtype}") + + def get_texel_type(dtype: str) -> str: image_format = TYPE_MAPPINGS["IMAGE_FORMAT"][dtype] if image_format[-1] == "f": @@ -134,6 +150,7 @@ def get_texel_component_type(dtype: str) -> str: 2: lambda pos: f"{pos}.xy", }, "buffer_scalar_type": get_buffer_scalar_type, + "buffer_gvec_type": get_buffer_gvec_type, "texel_type": get_texel_type, "gvec_type": get_gvec_type, "texel_component_type": get_texel_component_type, @@ -456,7 +473,7 @@ def generateSPV(self, output_dir: str) -> Dict[str, str]: glsl_out_path, "-o", spv_out_path, - "--target-env=vulkan1.0", + "--target-env=vulkan1.1", "-Werror", ] + [ arg diff --git a/backends/vulkan/test/glsl/all_shaders.yaml b/backends/vulkan/test/glsl/all_shaders.yaml index 09f051e04b5..46a1cd2bf5f 100644 --- a/backends/vulkan/test/glsl/all_shaders.yaml +++ b/backends/vulkan/test/glsl/all_shaders.yaml @@ -33,6 +33,17 @@ fill_texture__test: shader_variants: - NAME: fill_texture__test +idx_fill_buffer: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + - VALUE: int8 + shader_variants: + - NAME: idx_fill_buffer + idx_fill_texture: parameter_names_with_default_values: DTYPE: float diff --git a/backends/vulkan/test/glsl/idx_fill_buffer.glsl b/backends/vulkan/test/glsl/idx_fill_buffer.glsl new file mode 100644 index 00000000000..98cf04e338d --- /dev/null +++ b/backends/vulkan/test/glsl/idx_fill_buffer.glsl @@ -0,0 +1,48 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#include "indexing_utils.h" + +$if DTYPE == "half": + #extension GL_EXT_shader_16bit_storage : require + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +$elif DTYPE == "int8": + #extension GL_EXT_shader_8bit_storage : require + #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +$elif DTYPE == "uint8": + #extension GL_EXT_shader_8bit_storage : require + #extension GL_EXT_shader_explicit_arithmetic_types_uint8 : require + +layout(std430) buffer; + +layout(set = 0, binding = 0) buffer PRECISION restrict writeonly Buffer { + VEC4_T data[]; +} +buffer_in; + +layout(set = 0, binding = 1) uniform PRECISION restrict Params { + int len; +} +params; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int i = ivec3(gl_GlobalInvocationID).x; + + const int base = 4 * i; + if (base < params.len) { + buffer_in.data[i] = VEC4_T(base, base + 1, base + 2, base + 3); + } +} diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index bf9580d1bfc..ff0b03546fa 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -8,6 +8,8 @@ #include +#include + #include #include @@ -37,6 +39,10 @@ class VulkanComputeAPITest : public ::testing::Test { } }; +TEST_F(VulkanComputeAPITest, print_adapter) { + std::cout << *(api::context()->adapter_ptr()) << std::endl; +} + TEST_F(VulkanComputeAPITest, retrieve_custom_shader_test) { // Try to get shader from custom shader library const api::ShaderInfo& kernel = VK_KERNEL(test_shader); @@ -99,6 +105,72 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) { check_staging_buffer(staging_buffer, 4.0f); } +template +void test_storage_buffer_type(const size_t len) { + api::StorageBuffer buffer(api::context(), dtype, len); + + std::string kernel_name("idx_fill_buffer"); + switch (dtype) { + case api::kFloat: + kernel_name += "_float"; + break; + case api::kHalf: + kernel_name += "_half"; + break; + case api::kQInt8: + kernel_name += "_int8"; + break; + case api::kQUInt8: + kernel_name += "_uint8"; + break; + default: + throw std::runtime_error("Unsupported dtype"); + break; + } + + api::UniformParamsBuffer params(api::context(), int32_t(len)); + + { + uint32_t len_div4 = api::utils::div_up(uint32_t(len), uint32_t(4)); + api::PipelineBarrier pipeline_barrier{}; + api::context()->submit_compute_job( + VK_KERNEL_FROM_STR(kernel_name), + pipeline_barrier, + {64, 1, 1}, + {len_div4, 1, 1}, + VK_NULL_HANDLE, + buffer.buffer(), + params.buffer()); + } + + submit_to_gpu(); + + std::vector data(len); + copy_staging_to_ptr(buffer, data.data(), buffer.nbytes()); + + for (size_t i = 0; i < len; ++i) { + CHECK_VALUE(data, i, T(i)); + } +} + +TEST_F(VulkanComputeAPITest, test_buffer_float) { + test_storage_buffer_type(16); +} + +TEST_F(VulkanComputeAPITest, test_buffer_float16) { + if (!api::context()->adapter_ptr()->has_full_float16_buffers_support()) { + GTEST_SKIP(); + } + test_storage_buffer_type(16); +} + +TEST_F(VulkanComputeAPITest, test_buffer_int8) { + if (!api::context()->adapter_ptr()->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + test_storage_buffer_type(16); +} + TEST_F(VulkanComputeAPITest, texture_add_sanity_check) { // Simple test that performs a + b -> c