Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13928,7 +13928,6 @@ struct llm_build_deepseek3_2 : public llm_graph_context {
ctx0, model, il, cur, n_tokens, mctx_cur2, inp_attn->get_k_idxs(), KQmask2, top_k,
inp_pos, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
cb_wrapper, gf, sched, backend_cpu);
ggml_backend_sched_set_tensor_backend(sched, kvaware_indices, backend_cpu);
cur = llama::sparse_mla_fwd::apply_sparse_attention_kvaware(
ctx0, Qcur, Kcache, Vcache, kvaware_indices, n_tokens, top_k, kq_scale, KQmask2, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f, cb_wrapper);
// Sanity checks for MLA sparse attention output vs expected V-dim (kv_lora_rank)
Expand All @@ -13946,10 +13945,7 @@ struct llm_build_deepseek3_2 : public llm_graph_context {
}
}

if (!cparams.offload_kqv) {
/* sparse: follow dense behavior by placing on CPU when not offloading */
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
}
/* keep sparse attention output on device to avoid backend hops */

// Project kv_lora_rank -> n_embd_head_v per head using wv_b and flatten heads before WO
ggml_tensor * cur_perm = ggml_permute(ctx0, cur, 0, 2, 1, 3); // [kv_lora_rank, n_tokens, n_head]
Expand Down Expand Up @@ -14039,7 +14035,6 @@ struct llm_build_deepseek3_2 : public llm_graph_context {
ctx0, model, il, cur, n_tokens, mctx_cur2, inp_attn->get_k_idxs(), KQmask2, top_k,
inp_pos, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
cb_wrapper, gf, sched, backend_cpu);
ggml_backend_sched_set_tensor_backend(sched, kvaware_indices, backend_cpu);
cur = llama::sparse_mla_fwd::apply_sparse_attention_kvaware(
ctx0, Qcur, Kcache, Vcache, kvaware_indices, n_tokens, top_k, kq_scale, KQmask2, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f, cb_wrapper);
// Sanity checks for MHA sparse attention output vs expected V-dim (n_embd_head_v)
Expand All @@ -14049,10 +14044,7 @@ struct llm_build_deepseek3_2 : public llm_graph_context {
GGML_ASSERT(cur->ne[0] == (int64_t) n_embd_head_v);
}

if (!cparams.offload_kqv) {
/* sparse: follow dense behavior by placing on CPU when not offloading */
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
}
/* keep sparse attention output on device to avoid backend hops */

// Flatten heads before WO
ggml_tensor * cur_perm2 = ggml_permute(ctx0, cur, 0, 2, 1, 3); // [n_embd_head_v, n_tokens, n_head]
Expand Down
237 changes: 208 additions & 29 deletions src/llama-sparse-topk.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,189 @@
#include "llama-sparse-topk.h"
#include <algorithm>
#include <vector>
#include <cstdint>

#include "llama-impl.h"
#include <cstring>


