diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 48b2027fac3..b7d374b9ba0 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -510,6 +510,59 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( return true; }; + auto collapse_split_segments = [&](const ggml_backend_meta_split_state & ss) { + ggml_backend_meta_split_state ret = ss; + + if (ret.n_segments <= 1) { + return ret; + } + + memset(ret.ne, 0, sizeof(ret.ne)); + + for (size_t j = 0; j < n_bufs; ++j) { + for (size_t s = 0; s < ss.n_segments; ++s) { + ret.ne[j] += ss.ne[s*n_bufs + j]; + } + } + + ret.n_segments = 1; + + return ret; + }; + + bool split_state_ne_is_exact = false; + + auto exact_split_state = [&](ggml_backend_meta_split_state ret) { + if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + int64_t ne_sum = 0; + for (size_t s = 0; s < ret.n_segments; ++s) { + for (size_t j = 0; j < n_bufs; ++j) { + ne_sum += ret.ne[s*n_bufs + j]; + } + } + GGML_ASSERT(ne_sum == tensor->ne[int(ret.axis)]); + split_state_ne_is_exact = true; + } + return ret; + }; + + auto split_total = [&](const ggml_backend_meta_split_state & ss, size_t j) { + int64_t total = 0; + for (size_t s = 0; s < ss.n_segments; ++s) { + total += ss.ne[s*n_bufs + j]; + } + return total; + }; + + auto split_totals_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) { + for (size_t j = 0; j < n_bufs; ++j) { + if (split_total(a, j) != split_total(b, j)) { + return false; + } + } + return true; + }; + auto handle_generic = [&](const std::vector & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; for (size_t i = 0; i < GGML_MAX_SRC; i++) { @@ -539,6 +592,14 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( return src_ss[0]; }; + auto handle_shape_preserving_unary = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS && + tensor->ne[src_ss[0].axis] == tensor->src[0]->ne[src_ss[0].axis]) { + return exact_split_state(src_ss[0]); + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + // Some ops broadcast the src1 data across src0: auto handle_bin_bcast = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS && @@ -563,7 +624,10 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT(concat_axis != src_ss[0].axis); return src_ss[0]; } - if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) { + // If concat is performed on a non-split axis, both inputs must have the + // same local extent on every backend for the split axis. Matching axes + // alone is not sufficient. + if (split_states_equal(src_ss[0], src_ss[1]) && src_ss[0].axis != concat_axis) { return src_ss[0]; } return handle_generic(src_ss, /*scalar_only =*/ true); @@ -574,15 +638,19 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - ggml_backend_meta_split_state ret = src_ss[0]; + ggml_backend_meta_split_state ret = collapse_split_segments(src_ss[0]); ret.axis = GGML_BACKEND_SPLIT_AXIS_0; - ret.n_segments = 1; return ret; } if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - ggml_backend_meta_split_state ret = src_ss[1]; - ret.n_segments = 1; - return ret; + return collapse_split_segments(src_ss[1]); + } + // A mirrored, B split on a pass-through dim (dims >= 2 of B are carried to the output unchanged). + // This covers the PR #21038 rotation matmul where activations are reshaped to preserve the head + // axis and the head dim ends up as dim 2 or 3 of the 4D input. + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_3)) { + return collapse_split_segments(src_ss[1]); } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); @@ -659,6 +727,73 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( }; auto handle_view = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + auto split_view_state = [&](int view_axis, int src_axis) -> ggml_backend_meta_split_state { + ggml_backend_meta_split_state ret = {ggml_backend_meta_split_axis(view_axis), {0}, 1}; + + if (src_axis < 0 || src_axis >= GGML_MAX_DIMS || view_axis < 0 || view_axis >= GGML_MAX_DIMS) { + return ret; + } + + const ggml_tensor * src = tensor->src[0]; + + if (src->nb[src_axis] == 0 || tensor->view_offs % src->nb[src_axis] != 0) { + return ret; + } + + if (tensor->nb[view_axis] % src->nb[src_axis] != 0) { + return ret; + } + + const int64_t view_stride_src = tensor->nb[view_axis] / src->nb[src_axis]; + + if (view_stride_src <= 0) { + return ret; + } + + const int64_t view_begin = (tensor->view_offs / src->nb[src_axis]) % src->ne[src_axis]; + const int64_t view_end = view_begin + tensor->ne[view_axis] * view_stride_src; + + if (view_begin < 0 || view_end > src->ne[src_axis]) { + return ret; + } + + int64_t segment_begin = 0; + for (size_t s = 0; s < src_ss[0].n_segments; ++s) { + int64_t segment_size = 0; + for (size_t j = 0; j < n_bufs; ++j) { + segment_size += src_ss[0].ne[s*n_bufs + j]; + } + + int64_t backend_begin = segment_begin; + for (size_t j = 0; j < n_bufs; ++j) { + const int64_t backend_end = backend_begin + src_ss[0].ne[s*n_bufs + j]; + const int64_t overlap = + std::max(0, std::min(view_end, backend_end) - std::max(view_begin, backend_begin)); + + if (overlap > 0) { + GGML_ASSERT(overlap % view_stride_src == 0); + ret.ne[j] += overlap / view_stride_src; + } + + backend_begin = backend_end; + } + + segment_begin += segment_size; + } + + int64_t ret_ne_sum = 0; + for (size_t j = 0; j < n_bufs; ++j) { + ret_ne_sum += ret.ne[j]; + } + + if (ret_ne_sum == tensor->ne[view_axis]) { + return exact_split_state(ret); + } + + memset(ret.ne, 0, sizeof(ret.ne)); + return ret; + }; + if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { return handle_reshape(src_ss); } @@ -681,7 +816,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; + return split_view_state(dim, axis); } } GGML_ABORT("fatal error"); @@ -769,11 +904,21 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_ssm_conv = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == src_ss[1].axis) { + GGML_ASSERT(split_totals_equal(src_ss[0], src_ss[1])); + + // SSM_CONV is channel-wise. Preserve the exact segmented channel + // layout (Q/K/V groups for Qwen recurrent layers) so downstream + // VIEWs of q/k/v can recover matching head splits. + ggml_backend_meta_split_state ret = + src_ss[1].n_segments >= src_ss[0].n_segments ? src_ss[1] : src_ss[0]; + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + ret.axis = GGML_BACKEND_SPLIT_AXIS_1; + return exact_split_state(ret); } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + return exact_split_state(ret); } } return handle_generic(src_ss, /*scalar_only =*/ false); @@ -797,6 +942,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( }; auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { + split_state_ne_is_exact = false; + if (ggml_nelements(tensor) == 0) { return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; } @@ -997,7 +1144,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( split_state = handle_gated_delta_net(src_ss); } break; case GGML_OP_UNARY: { - split_state = handle_generic(src_ss, /*scalar_only =*/ false); + split_state = handle_shape_preserving_unary(src_ss); } break; case GGML_OP_MAP_CUSTOM1: case GGML_OP_MAP_CUSTOM2: @@ -1019,7 +1166,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; } break; } - if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + if (!split_state_ne_is_exact && split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { bool first_src_split_by_axis = true; const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ad36c06667d..99bdf092b3c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3403,10 +3403,6 @@ llama_context * llama_init_from_model( LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); return nullptr; } - if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { - LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); - return nullptr; - } } if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e6ec3054daf..881c0e9a52b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -68,12 +68,14 @@ static ggml_tensor * ggml_mul_mat_aux( const auto n = rot->ne[0]; ggml_tensor * res; + GGML_ASSERT(cur->ne[0] % n == 0); if (!ggml_is_contiguous(cur)) { - res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n); - } else { - res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + cur = ggml_cont(ctx, cur); } + // Keep the head axis visible through the rotation matmul so the meta backend + // can propagate tensor-parallel splits on the head/batch dimensions. + res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]*cur->ne[3]); res = ggml_mul_mat (ctx, rot, res); ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index ac11f96c22d..823191477bc 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -62,10 +62,14 @@ static ggml_tensor * ggml_mul_mat_aux( ggml_tensor * cur, ggml_tensor * rot) { const auto n = rot->ne[0]; + GGML_ASSERT(cur->ne[0] % n == 0); ggml_tensor * res; - res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + // Preserve the head dim through the matmul so SPLIT_MODE_TENSOR's split-axis + // inference can track a head-axis split. Collapsing heads and tokens together + // (reshape_2d) drops that information and trips the meta backend. + res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]*cur->ne[3]); res = ggml_mul_mat (ctx, rot, res); ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a8323c8fb1e..712d83893b2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -407,7 +407,63 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return {axis, tensor_axis_0, il, rotation}; }; + auto is_qwen_recurrent_arch = [&]() { + return ud->model->arch == LLM_ARCH_QWEN3NEXT || + ud->model->arch == LLM_ARCH_QWEN35 || + ud->model->arch == LLM_ARCH_QWEN35MOE; + }; + + auto get_layer_index = [&]() -> int { + if (tensor_name.substr(0, 4) == "blk.") { + const size_t length_prefix = tensor_name.find('.', 4); + GGML_ASSERT(length_prefix != std::string::npos); + return std::stoi(tensor_name.substr(4, length_prefix)); + } + if (tensor_name.substr(0, 6) == "cache_") { + const size_t layer_index_start = tensor_name.find("_l", 6); + GGML_ASSERT(layer_index_start != std::string::npos); + return std::stoi(tensor_name.substr(layer_index_start + 2)); + } + return -1; + }; + + auto is_qwen_recurrent_tensor = [&]() { + if (!is_qwen_recurrent_arch()) { + return false; + } + + const int il = get_layer_index(); + return il >= 0 && hparams.is_recurrent(il); + }; + auto get_tensor_config = [&]() -> tensor_config { + const bool qwen_recurrent_tensor = is_qwen_recurrent_tensor(); + + // Qwen recurrent tensors that meet in the DeltaNet/SSM path must use one + // canonical Q/K/V split layout. Anchor them to attn_qkv.weight so all of + // them use the same rotation and quantization block granularity. + if (qwen_recurrent_tensor) { + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_qkv.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta) || + std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_qkv.weight"); + } + if (std::regex_match(tensor_name, pattern_r_cache) || std::regex_match(tensor_name, pattern_s_cache)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_qkv.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_conv1d)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_qkv.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_qkv.weight"); + } + if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_qkv.weight"); + } + } + // standard attention if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight");