Skip to content
Merged
128 changes: 95 additions & 33 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ struct vk_queue;

struct vk_command_buffer {
vk::CommandBuffer buf;
uint64_t use_counter = 0;
bool in_use = false;
};

Expand Down Expand Up @@ -938,19 +939,24 @@ struct vk_subbuffer {
}
};

// vk_event is used for the event-related backend interfaces. It uses 'event' for
// event_wait and 'fence' for event_synchronize. Polling on an event for
struct vk_semaphore {
vk::Semaphore s;
uint64_t value;
};

// vk_event is used for the event-related backend interfaces. It uses vk::Events for
// event_wait and a timeline semaphore for event_synchronize. Polling on an event for
// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
// and would lead to validation errors.
struct vk_event {
std::vector<vk::Event> events_free; // Events available for reuse
std::vector<vk::Event> events_submitted; // Events that are fully submitted and can be reused on next synchronize
vk::Event event;
vk::Fence fence;
vk_command_buffer* cmd_buffer = nullptr;
};
bool has_event;

struct vk_semaphore {
vk::Semaphore s;
uint64_t value;
vk_semaphore tl_semaphore;
vk_command_buffer* cmd_buffer = nullptr;
uint64_t cmd_buffer_use_counter = 0;
};

struct vk_submission {
Expand Down Expand Up @@ -2319,7 +2325,7 @@ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_comman
vk::CommandBufferLevel::ePrimary,
1);
const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
p.cmd_buffers.push_back({ cmd_buffers.front(), true });
p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true });
return &p.cmd_buffers[p.cmd_buffers.size()-1];
}

Expand Down Expand Up @@ -2788,6 +2794,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
);
}

static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) {
VK_LOG_DEBUG("ggml_vk_set_event()");

ctx->s->buffer->buf.resetEvent(
event,
ctx->p->q->stage_flags
);
}

static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
VK_LOG_DEBUG("ggml_vk_set_event()");

Expand Down Expand Up @@ -6392,6 +6407,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
for (auto& cmd_buffer : pool.cmd_buffers) {
if (!cmd_buffer.in_use) {
cmd_buffer.use_counter++;
cmd_buffer.in_use = true;
return &cmd_buffer;
}
Expand Down Expand Up @@ -6496,14 +6512,15 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
}

static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
vk_context result;
if (!ctx->compute_ctx.expired()) {
return ctx->compute_ctx.lock();
}

vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
result = ctx->compute_ctx.lock();
} else {
result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);

ctx->compute_ctx = result;
ggml_vk_ctx_begin(ctx->device, result);
ctx->compute_ctx = result;
ggml_vk_ctx_begin(ctx->device, result);
}

if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
Expand Down Expand Up @@ -13797,6 +13814,7 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
ctx->submit_pending = false;
if (cmd_buf) {
cmd_buf->in_use = false;
cmd_buf->buf.reset();
}
}

Expand Down Expand Up @@ -14858,18 +14876,31 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset

// the backend interface doesn't have an explicit reset, so reset it here
// before we record the command to set it
ctx->device->device.resetEvent(vkev->event);
ctx->device->device.resetFences({ vkev->fence });
if (vkev->has_event) {
// Move existing event into submitted
vkev->events_submitted.push_back(vkev->event);
}

// Grab the next event and record it, create one if necessary
if (vkev->events_free.empty()) {
vkev->event = ctx->device->device.createEvent({});
} else {
vkev->event = vkev->events_free.back();
vkev->events_free.pop_back();
}

vkev->has_event = true;

ggml_vk_set_event(compute_ctx, vkev->event);

vkev->tl_semaphore.value++;
compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore);
ggml_vk_ctx_end(compute_ctx);

ggml_vk_submit(compute_ctx, {vkev->fence});
ggml_vk_submit(compute_ctx, {});
ctx->submit_pending = true;
vkev->cmd_buffer = cmd_buf;
vkev->cmd_buffer_use_counter = cmd_buf->use_counter;
ctx->compute_ctx.reset();
}

Expand All @@ -14880,9 +14911,10 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even

vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);

ggml_vk_wait_events(compute_ctx, {vkev->event});
ggml_vk_ctx_end(compute_ctx);
ctx->compute_ctx.reset();
if (vkev->has_event) {
// Wait for latest event
ggml_vk_wait_events(compute_ctx, { vkev->event });
}
}

// TODO: enable async and synchronize
Expand Down Expand Up @@ -15672,10 +15704,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t
return nullptr;
}

// The event/fence is expected to initially be in the signaled state.
vkev->event = device->device.createEvent({});
vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
device->device.setEvent(vkev->event);
// No events initially, they get created on demand
vkev->has_event = false;

vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
vk::SemaphoreCreateInfo ci{};
ci.setPNext(&tci);
vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 };

return new ggml_backend_event {
/* .device = */ dev,
Expand All @@ -15689,8 +15724,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe

vk_event *vkev = (vk_event *)event->context;

device->device.destroyFence(vkev->fence);
device->device.destroyEvent(vkev->event);
device->device.destroySemaphore(vkev->tl_semaphore.s);
for (auto& event : vkev->events_free) {
device->device.destroyEvent(event);
}
for (auto& event : vkev->events_submitted) {
device->device.destroyEvent(event);
}
if (vkev->has_event) {
device->device.destroyEvent(vkev->event);
}
delete vkev;
delete event;
}
Expand All @@ -15701,10 +15744,29 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
auto device = ggml_vk_get_device(ctx->device);
vk_event *vkev = (vk_event *)event->context;

VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
// Finished using current command buffer so we flag for reuse
if (vkev->cmd_buffer) {
vkev->cmd_buffer->in_use = false;
// Only do something if the event has actually been used
if (vkev->has_event) {
vk::Semaphore sem = vkev->tl_semaphore.s;
uint64_t val = vkev->tl_semaphore.value;
vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val};
VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize");

// Reset and move submitted events
for (auto& event : vkev->events_submitted) {
device->device.resetEvent(event);
}
vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end());
vkev->events_submitted.clear();

// Finished using current command buffer so we flag for reuse
if (vkev->cmd_buffer) {
// Only flag for reuse if it hasn't been reused already
if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) {
vkev->cmd_buffer->in_use = false;
vkev->cmd_buffer->buf.reset();
}
vkev->cmd_buffer = nullptr;
}
}
}

Expand Down
Loading