Skip to content

model : add LFM2-ColBert-350M#18607

Merged
CISC merged 3 commits intoggml-org:masterfrom
tdakhran:tarek/feat/lfm2-colbert-350m
Jan 5, 2026
Merged

model : add LFM2-ColBert-350M#18607
CISC merged 3 commits intoggml-org:masterfrom
tdakhran:tarek/feat/lfm2-colbert-350m

Conversation

@tdakhran
Copy link
Contributor

@tdakhran tdakhran commented Jan 5, 2026

PR adds support for LFM2-ColBert-350M by introducing n_embd_out - a separate output embedding dimension that can differ from the input embedding dimension (n_embd).

Initially, I introduced LLAMA_POOLING_TYPE_TOKEN, which was applying cls_out and outputting all embedding, but then switched to n_embd_out.

n_embd_out will be used in future multimodal models as well.

New GGUF key and API:

  • LLM_KV_EMBEDDING_LENGTH_OUT - stores output embedding dimension
  • llama_model_n_embd_out() - returns hparams.n_embd_out if set and fallbacks to hparams.n_embd

Testing

Convert

python convert_hf_to_gguf.py /data/playground/checkpoints/LFM2-ColBert-350M

Launch server

bin/llama-server -m /data/playground/checkpoints/LFM2-ColBert-350M/LFM2-ColBert-350M-F16.gguf --embeddings --pooling none

Run the attached Python script

❯ uv run rerank.py
Score: 29.74 | Q: What is panda? | D: hi
Score: 29.90 | Q: What is panda? | D: it is a bear
Score: 30.52 | Q: What is panda? | D: The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.

rerank.py

cc: @ngxson

