Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ struct vk_device_struct {
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
vk_pipeline pipeline_repeat_f16;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
Expand Down Expand Up @@ -4693,6 +4694,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_repeat_f16, "repeat_f16", repeat_f16_len, repeat_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
Comment thread
l8bloom marked this conversation as resolved.
Outdated

#define CREATE_UNARY(name) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
Expand Down Expand Up @@ -9629,9 +9632,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
}
return nullptr;
case GGML_OP_REPEAT:
if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
Comment thread
l8bloom marked this conversation as resolved.
Outdated
return ctx->device->pipeline_repeat_f32;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_repeat_f16;
}
return nullptr;
case GGML_OP_REPEAT_BACK:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
Expand Down Expand Up @@ -16020,7 +16026,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
case GGML_OP_REPEAT:
return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && op->type == op->src[0]->type;
case GGML_OP_REPEAT_BACK:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ROPE:
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,8 @@ void process_shaders() {
string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

string_to_spv("repeat_f16", "repeat.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
Expand Down