Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions examples/model-conversion/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,17 @@ embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
embedding-run-original-model-st: embedding-run-original-model

embedding-run-converted-model:
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
$(if $(USE_POOLING),--pooling)

embedding-run-converted-model-st: USE_POOLING=1
embedding-run-converted-model-st: embedding-run-converted-model
@POOLING_FLAG=$$(./scripts/utils/detect_pooling.py $(EMBEDDING_MODEL_PATH)); \
echo "pooling: $$POOLING_FLAG"; \
./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
--pooling "$$POOLING_FLAG" \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")

embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
@./scripts/embedding/compare-embeddings-logits.sh \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")

embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model
@./scripts/embedding/compare-embeddings-logits.sh \
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")

Expand Down
17 changes: 6 additions & 11 deletions examples/model-conversion/scripts/embedding/run-converted-model.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env bash

set -e
set -ex

# Parse command line arguments
CONVERTED_MODEL=""
PROMPTS_FILE=""
USE_POOLING=""
POOLING=""

while [[ $# -gt 0 ]]; do
case $1 in
Expand All @@ -14,8 +14,8 @@ while [[ $# -gt 0 ]]; do
shift 2
;;
--pooling)
USE_POOLING="1"
shift
POOLING="$2"
shift 2
;;
*)
if [ -z "$CONVERTED_MODEL" ]; then
Expand Down Expand Up @@ -50,10 +50,5 @@ fi

echo $CONVERTED_MODEL

cmake --build ../../build --target llama-logits -j8
# TODO: update logits.cpp to accept a --file/-f option for the prompt
if [ -n "$USE_POOLING" ]; then
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
else
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
fi
cmake --build ../../build --target llama-debug -j8
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding $POOLING -p "$PROMPT" --save-logits
44 changes: 44 additions & 0 deletions examples/model-conversion/scripts/utils/detect_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python3
"""
Detect pooling configuration from sentence-transformers model.
Usage: detect_pooling.py <model_dir>
Outputs: pooling flag for llama-cli (e.g., "--pooling mean") or "--pooling none"
"""

import sys
import json
from pathlib import Path

def detect_pooling(model_dir: str) -> str:
model_path = Path(model_dir)

pooling_configs = list(model_path.glob("*_Pooling/config.json"))

if not pooling_configs:
return "--pooling none"

config_path = pooling_configs[0]
try:
with open(config_path, 'r') as f:
config = json.load(f)

if config.get("pooling_mode_mean_tokens", False):
return "--pooling mean"
elif config.get("pooling_mode_cls_token", False):
return "--pooling cls"
elif config.get("pooling_mode_lasttoken", False):
return "--pooling last"
else:
print(f"Warning: Unsupported pooling mode in {config_path}", file=sys.stderr)
return "--pooling none"

except Exception as e:
print(f"Error reading pooling config: {e}", file=sys.stderr)
return ""

if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: detect_pooling.py <model_dir>", file=sys.stderr)
sys.exit(1)

print(detect_pooling(sys.argv[1]))