Comment on lines 629 to 632
if (!classifier_labels.empty()) {
hparams.n_cls_out = classifier_labels.size();
hparams.n_embd_out = classifier_labels.size();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This crashes f.ex. jina-bert-v2, as observed in CIs because n_cls_out defaults to 1 for models that don't have any classifier labels.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @CISC , rolled back to using n_cls_out for pooling = rank for now

@tdakhran tdakhran force-pushed the tarek/feat/lfm2-colbert-350m branch from 894f76d to e8a0336 Compare January 5, 2026 04:19
Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @ggerganov should look over n_embd_out changes.


def set_vocab(self):
super().set_vocab()
self.gguf_writer.add_add_bos_token(False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why set this to False BTW? It's true in the config, and it's used in TemplateProcessing both for single and pair.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if the config is correct both add_add_bos_token and add_add_sep_token should be True, and sep token should be set to bos token.

Ideally this should be automatically done by SpecialVocabs TemplateProcessing, but IIRC this pattern isn't accepted (a warning is logged), don't remember exactly why.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch, that's a debug leftover (was dealing with double BOS), will remove.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you still need to add sep metadata as mentioned for reranking to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ColBERT doesn't need a sep token, it embeds queries and documents separately, and then the similarity score is calculated pairwise using maxsim. See the script attached here #18607 (comment).

This allows embedding documents once, caching document embeddings in the database, and then embedding only a query and computing a similarity score against precomputed document embeddings.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of adding this logic to llama.cpp, but then it will require document database management, and I decided to leave it to the client.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so TemplateProcessing is not used at all.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should update examples/embedding and examples/model-conversionto use the newllama_model_n_embd_out()`.

@tdakhran tdakhran requested a review from danbev as a code owner January 5, 2026 11:58
@tdakhran
Copy link
Contributor Author

tdakhran commented Jan 5, 2026

Thanks for the feedback. Addressed the comment.
Attaching updated usage script for reference that sends tokens instead of strings to avoid double BOS, will push it later into the GGUF repo.
rerank.py

from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
assert tensors_file.is_file()
tensor = load_file(tensors_file)["linear.weight"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a change request, but I'm wondering if we should introduce an extra_model_files() API in near future @CISC @compilade ?

Something like this:

class LFM2ColBertModel(LFM2Model):
    def extra_model_dir():
        return [self.dir_model / "1_Dense" / "model.safetensors"]
# tensors will be loaded and processed via `modify_tensors()`

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, would be handy to deduplicate some code, I guess we'll see more of this as more ST support gets added.

I checked out some other ColBERTv2 models BTW, and they seem to have linear.weight embedded in main file.

Copy link
Member

@danbev danbev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look at updating the model-conversion example as a separate PR if that is alright. I'm currently working on #18464 and I can either update it or do this as a separate PR.

I can also look at updating llama-embedding as a separate PR though it looks like it only requires a small update:

diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 81111e81b..89e626353 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -252,7 +252,7 @@ int main(int argc, char ** argv) {
     }
 
     // allocate output
-    const int n_embd = llama_model_n_embd(model);
+    const int n_embd = llama_model_n_embd_out(model);
     std::vector<float> embeddings(n_embd_count * n_embd, 0);
     float * emb = embeddings.data();

@tdakhran
Copy link
Contributor Author

tdakhran commented Jan 5, 2026

I'll take a look at updating the model-conversion example as a separate PR if that is alright. I'm currently working on #18464 and I can either update it or do this as a separate PR.

I can also look at updating llama-embedding as a separate PR though it looks like it only requires a small update:

diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 81111e81b..89e626353 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -252,7 +252,7 @@ int main(int argc, char ** argv) {
     }
 
     // allocate output
-    const int n_embd = llama_model_n_embd(model);
+    const int n_embd = llama_model_n_embd_out(model);
     std::vector<float> embeddings(n_embd_count * n_embd, 0);
     float * emb = embeddings.data();

@danbev I updated them in this commit 98d7da6, can revert if you prefer to do it in a separate PR.

@danbev
Copy link
Member

danbev commented Jan 5, 2026

@danbev I updated them in this commit 98d7da6, can revert if you prefer to do it in a separate PR.

No, this is great! I just missed that commit.

I was able to verify the converted model using the embedding-verify-logits-st target but it required some modifications to the scripts and how 'sentence_transformer` (st) models are handled. I need to think about how we should handle this as currently pooling is used for the 'st' target/option as that was the use case we had prior to this model. I'll look into this and follow up in a separate PR.

@CISC CISC merged commit 73d284a into ggml-org:master Jan 5, 2026
76 checks passed
@CISC CISC mentioned this pull request Jan 5, 2026
@tdakhran tdakhran deleted the tarek/feat/lfm2-colbert-350m branch January 5, 2026 23:03
@tdakhran
Copy link
Contributor Author

tdakhran commented Jan 5, 2026

For visibility, GGUFs and usage instructions are uploaded to https://huggingface.co/LiquidAI/LFM2-ColBERT-350M-GGUF

danbev added a commit to danbev/llama.cpp that referenced this pull request Jan 7, 2026
This commit adds a Python script to automatically detect the pooling
configuration from a sentence-transformers model directory.

The motivation for this change is that I make a mistake when adding the
sentence-transformers support and I incorrectly assumed that if an
embedding model uses sentence-transformers, it always used pooling. With
the recent addition of support for late interaction models, which can
have a down-projection but do not use pooling (like LFM2-ColBert-350M).

This commit builds upon ggml-org#18464
which needs to be merged first.

Refs: ggml-org#18607 (comment)
danbev added a commit to danbev/llama.cpp that referenced this pull request Jan 7, 2026
This commit adds a Python script to automatically detect the pooling
configuration from a sentence-transformers model directory.

The motivation for this change is that I make a mistake when adding the
sentence-transformers support and I incorrectly assumed that if an
embedding model uses sentence-transformers, it always used pooling. With
the recent addition of support for late interaction models, which can
have a down-projection but do not use pooling (like LFM2-ColBert-350M).

Refs: ggml-org#18607 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants