Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ struct vk_device_struct {
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;

// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
vk_pipeline pipeline_add[2][2][2];
Expand Down Expand Up @@ -4181,7 +4182,8 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);

ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -8815,6 +8817,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_acc_f32;
}
return nullptr;
case GGML_OP_SET:
if (src0->type == src1->type && src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
return ctx->device->pipeline_set_f32;
}
return nullptr;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
Expand Down Expand Up @@ -9806,7 +9814,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
int offset = dst->op_params[3] / src0_type_size; // offset in bytes

ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
Expand Down Expand Up @@ -12500,6 +12508,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr

break;
case GGML_OP_ACC:
case GGML_OP_SET:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);

break;
Expand Down Expand Up @@ -14960,7 +14969,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_SET:
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
Expand Down Expand Up @@ -15611,6 +15623,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_SET) {
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {
Expand Down
9 changes: 8 additions & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include "types.glsl"
#include "generic_binary_head.glsl"

// false for SET, true for ACC
layout(constant_id = 1) const bool ACC = true;

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
Expand All @@ -23,7 +26,11 @@ void main() {
uint i00, i01, i02, i03;

if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
if (ACC) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
}
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
}
Expand Down
Loading