Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions xinference/model/embedding/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


class VLLMEmbeddingModel(EmbeddingModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._context_length = None
Expand All @@ -42,6 +41,19 @@ def load(self):
]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
if self.model_family.model_name in {
"Qwen3-Embedding-0.6B",
"Qwen3-Embedding-4B",
"Qwen3-Embedding-8B",
}:
if "hf_overrides" not in self._kwargs:
self._kwargs["hf_overrides"] = {
"is_matryoshka": True,
}
elif isinstance(self._kwargs["hf_overrides"], dict):
self._kwargs["hf_overrides"].update(
is_matryoshka=True,
)

self._model = LLM(model=self._model_path, task="embed", **self._kwargs)
self._tokenizer = self._model.get_tokenizer()
Expand All @@ -56,14 +68,15 @@ def create_embedding(
sentences: Union[str, List[str]],
**kwargs,
):
from packaging.version import Version
from vllm import PoolingParams
from vllm import __version__ as vllm_version

sentences = self._fix_langchain_openai_inputs(sentences)
model_uid = kwargs.pop("model_uid", None)

normalize_embedding = kwargs.get("normalize_embedding", True)
if not normalize_embedding:
raise ValueError(
"vllm embedding engine does not support setting `normalize_embedding=False`"
)
dimensions = kwargs.get("dimensions", None)

assert self._model is not None

Expand Down Expand Up @@ -92,8 +105,21 @@ def create_embedding(
sentences = truncated_sentences[0]
else:
sentences = truncated_sentences

outputs = self._model.embed(sentences, use_tqdm=False)
if Version(vllm_version) > Version("0.10.1"):
pool_params = PoolingParams(
dimensions=dimensions, normalize=normalize_embedding
)
else:
if not normalize_embedding:
raise ValueError(
f"vLLM version {vllm_version} does not support "
f"unnormalized embeddings. "
f"Please upgrade to v0.10.1 or later."
)
pool_params = PoolingParams(dimensions=dimensions)
outputs = self._model.embed(
sentences, use_tqdm=False, pooling_params=pool_params
)
embedding_list = []
all_token_nums = 0
for index, output in enumerate(outputs):
Expand Down
31 changes: 31 additions & 0 deletions xinference/model/embedding/vllm/tests/test_vllm_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest

from .....client import Client
from ...cache_manager import EmbeddingCacheManager as CacheManager
from ...core import (
EmbeddingModelFamilyV2,
Expand Down Expand Up @@ -156,3 +157,33 @@ def test_embedding_model_with_vllm_long_text():
finally:
if model_path is not None:
shutil.rmtree(model_path, ignore_errors=True)


@pytest.mark.skipif(not VLLMEmbeddingModel.check_lib(), reason="vllm not installed")
def test_change_dim(setup):
endpoint, _ = setup
client = Client(endpoint)
model_uid = client.launch_model(
model_name="Qwen3-Embedding-0.6B",
model_type="embedding",
model_engine="vllm",
)

model = client.get_model(model_uid)

content = (
"We are testing the behavior of the VLLM embedding model when processing text that contains "
"significantly more tokens than the specified maximum limit. This test is important because "
"it helps us understand how the model handles token truncation or other processing strategies "
"when dealing with extremely long input sequences. The model should either truncate the input "
"or handle it gracefully without crashing. We expect the embedding dimension to remain consistent "
"at 384 dimensions regardless of the input length, as the model architecture should maintain "
"the same output dimensionality. This comprehensive test ensures robustness and reliability "
"of the embedding generation process under edge case conditions with very long text inputs."
)

embeds = model.create_embedding(content, dimensions=500)
assert len(embeds["data"][0]["embedding"]) == 500

embeds = model.create_embedding(content)
assert len(embeds["data"][0]["embedding"]) == 1024
2 changes: 2 additions & 0 deletions xinference/model/rerank/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

from ....types import Document, DocumentObj, Meta, Rerank, RerankTokens
from ...utils import cache_clean
from ..core import RerankModel, RerankModelFamilyV2, RerankSpecV1

SUPPORTED_MODELS_PREFIXES = ["bge", "gte", "text2vec", "m3e", "gte", "Qwen3"]
Expand Down Expand Up @@ -42,6 +43,7 @@ def load(self):
self._model = LLM(model=self._model_path, task="score", **self._kwargs)
self._tokenizer = self._model.get_tokenizer()

@cache_clean
def rerank(
self,
documents: List[str],
Expand Down
Loading