Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 99 additions & 45 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
#include <unordered_map>
#include <mutex>
#include <future>
#include <condition_variable>
#include <thread>

#if defined(_MSC_VER)
Expand Down Expand Up @@ -158,8 +159,9 @@ struct vk_pipeline_struct {
uint32_t align;
// true if fields have been set by ggml_vk_create_pipeline
bool initialized {};
// set to true to request the pipeline is compiled
std::atomic<bool> needed {};
// true while a compile is in flight, used to dedupe concurrent claims.
// Protected by device->compile_mutex.
bool compile_pending {};
// set to true when the shader has been compiled
std::atomic<bool> compiled {};
// number of registers used, extracted from pipeline executable properties
Expand Down Expand Up @@ -619,6 +621,13 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie
struct vk_device_struct {
std::recursive_mutex mutex;

// Guards compile_pending, all_pipelines, and the dynamic pipeline maps
// (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile
// runs with no lock held, so different pipelines can compile in parallel.
// Lock order is device->mutex -> compile_mutex, never the reverse.
std::mutex compile_mutex;
std::condition_variable compile_cv;

vk::PhysicalDevice physical_device;
vk::PhysicalDeviceProperties properties;
std::string name;
Expand Down Expand Up @@ -1726,7 +1735,7 @@ struct ggml_vk_garbage_collector {
};

static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
static void ggml_vk_load_shaders(vk_device& device);
static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr);
Comment thread
0cc4m marked this conversation as resolved.
static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);

static bool vk_memory_logger_enabled = false;
Expand Down Expand Up @@ -2193,11 +2202,6 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
ctx->device->device.resetFences({ ctx->fence });
}

// variables to track number of compiles in progress
static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;

static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
Expand Down Expand Up @@ -2492,7 +2496,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
throw e;
}
pipeline->compiled = true;

if (vk_instance.debug_utils_support) {
vk::DebugUtilsObjectNameInfoEXT duoni;
Expand Down Expand Up @@ -2541,14 +2544,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
}
}

device->all_pipelines.push_back(pipeline);

{
std::lock_guard<std::mutex> guard(compile_count_mutex);
assert(compile_count > 0);
compile_count--;
std::lock_guard<std::mutex> guard(device->compile_mutex);
device->all_pipelines.push_back(pipeline);
pipeline->compiled = true;
pipeline->compile_pending = false;
}
compile_count_cond.notify_all();
device->compile_cv.notify_all();
}

static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
Expand All @@ -2564,8 +2566,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx,
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
ctx->pipeline_descriptor_set_requirements += n;
if (!pipeline->compiled) {
pipeline->needed = true;
ggml_vk_load_shaders(ctx->device);
ggml_vk_load_shaders(ctx->device, pipeline);
}
ggml_pipeline_allocate_descriptor_sets(ctx);
}
Expand Down Expand Up @@ -3557,10 +3558,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type
#endif
}

static void ggml_vk_load_shaders(vk_device& device) {
// load_shaders walks the pipeline list under compile_mutex and either claims
// the requested pipeline for compilation or, if another thread is already
// compiling it, drops the lock and waits on compile_cv. Compiles themselves
// run unlocked.
struct CompileTask {
vk_pipeline pipeline;
size_t spv_size;
const void * spv_data;
std::string entrypoint;
uint32_t parameter_count;
std::array<uint32_t, 3> wg_denoms;
std::vector<uint32_t> specialization_constants;
bool disable_robustness;
bool require_full_subgroups;
uint32_t required_subgroup_size;
};

static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");

std::lock_guard<std::recursive_mutex> guard(device->mutex);
// some shaders have a minimum subgroup size
const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
Expand Down Expand Up @@ -3590,6 +3607,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;

uint32_t l_align, m_align, s_align;

vk_pipeline wait_pipeline;
CompileTask claimed_task {};
bool has_claimed_task = false;

// The rest of the walk reads and writes shared device state, so hold the
// lock until we're done deciding what to compile.
std::unique_lock<std::mutex> compile_lock(device->compile_mutex);

if (device->coopmat2) {
// spec constants and tile sizes for non-quant matmul/matmul_id
l_warptile = { 256, 128, 256, 64, 1 };
Expand Down Expand Up @@ -3775,7 +3801,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
}

std::vector<std::future<void>> compiles;
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
Expand Down Expand Up @@ -3809,23 +3834,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
}

