diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index b6f73739809..2924fdbe988 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -381,11 +381,15 @@ extern "C" { // - most tensors have n_segments == 1 and a contiguous slice of the tensor data // - some tensors have an inhomogenenous data layout along the split axis, // those tensors are divided into segments which are each individually split across devices - // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis, - // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1], + // - ne has one entry per segment and device and that segment repeats nr times, + // in total when accounting for repetitions the segments add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0_r0, seg0_dev1_r0, seg0_dev0_r1, seg0_dev1_r1, seg1_dev0_r0, seg1_dev1_r0], // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments - // that each need to be split individually across devices so that each device gets a slice of Q, K, and V + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V, + // the Q matrix can be larger than the K and V matrices so this can either be expressed as 3 segments or as 2 segments + // where the segment for K/V repeats twice int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t nr[16]; uint32_t n_segments; }; diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 48b2027fac3..8c44c3e44ae 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -487,6 +487,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { + // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way. + // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there. + // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results. const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; @@ -497,11 +500,11 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( for (size_t j = 0; j < n_bufs; j++) { int64_t sum_a = 0; for (size_t s = 0; s < a.n_segments; s++) { - sum_a += a.ne[s*n_bufs + j]; + sum_a += a.ne[s*n_bufs + j] * a.nr[s]; } int64_t sum_b = 0; for (size_t s = 0; s < b.n_segments; s++) { - sum_b += b.ne[s*n_bufs + j]; + sum_b += b.ne[s*n_bufs + j] * b.nr[s]; } if (sum_a != sum_b) { return false; @@ -511,7 +514,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( }; auto handle_generic = [&](const std::vector & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { - ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}; for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { continue; @@ -519,15 +522,15 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { ret = src_ss[i]; } else if (!split_states_equal(src_ss[i], ret)) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; break; } } if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); return ret; @@ -571,42 +574,24 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_mul_mat = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { ggml_backend_meta_split_state ret = src_ss[0]; ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.nr[0] = 1; ret.n_segments = 1; return ret; } if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - ggml_backend_meta_split_state ret = src_ss[1]; - ret.n_segments = 1; - return ret; + return src_ss[1]; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); - return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1}; + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1}; } GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; - }; - - auto handle_cpy = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { - if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { - int64_t ne_split_src = tensor->src[0]->ne[0]; - for (int dim = 1; dim <= src_ss[0].axis; dim++) { - ne_split_src *= tensor->src[0]->ne[dim]; - } - int64_t ne_split_dst = 1; - for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { - ne_split_dst *= tensor->ne[dim]; - if (ne_split_dst == ne_split_src) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } - } - } - return handle_generic(src_ss, /*scalar_only =*/ false); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_reshape = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { @@ -615,33 +600,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0])); - if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) { - return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1}; } - std::vector base_ne_in; - base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis); - { - base_ne_in.push_back(1); - int dim = 0; - for (; dim <= src_ss[0].axis; dim++) { - base_ne_in[0] *= tensor->src[0]->ne[dim]; - } - for (; dim <= GGML_MAX_DIMS; dim++) { - base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]); - } + int64_t base_ne_in = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + base_ne_in *= tensor->src[0]->ne[dim]; } + base_ne_in /= src_ss[0].nr[0]; int64_t base_ne_out = 1; for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; - for (const int64_t & bni : base_ne_in) { - if (bni == base_ne_out_next) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } + if (base_ne_out_next % base_ne_in == 0) { + return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1}; } - if (base_ne_out_next > base_ne_in[0]) { - GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); - return {ggml_backend_meta_split_axis(dim + 1), {0}, 1}; + if (base_ne_out_next > base_ne_in) { + GGML_ASSERT(src_ss[0].n_segments == 1); + GGML_ASSERT(src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } base_ne_out = base_ne_out_next; } @@ -653,11 +630,18 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; + auto handle_cpy = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + return handle_reshape(src_ss); + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + auto handle_view = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { return handle_reshape(src_ss); @@ -681,7 +665,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } } GGML_ABORT("fatal error"); @@ -690,7 +674,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( return src_ss[0]; } GGML_ABORT("view of permuted tensor not implemented"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_permute = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { @@ -699,7 +683,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_MIRRORED: case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { @@ -707,7 +692,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -716,7 +701,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( switch (src_ss[0].axis) { case GGML_BACKEND_SPLIT_AXIS_0: case GGML_BACKEND_SPLIT_AXIS_1: { - return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: @@ -726,7 +712,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -764,16 +750,16 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; }; auto handle_ssm_conv = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == src_ss[1].axis) { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; } } return handle_generic(src_ss, /*scalar_only =*/ false); @@ -781,8 +767,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_gated_delta_net = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { return src_ss[0]; } GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -793,12 +779,12 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; }; auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { if (ggml_nelements(tensor) == 0) { - return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); @@ -807,19 +793,21 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; int64_t ne_sum = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - GGML_ASSERT(ret.ne[sj] % granularity == 0); - ne_sum += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0); + ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); } return ret; } - std::vector src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}); + std::vector src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}); for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { - src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; continue; } src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); @@ -829,7 +817,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_split_state split_state; switch (tensor->op) { case GGML_OP_NONE: { - split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } break; case GGML_OP_DUP: { split_state = handle_generic(src_ss, /*scalar_only =*/ true); @@ -1016,7 +1004,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } break; default: { GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); - split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } break; } if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { @@ -1034,23 +1022,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( split_state.ne[s*n_bufs + j] = 0; } for (size_t s = 0; s < src_ss[i].n_segments; s++) { - split_state.ne[j] += src_ss[i].ne[s*n_bufs + j]; + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } split_state.ne[j] *= tensor->ne[split_state.axis]; if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { - GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0); - split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis]; + const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0]; + GGML_ASSERT(split_state.ne[j] % div == 0); + split_state.ne[j] /= div; } } } else { + GGML_ASSERT(split_state.n_segments == 1); for (size_t j = 0; j < n_bufs; j++) { + // Assert that ratio is consistent: int64_t sum = 0; for (size_t s = 0; s < src_ss[i].n_segments; s++) { - sum += src_ss[i].ne[s*n_bufs + j]; + sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } - // Assert that ratio is consistent: - GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] - == sum * tensor->ne[split_state.axis]); + GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); } } first_src_split_by_axis = false; @@ -1080,13 +1070,14 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( srcs_info += ", "; } const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + GGML_ASSERT(split_state.n_segments == 1); const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); std::string ne_info; for (size_t j = 0; j < n_bufs; j++) { if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(split_state.ne[j]); + ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]); } srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; } @@ -1095,7 +1086,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]); + const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first; + ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]); } GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); @@ -1107,8 +1099,10 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( #ifndef NDEBUG if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { int64_t ne_ret = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - ne_ret += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } assert(ne_ret == tensor->ne[int(ret.axis)]); } @@ -1155,7 +1149,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); ne[split_dim] = 0; for (size_t s = 0; s < split_state.n_segments; s++) { - ne[split_dim] += split_state.ne[s*n_simple_bufs + j]; + ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s]; } for (int i = 0; i < GGML_MAX_DIMS; i++) { if (tensor->nb[i] > tensor->nb[split_dim]) { @@ -1229,7 +1223,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m for (size_t j = 0; j < n_simple_bufs; j++) { int64_t ne_sum = 0; for (size_t s = 0; s < split_state_src.n_segments; s++) { - ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s]; } if (ne_sum == 0) { simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; @@ -1255,8 +1249,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1267,24 +1262,26 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1292,22 +1289,24 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1365,8 +1364,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1377,24 +1377,26 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1402,22 +1404,24 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1675,6 +1679,7 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: @@ -1719,6 +1724,7 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ad36c06667d..99bdf092b3c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3403,10 +3403,6 @@ llama_context * llama_init_from_model( LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); return nullptr; } - if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { - LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); - return nullptr; - } } if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 914fc423b1f..3e236f8c17d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -488,7 +488,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); }; - auto get_split_segments = [&](int axis, uint32_t il) -> std::vector { + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector> { if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { const int64_t head_k_dim = hparams.ssm_d_state; const int64_t head_v_dim = hparams.ssm_d_state; @@ -503,26 +503,26 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return {key_dim, key_dim, value_dim}; + return {{key_dim, 2}, {value_dim, 1}}; } } else { const int64_t head_ratio = n_v_heads / n_k_heads; if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return std::vector(2 + head_ratio, key_dim); + return {{key_dim, 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { - return std::vector(head_ratio, key_dim); + return {{key_dim, head_ratio}}; } if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { - return std::vector(head_ratio, n_k_heads); + return {{n_k_heads, head_ratio}}; } if (std::regex_match(tensor_name, pattern_r_cache)) { - return std::vector(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + return {{key_dim * (hparams.ssm_d_conv - 1), 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_s_cache)) { - return std::vector(head_ratio, n_k_heads * head_v_dim * head_v_dim); + return {{n_k_heads * head_v_dim * head_v_dim, head_ratio}}; } } @@ -530,9 +530,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { @@ -540,17 +540,17 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); - return {n_embd, n_embd_gqa, n_embd_gqa}; + return {{n_embd, 1}, {n_embd_gqa, 2}}; } if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; }; - auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector & segments) -> std::vector { + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector> & segments) -> std::vector { if (hparams.is_recurrent(il)) { // linear attention const int64_t head_dim = hparams.ssm_d_state; @@ -603,16 +603,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return {granularity_kv}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { - GGML_ASSERT(segments.size() == 3); - return {granularity_q, granularity_kv, granularity_kv}; + GGML_ASSERT(segments.size() == 2); + return {granularity_q, granularity_kv}; } } // FFN if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { - GGML_ASSERT(segments.size() <= 2); - return std::vector(segments.size(), blck_size); + GGML_ASSERT(segments.size() == 1); + return {blck_size}; } // everything else @@ -636,11 +636,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_split_scan[j] += tensor_split_scan[j - 1]; } } - const std::vector segments = get_split_segments(split_state.axis, tc.il); + const std::vector> segments = get_split_segments(split_state.axis, tc.il); const std::vector granularity = get_split_granularity(blck_size, tc.il, segments); for (size_t is = 0; is < segments.size(); is++) { - const int64_t ne_s = segments[is]; - const int64_t g_s = granularity[is]; + const int64_t ne_s = segments[is].first; + const uint32_t nr_s = segments[is].second; + const int64_t g_s = granularity[is]; GGML_ASSERT(ne_full % g_s == 0); int64_t low = 0; size_t j = 0; @@ -654,10 +655,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str low = high; } split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + split_state.nr[is] = nr_s; } split_state.n_segments = segments.size(); } else { memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.nr[0] = 1; split_state.n_segments = 1; } return split_state;