#include <cmath>
#include <cinttypes>
#include <cstdio>
#include <cstdlib>
namespace {

static inline uint32_t float_to_key_desc(float x) {
uint32_t u;
memcpy(&u, &x, sizeof(u));
// Map float bits to monotonically increasing unsigned keys (ascending order):
// TileLang-compatible mapping: negative -> bitwise NOT, non-negative -> set sign bit
if ((int32_t)u < 0) {
u = ~u;
} else {
u |= 0x80000000u;
}
return u;
}

struct radix_topk_userdata {
// currently unused; k is taken from dst->ne[0]
};

static void radix_topk_custom(ggml_tensor * dst, int ith, int nth, void * userdata) {
(void)userdata;
const char * ENV_SPARSE_DEBUG = getenv("LLAMA_SPARSE_DEBUG");
const bool dbg = (ENV_SPARSE_DEBUG && atoi(ENV_SPARSE_DEBUG) != 0);
ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
const int64_t N = src0->ne[0];
const int64_t k = dst->ne[0];
const int64_t nr = ggml_nrows(src0);
const size_t src_nb1 = src0->nb[1];
const size_t src_nb0 = src0->nb[0];
const size_t dst_nb1 = dst->nb[1];

for (int64_t r = ith; r < nr; r += nth) {
const char * row0 = (const char *)src0->data + r * src_nb1;
int32_t * out_idx = (int32_t *)((char *)dst->data + r * dst_nb1);
const int64_t KK = k < N ? k : N;

// Precompute keys for this row
std::vector<uint32_t> keys(N);
for (int64_t i = 0; i < N; ++i) {
float v = *(const float *)(row0 + (size_t)i*src_nb0);
keys[i] = float_to_key_desc(v);
}

// Stage 1: histogram of high 8 bits (bits 31..24)
uint32_t counts[256] = {0};
for (int64_t i = 0; i < N; ++i) {
uint32_t bin = (keys[i] >> 24) & 0xFFu;
counts[bin]++;
}
// Find threshold bin: number of items with bin > thr0 is sum of counts above thr0
auto sum_greater = [&](int b){ uint32_t s=0; for (int bb=b+1; bb<256; ++bb) s += counts[bb]; return s; };
int thr0 = 0;
uint32_t gt = 0;
for (int b = 255; b >= 0; --b) {
uint32_t sgt = sum_greater(b);
uint32_t eq = counts[b];
if (sgt < (uint32_t)KK && sgt + eq >= (uint32_t)KK) { thr0 = b; gt = sgt; break; }
}
uint32_t eq0 = counts[thr0];
int64_t remaining = (int64_t)KK - (int64_t)gt;
if (remaining < 0) remaining = 0;

// Collect selected (> thr0) and eq candidates
std::vector<int32_t> selected; selected.reserve(KK);
std::vector<int32_t> eq_list; eq_list.reserve(eq0);
for (int64_t i = 0; i < N; ++i) {
uint32_t bin = (keys[i] >> 24) & 0xFFu;
if ((int)bin > thr0) {
if ((int64_t)selected.size() < KK) selected.push_back((int32_t)i);
} else if ((int)bin == thr0) {
eq_list.push_back((int32_t)i);
}
}
remaining = (int64_t)KK - (int64_t)selected.size();

// Safety check: ensure we have enough candidates to fill K
if ((int64_t)selected.size() + (int64_t)eq_list.size() < KK) {
// Fallback: use partial_sort to guarantee correctness
std::vector<int32_t> idx(N);
for (int64_t i = 0; i < N; ++i) idx[i] = (int32_t)i;
auto cmp = [&](int32_t a, int32_t b){
float va = *(const float *)(row0 + (size_t)a*src_nb0);
float vb = *(const float *)(row0 + (size_t)b*src_nb0);
if (va != vb) return va > vb;
return a < b;
};
std::partial_sort(idx.begin(), idx.begin() + KK, idx.end(), cmp);
for (int64_t i = 0; i < KK; ++i) out_idx[i] = idx[i];
continue;
}

// Tail passes for equal bin
int shifts[3] = {16, 8, 0};
for (int pass = 0; pass < 3 && remaining > 0 && !eq_list.empty(); ++pass) {
uint32_t c2[256] = {0};
for (int idx : eq_list) {
uint32_t bin = (keys[idx] >> shifts[pass]) & 0xFFu;
c2[bin]++;
}
auto sum_greater2 = [&](int b){ uint32_t s=0; for (int bb=b+1; bb<256; ++bb) s += c2[bb]; return s; };
int thr = 255;
for (int b = 255; b >= 0; --b) {
uint32_t sgt = sum_greater2(b);
uint32_t eq = c2[b];
if (sgt < (uint32_t)remaining && sgt + eq >= (uint32_t)remaining) { thr = b; break; }
}
std::vector<int32_t> next_eq; next_eq.reserve(c2[thr]);
// Add strictly greater than thr
for (int idx : eq_list) {
uint32_t bin = (keys[idx] >> shifts[pass]) & 0xFFu;
if ((int)bin > thr) {
if ((int64_t)selected.size() < KK) { selected.push_back(idx); }
} else if ((int)bin == thr) {
next_eq.push_back(idx);
}
}
eq_list.swap(next_eq);
remaining = (int64_t)KK - (int64_t)selected.size();
if ((int64_t)selected.size() + (int64_t)eq_list.size() < remaining) {
// Fallback safety
break;
}
}
// Final fill from eq_list if still remaining
for (int64_t i = 0; i < (int64_t)eq_list.size() && (int64_t)selected.size() < KK; ++i) {
selected.push_back(eq_list[i]);
}

// As a final fallback, if still not enough, use partial_sort
if ((int64_t)selected.size() < KK) {
std::vector<int32_t> idx(N);
for (int64_t i = 0; i < N; ++i) idx[i] = (int32_t)i;
auto cmp = [&](int32_t a, int32_t b){
float va = *(const float *)(row0 + (size_t)a*src_nb0);
float vb = *(const float *)(row0 + (size_t)b*src_nb0);
if (va != vb) return va > vb;
return a < b;
};
std::partial_sort(idx.begin(), idx.begin() + KK, idx.end(), cmp);
for (int64_t i = 0; i < KK; ++i) out_idx[i] = idx[i];
continue;
}

// Output first KK indices (order arbitrary)
for (int64_t i = 0; i < KK; ++i) out_idx[i] = selected[i];

// Debug: compare with partial_sort for a few rows
if (r < 8) {
std::vector<int32_t> ref(N);
for (int64_t i = 0; i < N; ++i) ref[i] = (int32_t)i;
auto cmp = [&](int32_t a, int32_t b){
float va = *(const float *)(row0 + (size_t)a*src_nb0);
float vb = *(const float *)(row0 + (size_t)b*src_nb0);
if (va != vb) return va > vb;
return a < b;
};
std::partial_sort(ref.begin(), ref.begin() + KK, ref.end(), cmp);
if (dbg) {
printf("[radix debug] row=%lld top: ", (long long)r);
for (int ii = 0; ii < (int)std::min<int64_t>(8, KK); ++ii) printf("%d ", out_idx[ii]);
printf("| ref: ");
for (int ii = 0; ii < (int)std::min<int64_t>(8, KK); ++ii) printf("%d ", ref[ii]);
printf("\n");
fflush(stdout);
}
}
}
}

} // anonymous namespace


