|
22 | 22 | #include "shaderop_mul_mat_q4_1.h" |
23 | 23 | #include "shaderop_mul_mat_q6_k.h" |
24 | 24 | #include "shaderop_mul_mat_mat_f32.h" |
| 25 | +#include "shaderop_getrows_f32.h" |
25 | 26 | #include "shaderop_getrows_f16.h" |
26 | 27 | #include "shaderop_getrows_q4_0.h" |
27 | 28 | #include "shaderop_getrows_q4_1.h" |
@@ -1228,6 +1229,14 @@ static void ggml_vk_get_rows( |
1228 | 1229 | seq.record<kp::OpAlgoDispatch>(s_algo); |
1229 | 1230 | } |
1230 | 1231 |
|
| 1232 | +template <typename... Args> |
| 1233 | +static void ggml_vk_get_rows_f32(Args&&... args) { |
| 1234 | + const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv, |
| 1235 | + kp::shader_data::op_getrows_f32_comp_spv_len); |
| 1236 | + |
| 1237 | + ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...); |
| 1238 | +} |
| 1239 | + |
1231 | 1240 | template <typename... Args> |
1232 | 1241 | static void ggml_vk_get_rows_f16(Args&&... args) { |
1233 | 1242 | const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv, |
@@ -1453,6 +1462,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { |
1453 | 1462 | return op->ne[3] == 1; |
1454 | 1463 | case GGML_OP_GET_ROWS: |
1455 | 1464 | switch (op->src[0]->type) { |
| 1465 | + case GGML_TYPE_F32: |
1456 | 1466 | case GGML_TYPE_F16: |
1457 | 1467 | case GGML_TYPE_Q4_0: |
1458 | 1468 | case GGML_TYPE_Q4_1: |
@@ -1737,7 +1747,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml |
1737 | 1747 | } break; |
1738 | 1748 | case GGML_OP_GET_ROWS: |
1739 | 1749 | { |
1740 | | - if (src0t == GGML_TYPE_F16) { |
| 1750 | + if (src0t == GGML_TYPE_F32) { |
| 1751 | + ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); |
| 1752 | + } else if (src0t == GGML_TYPE_F16) { |
1741 | 1753 | ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); |
1742 | 1754 | } else if (src0t == GGML_TYPE_Q4_0) { |
1743 | 1755 | ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); |
|
0 commit comments