diff --git a/xinference/model/embedding/vllm/core.py b/xinference/model/embedding/vllm/core.py index 5b63c75874..eb398c3401 100644 --- a/xinference/model/embedding/vllm/core.py +++ b/xinference/model/embedding/vllm/core.py @@ -25,7 +25,6 @@ class VLLMEmbeddingModel(EmbeddingModel): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._context_length = None @@ -42,6 +41,19 @@ def load(self): ] raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + if self.model_family.model_name in { + "Qwen3-Embedding-0.6B", + "Qwen3-Embedding-4B", + "Qwen3-Embedding-8B", + }: + if "hf_overrides" not in self._kwargs: + self._kwargs["hf_overrides"] = { + "is_matryoshka": True, + } + elif isinstance(self._kwargs["hf_overrides"], dict): + self._kwargs["hf_overrides"].update( + is_matryoshka=True, + ) self._model = LLM(model=self._model_path, task="embed", **self._kwargs) self._tokenizer = self._model.get_tokenizer() @@ -56,14 +68,15 @@ def create_embedding( sentences: Union[str, List[str]], **kwargs, ): + from packaging.version import Version + from vllm import PoolingParams + from vllm import __version__ as vllm_version + sentences = self._fix_langchain_openai_inputs(sentences) model_uid = kwargs.pop("model_uid", None) normalize_embedding = kwargs.get("normalize_embedding", True) - if not normalize_embedding: - raise ValueError( - "vllm embedding engine does not support setting `normalize_embedding=False`" - ) + dimensions = kwargs.get("dimensions", None) assert self._model is not None @@ -92,8 +105,21 @@ def create_embedding( sentences = truncated_sentences[0] else: sentences = truncated_sentences - - outputs = self._model.embed(sentences, use_tqdm=False) + if Version(vllm_version) > Version("0.10.1"): + pool_params = PoolingParams( + dimensions=dimensions, normalize=normalize_embedding + ) + else: + if not normalize_embedding: + raise ValueError( + f"vLLM version {vllm_version} does not support " + f"unnormalized embeddings. " + f"Please upgrade to v0.10.1 or later." + ) + pool_params = PoolingParams(dimensions=dimensions) + outputs = self._model.embed( + sentences, use_tqdm=False, pooling_params=pool_params + ) embedding_list = [] all_token_nums = 0 for index, output in enumerate(outputs): diff --git a/xinference/model/embedding/vllm/tests/test_vllm_embedding.py b/xinference/model/embedding/vllm/tests/test_vllm_embedding.py index 981883408c..41979ad237 100644 --- a/xinference/model/embedding/vllm/tests/test_vllm_embedding.py +++ b/xinference/model/embedding/vllm/tests/test_vllm_embedding.py @@ -16,6 +16,7 @@ import pytest +from .....client import Client from ...cache_manager import EmbeddingCacheManager as CacheManager from ...core import ( EmbeddingModelFamilyV2, @@ -156,3 +157,33 @@ def test_embedding_model_with_vllm_long_text(): finally: if model_path is not None: shutil.rmtree(model_path, ignore_errors=True) + + +@pytest.mark.skipif(not VLLMEmbeddingModel.check_lib(), reason="vllm not installed") +def test_change_dim(setup): + endpoint, _ = setup + client = Client(endpoint) + model_uid = client.launch_model( + model_name="Qwen3-Embedding-0.6B", + model_type="embedding", + model_engine="vllm", + ) + + model = client.get_model(model_uid) + + content = ( + "We are testing the behavior of the VLLM embedding model when processing text that contains " + "significantly more tokens than the specified maximum limit. This test is important because " + "it helps us understand how the model handles token truncation or other processing strategies " + "when dealing with extremely long input sequences. The model should either truncate the input " + "or handle it gracefully without crashing. We expect the embedding dimension to remain consistent " + "at 384 dimensions regardless of the input length, as the model architecture should maintain " + "the same output dimensionality. This comprehensive test ensures robustness and reliability " + "of the embedding generation process under edge case conditions with very long text inputs." + ) + + embeds = model.create_embedding(content, dimensions=500) + assert len(embeds["data"][0]["embedding"]) == 500 + + embeds = model.create_embedding(content) + assert len(embeds["data"][0]["embedding"]) == 1024 diff --git a/xinference/model/rerank/vllm/core.py b/xinference/model/rerank/vllm/core.py index 2be23ab457..eac173b40c 100644 --- a/xinference/model/rerank/vllm/core.py +++ b/xinference/model/rerank/vllm/core.py @@ -3,6 +3,7 @@ from typing import List, Optional from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens +from ...utils import cache_clean from ..core import RerankModel, RerankModelFamilyV2, RerankSpecV1 SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"] @@ -42,6 +43,7 @@ def load(self): self._model = LLM(model=self._model_path, task="score", **self._kwargs) self._tokenizer = self._model.get_tokenizer() + @cache_clean def rerank( self, documents: List[str],