diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c9f906d7930..0f495caf94f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -64,6 +64,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { #include #include #include +#include #include #if defined(_MSC_VER) @@ -158,8 +159,9 @@ struct vk_pipeline_struct { uint32_t align; // true if fields have been set by ggml_vk_create_pipeline bool initialized {}; - // set to true to request the pipeline is compiled - std::atomic needed {}; + // true while a compile is in flight, used to dedupe concurrent claims. + // Protected by device->compile_mutex. + bool compile_pending {}; // set to true when the shader has been compiled std::atomic compiled {}; // number of registers used, extracted from pipeline executable properties @@ -619,6 +621,13 @@ static constexpr std::initializer_list> rms_norm_mul_rope_vie struct vk_device_struct { std::recursive_mutex mutex; + // Guards compile_pending, all_pipelines, and the dynamic pipeline maps + // (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile + // runs with no lock held, so different pipelines can compile in parallel. + // Lock order is device->mutex -> compile_mutex, never the reverse. + std::mutex compile_mutex; + std::condition_variable compile_cv; + vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; std::string name; @@ -1726,7 +1735,7 @@ struct ggml_vk_garbage_collector { }; static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx); -static void ggml_vk_load_shaders(vk_device& device); +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr); static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx); static bool vk_memory_logger_enabled = false; @@ -2193,11 +2202,6 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { ctx->device->device.resetFences({ ctx->fence }); } -// variables to track number of compiles in progress -static uint32_t compile_count = 0; -static std::mutex compile_count_mutex; -static std::condition_variable compile_count_cond; - static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367; static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447; static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4; @@ -2492,7 +2496,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::cerr << "ggml_vulkan: " << e.what() << std::endl; throw e; } - pipeline->compiled = true; if (vk_instance.debug_utils_support) { vk::DebugUtilsObjectNameInfoEXT duoni; @@ -2541,14 +2544,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } } - device->all_pipelines.push_back(pipeline); - { - std::lock_guard guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; + std::lock_guard guard(device->compile_mutex); + device->all_pipelines.push_back(pipeline); + pipeline->compiled = true; + pipeline->compile_pending = false; } - compile_count_cond.notify_all(); + device->compile_cv.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -2564,8 +2566,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { - pipeline->needed = true; - ggml_vk_load_shaders(ctx->device); + ggml_vk_load_shaders(ctx->device, pipeline); } ggml_pipeline_allocate_descriptor_sets(ctx); } @@ -3557,10 +3558,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type #endif } -static void ggml_vk_load_shaders(vk_device& device) { +// load_shaders walks the pipeline list under compile_mutex and either claims +// the requested pipeline for compilation or, if another thread is already +// compiling it, drops the lock and waits on compile_cv. Compiles themselves +// run unlocked. +struct CompileTask { + vk_pipeline pipeline; + size_t spv_size; + const void * spv_data; + std::string entrypoint; + uint32_t parameter_count; + std::array wg_denoms; + std::vector specialization_constants; + bool disable_robustness; + bool require_full_subgroups; + uint32_t required_subgroup_size; +}; + +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::lock_guard guard(device->mutex); // some shaders have a minimum subgroup size const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); @@ -3590,6 +3607,15 @@ static void ggml_vk_load_shaders(vk_device& device) { l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; uint32_t l_align, m_align, s_align; + + vk_pipeline wait_pipeline; + CompileTask claimed_task {}; + bool has_claimed_task = false; + + // The rest of the walk reads and writes shared device state, so hold the + // lock until we're done deciding what to compile. + std::unique_lock compile_lock(device->compile_mutex); + if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id l_warptile = { 256, 128, 256, 64, 1 }; @@ -3775,7 +3801,6 @@ static void ggml_vk_load_shaders(vk_device& device) { device->pipeline_matmul_id_bf16 = std::make_shared(); } - std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -3809,23 +3834,33 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif } - if (!pipeline->needed || pipeline->compiled) { + // We only care about the pipeline this call asked for; the rest + // (including the 64-bit indexing variant) are handled by their + // own request_descriptor_sets / load_shaders calls. + if (pipeline.get() != requested.get()) { continue; } - // TODO: We're no longer benefitting from the async compiles (shaders are - // compiled individually, as needed) and this complexity can be removed. - { - // wait until fewer than N compiles are in progress - uint32_t N = std::max(1u, std::thread::hardware_concurrency()); - std::unique_lock guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; + + if (pipeline->compiled) { + continue; } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, - parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + wait_pipeline = pipeline; + + if (!pipeline->compile_pending) { + pipeline->compile_pending = true; + claimed_task.pipeline = pipeline; + claimed_task.spv_size = spv_size; + claimed_task.spv_data = spv_data; + claimed_task.entrypoint = entrypoint; + claimed_task.parameter_count = parameter_count; + claimed_task.wg_denoms = wg_denoms; + claimed_task.specialization_constants = specialization_constants; + claimed_task.disable_robustness = disable_robustness; + claimed_task.require_full_subgroups = require_full_subgroups; + claimed_task.required_subgroup_size = required_subgroup_size; + has_claimed_task = true; + } } }; @@ -5291,8 +5326,25 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - for (auto &c : compiles) { - c.wait(); + // Drop compile_mutex so other threads can walk while we compile. + compile_lock.unlock(); + + // Compile what we claimed; create_pipeline_func reacquires compile_mutex + // at the end to flip compile_pending/compiled and notify waiters. + if (has_claimed_task) { + auto & task = claimed_task; + ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data, + task.entrypoint, task.parameter_count, task.wg_denoms, + task.specialization_constants, task.disable_robustness, + task.require_full_subgroups, task.required_subgroup_size); + } + + // Another thread may be compiling the pipeline we need; block on it here. + if (wait_pipeline) { + std::unique_lock wait_lock(device->compile_mutex); + device->compile_cv.wait(wait_lock, [&] { + return wait_pipeline->compiled.load(); + }); } } @@ -9650,7 +9702,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = nullptr; { - std::lock_guard guard(ctx->device->mutex); + std::lock_guard guard(ctx->device->compile_mutex); auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { @@ -9714,13 +9766,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline_fa_mask_opt = nullptr; if (use_mask_opt) { - std::lock_guard guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_fa_mask_opt; - auto it = pipelines.find({Br, Bc}); - if (it != pipelines.end()) { - pipeline_fa_mask_opt = it->second; - } else { - pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared(); + { + std::lock_guard guard(ctx->device->compile_mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared(); + } } assert(pipeline_fa_mask_opt); ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); @@ -10254,7 +10308,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard guard(ctx->device->mutex); + std::lock_guard guard(ctx->device->compile_mutex); auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); if (it != ctx->device->pipeline_solve_tri_f32.end()) { pipeline = it->second; @@ -10413,7 +10467,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard guard(ctx->device->mutex); + std::lock_guard guard(ctx->device->compile_mutex); auto it = pipelines->find(conv2d_pipeline_state); if (it != pipelines->end()) { pipeline = it->second;