Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 158 additions & 11 deletions ggml/src/ggml-backend-meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,59 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
return true;
};

auto collapse_split_segments = [&](const ggml_backend_meta_split_state & ss) {
ggml_backend_meta_split_state ret = ss;

if (ret.n_segments <= 1) {
return ret;
}

memset(ret.ne, 0, sizeof(ret.ne));

for (size_t j = 0; j < n_bufs; ++j) {
for (size_t s = 0; s < ss.n_segments; ++s) {
ret.ne[j] += ss.ne[s*n_bufs + j];
}
}

ret.n_segments = 1;

return ret;
};

bool split_state_ne_is_exact = false;

auto exact_split_state = [&](ggml_backend_meta_split_state ret) {
if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) {
int64_t ne_sum = 0;
for (size_t s = 0; s < ret.n_segments; ++s) {
for (size_t j = 0; j < n_bufs; ++j) {
ne_sum += ret.ne[s*n_bufs + j];
}
}
GGML_ASSERT(ne_sum == tensor->ne[int(ret.axis)]);
split_state_ne_is_exact = true;
}
return ret;
};

auto split_total = [&](const ggml_backend_meta_split_state & ss, size_t j) {
int64_t total = 0;
for (size_t s = 0; s < ss.n_segments; ++s) {
total += ss.ne[s*n_bufs + j];
}
return total;
};

auto split_totals_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) {
for (size_t j = 0; j < n_bufs; ++j) {
if (split_total(a, j) != split_total(b, j)) {
return false;
}
}
return true;
};

auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state {
ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1};
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
Expand Down Expand Up @@ -539,6 +592,14 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
return src_ss[0];
};

auto handle_shape_preserving_unary = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS &&
tensor->ne[src_ss[0].axis] == tensor->src[0]->ne[src_ss[0].axis]) {
return exact_split_state(src_ss[0]);
}
return handle_generic(src_ss, /*scalar_only =*/ false);
};

// Some ops broadcast the src1 data across src0:
auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS &&
Expand All @@ -563,7 +624,10 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
GGML_ASSERT(concat_axis != src_ss[0].axis);
return src_ss[0];
}
if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) {
// If concat is performed on a non-split axis, both inputs must have the
// same local extent on every backend for the split axis. Matching axes
// alone is not sufficient.
if (split_states_equal(src_ss[0], src_ss[1]) && src_ss[0].axis != concat_axis) {
return src_ss[0];
}
return handle_generic(src_ss, /*scalar_only =*/ true);
Expand All @@ -574,15 +638,19 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1};
}
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
ggml_backend_meta_split_state ret = src_ss[0];
ggml_backend_meta_split_state ret = collapse_split_segments(src_ss[0]);
ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
ret.n_segments = 1;
return ret;
}
if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
ggml_backend_meta_split_state ret = src_ss[1];
ret.n_segments = 1;
return ret;
return collapse_split_segments(src_ss[1]);
}
// A mirrored, B split on a pass-through dim (dims >= 2 of B are carried to the output unchanged).
// This covers the PR #21038 rotation matmul where activations are reshaped to preserve the head
// axis and the head dim ends up as dim 2 or 3 of the 4D input.
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED &&
(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_3)) {
return collapse_split_segments(src_ss[1]);
}
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) {
GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1]));
Expand Down Expand Up @@ -659,6 +727,73 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
};

auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
auto split_view_state = [&](int view_axis, int src_axis) -> ggml_backend_meta_split_state {
ggml_backend_meta_split_state ret = {ggml_backend_meta_split_axis(view_axis), {0}, 1};

if (src_axis < 0 || src_axis >= GGML_MAX_DIMS || view_axis < 0 || view_axis >= GGML_MAX_DIMS) {
return ret;
}

const ggml_tensor * src = tensor->src[0];

if (src->nb[src_axis] == 0 || tensor->view_offs % src->nb[src_axis] != 0) {
return ret;
}

if (tensor->nb[view_axis] % src->nb[src_axis] != 0) {
return ret;
}

const int64_t view_stride_src = tensor->nb[view_axis] / src->nb[src_axis];

if (view_stride_src <= 0) {
return ret;
}

const int64_t view_begin = (tensor->view_offs / src->nb[src_axis]) % src->ne[src_axis];
const int64_t view_end = view_begin + tensor->ne[view_axis] * view_stride_src;

if (view_begin < 0 || view_end > src->ne[src_axis]) {
return ret;
}

int64_t segment_begin = 0;
for (size_t s = 0; s < src_ss[0].n_segments; ++s) {
int64_t segment_size = 0;
for (size_t j = 0; j < n_bufs; ++j) {
segment_size += src_ss[0].ne[s*n_bufs + j];
}

int64_t backend_begin = segment_begin;
for (size_t j = 0; j < n_bufs; ++j) {
const int64_t backend_end = backend_begin + src_ss[0].ne[s*n_bufs + j];
const int64_t overlap =
std::max<int64_t>(0, std::min(view_end, backend_end) - std::max(view_begin, backend_begin));

if (overlap > 0) {
GGML_ASSERT(overlap % view_stride_src == 0);
ret.ne[j] += overlap / view_stride_src;
}

backend_begin = backend_end;
}

segment_begin += segment_size;
}

int64_t ret_ne_sum = 0;
for (size_t j = 0; j < n_bufs; ++j) {
ret_ne_sum += ret.ne[j];
}

if (ret_ne_sum == tensor->ne[view_axis]) {
return exact_split_state(ret);
}

memset(ret.ne, 0, sizeof(ret.ne));
return ret;
};

