diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index df0f405ed9f..440f882f0f2 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -536,6 +536,12 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_3)) { + ggml_backend_meta_split_state ret = src_ss[1]; + ret.n_segments = 1; + return ret; + } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { ggml_backend_meta_split_state ret = src_ss[0]; ret.axis = GGML_BACKEND_SPLIT_AXIS_0; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d62abc4009b..b3c3a7adf5b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3366,10 +3366,6 @@ llama_context * llama_init_from_model( LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); return nullptr; } - if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { - LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); - return nullptr; - } } if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 858c297dd76..1abd2f6a75a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -65,10 +65,11 @@ static ggml_tensor * ggml_mul_mat_aux( ggml_tensor * res; + GGML_ASSERT(cur->ne[0] % n == 0); if (!ggml_is_contiguous(cur)) { - res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_cont_4d(ctx, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]*cur->ne[3]); } else { - res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]*cur->ne[3]); } res = ggml_mul_mat (ctx, rot, res); ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a49a055a630..dff325b8473 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -65,7 +65,8 @@ static ggml_tensor * ggml_mul_mat_aux( ggml_tensor * res; - res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + GGML_ASSERT(cur->ne[0] % n == 0); + res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]*cur->ne[3]); res = ggml_mul_mat (ctx, rot, res); ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);