diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index f8dc525a77..6a0c96f0af 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -136,18 +136,17 @@ embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1 embedding-run-original-model-st: embedding-run-original-model embedding-run-converted-model: - @./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \ - $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ - $(if $(USE_POOLING),--pooling) - -embedding-run-converted-model-st: USE_POOLING=1 -embedding-run-converted-model-st: embedding-run-converted-model + @POOLING_FLAG=$$(./scripts/utils/detect_pooling.py $(EMBEDDING_MODEL_PATH)); \ + echo "pooling: $$POOLING_FLAG"; \ + ./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \ + --pooling "$$POOLING_FLAG" \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") embedding-verify-logits: embedding-run-original-model embedding-run-converted-model @./scripts/embedding/compare-embeddings-logits.sh \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") -embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st +embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model @./scripts/embedding/compare-embeddings-logits.sh \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index 0f490e6c3b..d52fa37c12 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -1,11 +1,11 @@ #!/usr/bin/env bash -set -e +set -ex # Parse command line arguments CONVERTED_MODEL="" PROMPTS_FILE="" -USE_POOLING="" +POOLING="" while [[ $# -gt 0 ]]; do case $1 in @@ -14,8 +14,8 @@ while [[ $# -gt 0 ]]; do shift 2 ;; --pooling) - USE_POOLING="1" - shift + POOLING="$2" + shift 2 ;; *) if [ -z "$CONVERTED_MODEL" ]; then @@ -50,10 +50,5 @@ fi echo $CONVERTED_MODEL -cmake --build ../../build --target llama-logits -j8 -# TODO: update logits.cpp to accept a --file/-f option for the prompt -if [ -n "$USE_POOLING" ]; then - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" -else - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" -fi +cmake --build ../../build --target llama-debug -j8 +../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding $POOLING -p "$PROMPT" --save-logits diff --git a/examples/model-conversion/scripts/utils/detect_pooling.py b/examples/model-conversion/scripts/utils/detect_pooling.py new file mode 100755 index 0000000000..b6e51b11cb --- /dev/null +++ b/examples/model-conversion/scripts/utils/detect_pooling.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +""" +Detect pooling configuration from sentence-transformers model. +Usage: detect_pooling.py +Outputs: pooling flag for llama-cli (e.g., "--pooling mean") or "--pooling none" +""" + +import sys +import json +from pathlib import Path + +def detect_pooling(model_dir: str) -> str: + model_path = Path(model_dir) + + pooling_configs = list(model_path.glob("*_Pooling/config.json")) + + if not pooling_configs: + return "--pooling none" + + config_path = pooling_configs[0] + try: + with open(config_path, 'r') as f: + config = json.load(f) + + if config.get("pooling_mode_mean_tokens", False): + return "--pooling mean" + elif config.get("pooling_mode_cls_token", False): + return "--pooling cls" + elif config.get("pooling_mode_lasttoken", False): + return "--pooling last" + else: + print(f"Warning: Unsupported pooling mode in {config_path}", file=sys.stderr) + return "--pooling none" + + except Exception as e: + print(f"Error reading pooling config: {e}", file=sys.stderr) + return "" + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: detect_pooling.py ", file=sys.stderr) + sys.exit(1) + + print(detect_pooling(sys.argv[1]))