Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"DeepseekForCausalLM": "deepseek",
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
"DistilBertForMaskedLM": "bert",
"DistilBertForSequenceClassification": "bert",
"DistilBertModel": "bert",
Expand Down
2 changes: 2 additions & 0 deletions conversion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,8 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.SSM_CONV1D_Q,
gguf.MODEL_TENSOR.SSM_CONV1D_K,
gguf.MODEL_TENSOR.SSM_CONV1D_V,
# DSA indexer weights should be F32
gguf.MODEL_TENSOR.INDEXER_PROJ,
)
)
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
Expand Down
29 changes: 29 additions & 0 deletions conversion/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,32 @@ def prepare_tensors(self):
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("DeepseekV32ForCausalLM")
class DeepseekV32Model(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK32
skip_mtp = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
assert getattr(tokenizer, "add_bos_token", False), "Change value of add_bos_token to true in tokenizer_config.json file."
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()

# NextN/MTP prediction layers
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)

# DSA indexer parameters
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
36 changes: 35 additions & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2235,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg
}
}

static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0));

GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
GGML_TENSOR_LOCALS(size_t, nb, dst, nb);

const auto [ir0, ir1] = get_thread_range(params, dst);

for (int64_t ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne2*ne1);
const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);

ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);

ggml_vec_set_f16(ne0, dst_ptr, c);
}
}

void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
ggml_compute_forward_fill_f32(params, dst);
const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_fill_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_fill_f16(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type));
}
}
}

// ggml_compute_tri
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5223,7 +5223,7 @@ static struct ggml_tensor * ggml_fill_impl(
struct ggml_tensor * a,
float c,
bool inplace) {
GGML_ASSERT(a->type == GGML_TYPE_F32);
GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16);
GGML_ASSERT(ggml_is_contiguous(a));

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
Expand Down
46 changes: 46 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK = auto()
DEEPSEEK2 = auto()
DEEPSEEK2OCR = auto()
DEEPSEEK32 = auto()
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
Expand Down Expand Up @@ -966,6 +967,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DEEPSEEK: "deepseek",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr",
MODEL_ARCH.DEEPSEEK32: "deepseek32",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
Expand Down Expand Up @@ -2928,6 +2930,46 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.DEEPSEEK32: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.INDEXER_K_NORM,
MODEL_TENSOR.INDEXER_PROJ,
MODEL_TENSOR.INDEXER_ATTN_K,
MODEL_TENSOR.INDEXER_ATTN_Q_B,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.ERNIE4_5_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down Expand Up @@ -4062,6 +4104,10 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.DEEPSEEK32: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS,
],
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_library(llama
llama-io.cpp
llama-kv-cache.cpp
llama-kv-cache-iswa.cpp
llama-kv-cache-dsa.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-hybrid-iswa.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK, "deepseek" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" },
{ LLM_ARCH_DEEPSEEK32, "deepseek32" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
Expand Down Expand Up @@ -902,6 +903,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
case LLM_ARCH_OLMO2:
case LLM_ARCH_OLMOE:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK32:
case LLM_ARCH_GLM_DSA:
case LLM_ARCH_BITNET:
case LLM_ARCH_T5:
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK,
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_DEEPSEEK2OCR,
LLM_ARCH_DEEPSEEK32,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
Expand Down
128 changes: 128 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-kv-cache-dsa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
Expand Down Expand Up @@ -499,6 +500,34 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
return res;
}

void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) {
mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch);

mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn);

mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch);

mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn);

mctx->get_lid()->set_input_k_rot(self_k_rot_lid);
}

bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_dsa_context *>(params.mctx);

this->mctx = mctx;

bool res = true;

res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens;
res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens;

res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams);
res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams);

return res;
}

void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
// base tensors may not be allocated if there are no non-SWA attention layers
if (self_k_idxs && self_k_idxs->buffer) {
Expand Down Expand Up @@ -2354,6 +2383,81 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_k_dsa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
Comment thread
fairydreaming marked this conversation as resolved.
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
ggml_tensor * top_k,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, v_cur);
ggml_build_forward_expand(gf, k_cur);

const auto * mctx_cur = inp->mctx->get_mla();

// store to KV cache
{
const auto & k_idxs = inp->get_k_idxs_mla();

ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}

const auto & kq_mask = inp->get_kq_mask_mla();

// prepare new kq mask - starts filled with -INFINITY
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY);

// reshape KQ mask into tensor with rows of size 1:
// [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream]
kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0);

// reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1]
ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0);

// prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream]
// this will be our source of zero values for unmasking top k mask elements
ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]);
zeros = ggml_fill(ctx0, zeros, 0.0f);

// modify KQ mask by unmasking elements that are in top_k indices
// ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1])
ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d);

// reshape to restore the original shape of KQ mask:
// [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream]
kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0);

// combine with the original kq mask
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask);

ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);

if (wo) {
cur = build_lora_mm(wo, cur);
Comment thread
fairydreaming marked this conversation as resolved.
Outdated
}

if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}

return cur;
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
Expand Down Expand Up @@ -2497,6 +2601,30 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx);

auto inp = std::make_unique<llm_graph_input_attn_k_dsa>(hparams, cparams, mctx_cur);

{
inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch);

inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams);
inp->self_kq_mask_mla_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_mla, GGML_TYPE_F16) : inp->self_kq_mask_mla;
}

{
inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch);

inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams);
inp->self_kq_mask_lid_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_lid, GGML_TYPE_F16) : inp->self_kq_mask_lid;

inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0);
}

return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp));
}

// TODO: maybe separate the inner implementation into a separate function
// like with the non-sliding window equivalent
// once sliding-window hybrid caches are a thing.
Expand Down
Loading
Loading