Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9948,6 +9948,31 @@ def _is_audio_tensor(self, name: str):
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])


@ModelBase.register("Lfm2Model")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"

def set_vocab(self):
super().set_vocab()
self.gguf_writer.add_add_bos_token(False)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name

return super().modify_tensors(data_torch, name, bid)

def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
assert tensors_file.is_file()
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()


@ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2MOE
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class LLM:
VOCAB_SIZE = "{arch}.vocab_size"
CONTEXT_LENGTH = "{arch}.context_length"
EMBEDDING_LENGTH = "{arch}.embedding_length"
EMBEDDING_LENGTH_OUT = "{arch}.embedding_length_out"
FEATURES_LENGTH = "{arch}.features_length"
BLOCK_COUNT = "{arch}.block_count"
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
Expand Down Expand Up @@ -3038,6 +3039,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.DENSE_2_OUT, # LFM2-ColBert-350M
],
MODEL_ARCH.LFM2MOE: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,9 @@ def add_context_length(self, length: int) -> None:
def add_embedding_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)

def add_embedding_length_out(self, length: int) -> None:
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length)

def add_features_length(self, length: int) -> None:
self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)

Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ extern "C" {
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
{ LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" },
{ LLM_KV_FEATURES_LENGTH, "%s.features_length" },
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
Expand Down Expand Up @@ -2075,6 +2076,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM_LFM2,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_DENSE_2_OUT,
};
case LLM_ARCH_LFM2MOE:
return {
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ enum llm_kv {
LLM_KV_VOCAB_SIZE,
LLM_KV_CONTEXT_LENGTH,
LLM_KV_EMBEDDING_LENGTH,
LLM_KV_EMBEDDING_LENGTH_OUT,
LLM_KV_FEATURES_LENGTH,
LLM_KV_BLOCK_COUNT,
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
Expand Down
39 changes: 21 additions & 18 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}

return embd + j*model.hparams.n_embd;
const uint32_t n_embd_out = model.hparams.get_n_embd_out();
return embd + j*n_embd_out;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
Expand Down Expand Up @@ -1194,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
const uint32_t n_embd_out = hparams.get_n_embd_out();

GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
Expand All @@ -1215,17 +1217,17 @@ int llama_context::encode(const llama_batch & batch_inp) {
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - n_cls_out floats per sequence
// extract the rerank score - n_embd_out floats per sequence
auto & embd_seq_out = embd_seq;

const uint32_t n_cls_out = hparams.n_cls_out;
const uint32_t n_embd_out = hparams.get_n_embd_out();

for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];

embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
embd_seq_out[seq_id].resize(n_embd_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
Expand Down Expand Up @@ -1600,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
float * embd_out = embd + n_outputs_prev*n_embd;
const uint32_t n_embd_out = hparams.get_n_embd_out();
float * embd_out = embd + n_outputs_prev*n_embd_out;

if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_MEAN:
Expand All @@ -1625,17 +1628,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - n_cls_out floats per sequence
// extract the rerank score - n_embd_out floats per sequence
auto & embd_seq_out = embd_seq;

const uint32_t n_cls_out = hparams.n_cls_out;
const uint32_t n_embd_out = hparams.get_n_embd_out();

for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];

embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
embd_seq_out[seq_id].resize(n_embd_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
Expand Down Expand Up @@ -1730,9 +1733,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba

const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());

const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd_out = hparams.get_n_embd_out();

bool has_logits = true;
bool has_embd = cparams.embeddings;
Expand Down Expand Up @@ -1773,7 +1776,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba

// Allocate CPU logits buffer only if needed by sequences in this batch
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
embd_size = has_embd ? n_embd*n_outputs_max : 0;
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;

// TODO: avoid this branching by working with the worst-case
if (!has_sampling) {
Expand Down
10 changes: 7 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2067,14 +2067,18 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
void llm_graph_context::build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const {
if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
if (!cparams.embeddings || !(dense_2 || dense_3)) {
return;
}
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");

cur = ggml_mul_mat(ctx0, dense_2, cur);
cur = ggml_mul_mat(ctx0, dense_3, cur);
if (dense_2) {
cur = ggml_mul_mat(ctx0, dense_2, cur);
}
if (dense_3) {
cur = ggml_mul_mat(ctx0, dense_3, cur);
}
cb(cur, "result_embd_pooled", -1);
res->t_embd_pooled = cur;
ggml_build_forward_expand(gf, cur);
Expand Down
4 changes: 4 additions & 0 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const {
return n_embd_inp;
}

uint32_t llama_hparams::get_n_embd_out() const {
return n_embd_out > 0 ? n_embd_out : n_embd;
}

uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
const uint32_t n_head_kv = this->n_head_kv(il);

Expand Down
6 changes: 6 additions & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ struct llama_hparams {
// for Classifiers
uint32_t n_cls_out = 1;

// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out = 0;

// llama4 smallthinker
uint32_t n_moe_layer_step = 0;
uint32_t n_no_rope_layer_step = 4;
Expand Down Expand Up @@ -234,6 +237,9 @@ struct llama_hparams {
// dimension of main + auxiliary input embeddings
uint32_t n_embd_inp() const;

// dimension of output embeddings
uint32_t get_n_embd_out() const;

// dimension of key embeddings across all k-v heads
uint32_t n_embd_k_gqa(uint32_t il = 0) const;

Expand Down
3 changes: 3 additions & 0 deletions src/llama-model-saver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() {
add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens());
add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
if (hparams.n_embd_out > 0) {
add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out);
}
add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer);
add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true);
Expand Down
9 changes: 9 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {

ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out, false);
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
Expand Down Expand Up @@ -627,6 +628,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
if (!classifier_labels.empty()) {
hparams.n_cls_out = classifier_labels.size();
hparams.n_embd_out = classifier_labels.size();
}

// arch-specific KVs
Expand Down Expand Up @@ -6446,6 +6448,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0);
}
}

// for LFM2-ColBert-350M
dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED);
} break;
case LLM_ARCH_SMALLTHINKER:
{
Expand Down Expand Up @@ -7976,6 +7981,10 @@ int32_t llama_model_n_embd_inp(const llama_model * model) {
return model->hparams.n_embd_inp();
}

int32_t llama_model_n_embd_out(const llama_model * model) {
return model->hparams.get_n_embd_out();
}

int32_t llama_model_n_layer(const llama_model * model) {
return model->hparams.n_layer;
}
Expand Down
10 changes: 5 additions & 5 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1505,9 +1505,9 @@ struct server_context_impl {
res->n_tokens = slot.task->n_tokens();
res->res_type = slot.task->params.res_type;

const int n_embd = llama_model_n_embd(model);
const int n_embd_out = llama_model_n_embd_out(model);

std::vector<float> embd_res(n_embd, 0.0f);
std::vector<float> embd_res(n_embd_out, 0.0f);

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
Expand All @@ -1524,18 +1524,18 @@ struct server_context_impl {
if (embd == nullptr) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);

res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
res->embedding.push_back(std::vector<float>(n_embd_out, 0.0f));
continue;
}

// normalize only when there is pooling
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize);
common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize);
res->embedding.push_back(embd_res);
break;
}

res->embedding.emplace_back(embd, embd + n_embd);
res->embedding.emplace_back(embd, embd + n_embd_out);
}

SLT_DBG(slot, "%s", "sending embeddings\n");
Expand Down