namespace llama {

Expand Down Expand Up @@ -96,7 +275,7 @@ using std::function;
}

const int64_t k = std::min<int64_t>(top_k, N_kv);
int64_t TILE_T = 32;
int64_t TILE_T = 128; // larger default tile improves GEMM utilization; overridable via env
if (const char *env = getenv("LLAMA_SPARSE_TOPK_TILE_T")) {
long v = strtol(env, nullptr, 10);
if (v > 0 && v <= 4096) TILE_T = v;
Expand Down Expand Up @@ -238,25 +417,11 @@ using std::function;
ggml_build_forward_expand(gf, sft_sumsq);
}
}
// Clamp infinities from mask to large finite values to stabilize argsort/top-k
// Clamp infinities then compute top-k indices via custom CPU radix selection (no full sort)
ggml_tensor * scores_clamped = ggml_clamp(ctx, scores_for_topk, -1e30f, 1e30f);
ggml_tensor * topk_tc = ggml_top_k(ctx, scores_clamped, k);
if (dbg) {
cb(topk_tc->src[0], "idxkv_argsort", -1);
cb(topk_tc, "idxkv_topk", -1);
if (t0 == 0) {
ggml_tensor * topk_f32 = ggml_cast(ctx, topk_tc, GGML_TYPE_F32);
ggml_tensor * idxkv_topk_idx_sum = ggml_sum(ctx, topk_f32);
ggml_tensor * idxkv_topk_idx_sumsq = ggml_sum(ctx, ggml_sqr(ctx, topk_f32));
cb(idxkv_topk_idx_sum, "idxkv_topk_idx_sum", -1);
cb(idxkv_topk_idx_sumsq, "idxkv_topk_idx_sumsq", -1);
if (gf) {
ggml_set_output(idxkv_topk_idx_sum);
ggml_set_output(idxkv_topk_idx_sumsq);
ggml_build_forward_expand(gf, idxkv_topk_idx_sum);
ggml_build_forward_expand(gf, idxkv_topk_idx_sumsq);
}
}
ggml_tensor * topk_tc = sparse_attn_topk::topk_radix_indices(ctx, scores_clamped, k);
if (dbg && t0 == 0) {
cb(topk_tc, "idxkv_topk_radix", -1);
}
result = result ? ggml_concat(ctx, result, topk_tc, 1) : topk_tc;
}
Expand All @@ -267,7 +432,7 @@ using std::function;
ggml_build_forward_expand(gf, result);
}
// Also provide a float32 view for eval-callback visibility on platforms that skip integer dumps
if (result) {
if (dbg && result) {
ggml_tensor * result_f32 = ggml_cast(ctx, result, GGML_TYPE_F32);
cb(result_f32, "idxkv_topk_indices_k_T_f32", -1);
if (gf) {
Expand All @@ -280,16 +445,30 @@ using std::function;
result->ne[0], result->ne[1], result->ne[2], result->ne[3], (int)result->type);
fflush(stdout);
}
// prefer CPU backend for the final indices tensor using scheduler API
if (sched && backend_cpu) {
ggml_backend_sched_set_tensor_backend(sched, result, backend_cpu);
if (dbg) {
const char * bname2 = ggml_backend_name(backend_cpu);
printf("[TOPK] assigned backend for kvaware_topk_indices: %s (non-null=%d)\n", bname2 ? bname2 : "null", backend_cpu ? 1 : 0);
fflush(stdout);
}
}
// Keep indices on device by default to avoid host syncs during get_rows
return result;
}



ggml_tensor * llama::sparse_attn_topk::topk_radix_indices(
ggml_context * ctx,
ggml_tensor * scores, // [N, T]
int64_t k) {
GGML_ASSERT(scores->type == GGML_TYPE_F32);
ggml_tensor * args[1] = { scores };
return ggml_custom_4d(
ctx,
GGML_TYPE_I32,
/*ne0*/ k,
/*ne1*/ scores->ne[1],
/*ne2*/ 1,
/*ne3*/ 1,
args,
/*n_args*/ 1,
/*fun*/ radix_topk_custom,
/*n_tasks*/ GGML_N_TASKS_MAX,
/*userdata*/ nullptr);
}

} // namespace llama
6 changes: 6 additions & 0 deletions src/llama-sparse-topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ struct sparse_attn_topk {
ggml_cgraph * gf,
ggml_backend_sched_t sched,
ggml_backend_t backend_cpu);

// new: compute top-k indices per column for a scores matrix [N, T]
static ggml_tensor * topk_radix_indices(
ggml_context * ctx,
ggml_tensor * scores, // [N, T]
int64_t k);
};

} // namespace llama
Expand Down
3 changes: 3 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,6 @@ target_include_directories(test-sparse-attn PRIVATE ${PROJECT_SOURCE_DIR}/ggml/s

# Reproducer for no_alloc sparse indexer crash
llama_build_and_test(test-sparse-attn-noalloc.cpp)

# Radix top-k unit test for sparse indexer
llama_build_and_test(test-sparse-topk-radix.cpp)
Loading
Loading