if (!pipeline->needed || pipeline->compiled) {
Comment thread
0cc4m marked this conversation as resolved.
// We only care about the pipeline this call asked for; the rest
// (including the 64-bit indexing variant) are handled by their
// own request_descriptor_sets / load_shaders calls.
if (pipeline.get() != requested.get()) {
continue;
}
// TODO: We're no longer benefitting from the async compiles (shaders are
// compiled individually, as needed) and this complexity can be removed.
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
std::unique_lock<std::mutex> guard(compile_count_mutex);
while (compile_count >= N) {
compile_count_cond.wait(guard);
}
compile_count++;

if (pipeline->compiled) {
continue;
}

compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
wait_pipeline = pipeline;

if (!pipeline->compile_pending) {
pipeline->compile_pending = true;
claimed_task.pipeline = pipeline;
claimed_task.spv_size = spv_size;
claimed_task.spv_data = spv_data;
claimed_task.entrypoint = entrypoint;
claimed_task.parameter_count = parameter_count;
claimed_task.wg_denoms = wg_denoms;
claimed_task.specialization_constants = specialization_constants;
claimed_task.disable_robustness = disable_robustness;
claimed_task.require_full_subgroups = require_full_subgroups;
claimed_task.required_subgroup_size = required_subgroup_size;
has_claimed_task = true;
}
}
};

Expand Down Expand Up @@ -5291,8 +5326,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
}

for (auto &c : compiles) {
c.wait();
// Drop compile_mutex so other threads can walk while we compile.
compile_lock.unlock();

// Compile what we claimed; create_pipeline_func reacquires compile_mutex
// at the end to flip compile_pending/compiled and notify waiters.
if (has_claimed_task) {
auto & task = claimed_task;
ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data,
task.entrypoint, task.parameter_count, task.wg_denoms,
task.specialization_constants, task.disable_robustness,
task.require_full_subgroups, task.required_subgroup_size);
}

// Another thread may be compiling the pipeline we need; block on it here.
if (wait_pipeline) {
std::unique_lock<std::mutex> wait_lock(device->compile_mutex);
device->compile_cv.wait(wait_lock, [&] {
return wait_pipeline->compiled.load();
});
}
}

Expand Down Expand Up @@ -9650,7 +9702,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_pipeline pipeline = nullptr;

{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
auto it = pipelines.find(fa_pipeline_state);
if (it != pipelines.end()) {
Expand Down Expand Up @@ -9714,13 +9766,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx

vk_pipeline pipeline_fa_mask_opt = nullptr;
if (use_mask_opt) {
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
auto it = pipelines.find({Br, Bc});
if (it != pipelines.end()) {
pipeline_fa_mask_opt = it->second;
} else {
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
{
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto &pipelines = ctx->device->pipeline_fa_mask_opt;
auto it = pipelines.find({Br, Bc});
if (it != pipelines.end()) {
pipeline_fa_mask_opt = it->second;
} else {
pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
}
}
assert(pipeline_fa_mask_opt);
ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
Expand Down Expand Up @@ -10254,7 +10308,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
vk_pipeline pipeline = nullptr;

{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
pipeline = it->second;
Expand Down Expand Up @@ -10413,7 +10467,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
vk_pipeline pipeline = nullptr;

{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
auto it = pipelines->find(conv2d_pipeline_state);
if (it != pipelines->end()) {
pipeline = it->second;
Expand Down
Loading