if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) {
return handle_reshape(src_ss);
}
Expand All @@ -681,7 +816,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) {
for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) {
if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) {
return {ggml_backend_meta_split_axis(dim), {0}, 1};
return split_view_state(dim, axis);
}
}
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -769,11 +904,21 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(

auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state {
if (src_ss[0].axis == src_ss[1].axis) {
GGML_ASSERT(split_totals_equal(src_ss[0], src_ss[1]));

// SSM_CONV is channel-wise. Preserve the exact segmented channel
// layout (Q/K/V groups for Qwen recurrent layers) so downstream
// VIEWs of q/k/v can recover matching head splits.
ggml_backend_meta_split_state ret =
src_ss[1].n_segments >= src_ss[0].n_segments ? src_ss[1] : src_ss[0];

if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) {
return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1};
ret.axis = GGML_BACKEND_SPLIT_AXIS_1;
return exact_split_state(ret);
}
if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) {
return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1};
ret.axis = GGML_BACKEND_SPLIT_AXIS_0;
return exact_split_state(ret);
}
}
return handle_generic(src_ss, /*scalar_only =*/ false);
Expand All @@ -797,6 +942,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
};

auto calculate_split_state = [&]() -> ggml_backend_meta_split_state {
split_state_ne_is_exact = false;

if (ggml_nelements(tensor) == 0) {
return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
}
Expand Down Expand Up @@ -997,7 +1144,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
split_state = handle_gated_delta_net(src_ss);
} break;
case GGML_OP_UNARY: {
split_state = handle_generic(src_ss, /*scalar_only =*/ false);
split_state = handle_shape_preserving_unary(src_ss);
} break;
case GGML_OP_MAP_CUSTOM1:
case GGML_OP_MAP_CUSTOM2:
Expand All @@ -1019,7 +1166,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1};
} break;
}
if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
if (!split_state_ne_is_exact && split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) {
bool first_src_split_by_axis = true;
const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer);

Expand Down
4 changes: 0 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3403,10 +3403,6 @@ llama_context * llama_init_from_model(
LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__);
return nullptr;
}
if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) {
LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__);
return nullptr;
}
}

if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) {
Expand Down
8 changes: 5 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ static ggml_tensor * ggml_mul_mat_aux(
const auto n = rot->ne[0];

ggml_tensor * res;
GGML_ASSERT(cur->ne[0] % n == 0);

if (!ggml_is_contiguous(cur)) {
res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n);
} else {
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
cur = ggml_cont(ctx, cur);
}
// Keep the head axis visible through the rotation matmul so the meta backend
// can propagate tensor-parallel splits on the head/batch dimensions.
res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]*cur->ne[3]);
res = ggml_mul_mat (ctx, rot, res);
ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
Expand Down
6 changes: 5 additions & 1 deletion src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ static ggml_tensor * ggml_mul_mat_aux(
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
GGML_ASSERT(cur->ne[0] % n == 0);

ggml_tensor * res;

res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
// Preserve the head dim through the matmul so SPLIT_MODE_TENSOR's split-axis
// inference can track a head-axis split. Collapsing heads and tokens together
// (reshape_2d) drops that information and trips the meta backend.
res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/n, cur->ne[1], cur->ne[2]*cur->ne[3]);
res = ggml_mul_mat (ctx, rot, res);
ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
Expand Down
56 changes: 56 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,63 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
return {axis, tensor_axis_0, il, rotation};
};

auto is_qwen_recurrent_arch = [&]() {
return ud->model->arch == LLM_ARCH_QWEN3NEXT ||
ud->model->arch == LLM_ARCH_QWEN35 ||
ud->model->arch == LLM_ARCH_QWEN35MOE;
};

auto get_layer_index = [&]() -> int {
if (tensor_name.substr(0, 4) == "blk.") {
const size_t length_prefix = tensor_name.find('.', 4);
GGML_ASSERT(length_prefix != std::string::npos);
return std::stoi(tensor_name.substr(4, length_prefix));
}
if (tensor_name.substr(0, 6) == "cache_") {
const size_t layer_index_start = tensor_name.find("_l", 6);
GGML_ASSERT(layer_index_start != std::string::npos);
return std::stoi(tensor_name.substr(layer_index_start + 2));
}
return -1;
};

auto is_qwen_recurrent_tensor = [&]() {
if (!is_qwen_recurrent_arch()) {
return false;
}

const int il = get_layer_index();
return il >= 0 && hparams.is_recurrent(il);
};

auto get_tensor_config = [&]() -> tensor_config {
const bool qwen_recurrent_tensor = is_qwen_recurrent_tensor();

// Qwen recurrent tensors that meet in the DeltaNet/SSM path must use one
// canonical Q/K/V split layout. Anchor them to attn_qkv.weight so all of
// them use the same rotation and quantization block granularity.
if (qwen_recurrent_tensor) {
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) {
return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_qkv.weight");
}
if (std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta) ||
std::regex_match(tensor_name, pattern_ssm_beta_alpha)) {
return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_qkv.weight");
}
if (std::regex_match(tensor_name, pattern_r_cache) || std::regex_match(tensor_name, pattern_s_cache)) {
return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_qkv.weight");
}
if (std::regex_match(tensor_name, pattern_ssm_conv1d)) {
return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_qkv.weight");
}
if (std::regex_match(tensor_name, pattern_ssm_out_weight)) {
return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_qkv.weight");
}
if (std::regex_match(tensor_name, pattern_attn_gate_weight)) {
return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_qkv.weight");
}
}

// standard attention
if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) {
return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight");
Expand Down