diff --git a/tests/entrypoints/pooling/classify/test_offline.py b/tests/entrypoints/pooling/classify/test_offline.py index a02d07ab0695..1320385564be 100644 --- a/tests/entrypoints/pooling/classify/test_offline.py +++ b/tests/entrypoints/pooling/classify/test_offline.py @@ -7,7 +7,7 @@ import torch from tests.models.utils import softmax -from vllm import LLM, ClassificationRequestOutput, PoolingParams, PoolingRequestOutput +from vllm import LLM, ClassificationRequestOutput, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory from vllm.tasks import PoolingTask @@ -66,15 +66,6 @@ def test_list_prompts(llm: LLM): assert len(outputs[i].outputs.probs) == num_labels -@pytest.mark.skip_global_cleanup -def test_token_classify(llm: LLM): - outputs = llm.encode(prompt, pooling_task="token_classify", use_tqdm=False) - assert len(outputs) == 1 - assert isinstance(outputs[0], PoolingRequestOutput) - assert outputs[0].prompt_token_ids == prompt_token_ids - assert outputs[0].outputs.data.shape == (len(prompt_token_ids), num_labels) - - @pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): def get_outputs(use_activation): @@ -107,8 +98,12 @@ def test_score_api(llm: LLM): llm.score("ping", "pong", use_tqdm=False) -@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) +@pytest.mark.parametrize("task", ["token_classify", "embed", "token_embed"]) def test_unsupported_tasks(llm: LLM, task: PoolingTask): - err_msg = f"Unsupported task: '{task}' Supported tasks.+" + if task == "token_classify": + err_msg = "Try switching the model's pooling_task via.+" + else: + err_msg = "Embedding API is not supported by this model.+" + with pytest.raises(ValueError, match=err_msg): llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index e23918fb8db8..8af1acc64668 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -436,26 +436,7 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): - task = "token_classify" - response = requests.post( - server.url_for("pooling"), - json={ - "model": model_name, - "input": input_text, - "encoding_format": "float", - "task": task, - }, - ) - poolings = PoolingResponse.model_validate(response.json()) - assert len(poolings.data) == 1 - assert len(poolings.data[0].data) == 8 - assert len(poolings.data[0].data[0]) == 2 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) +@pytest.mark.parametrize("task", ["token_classify", "embed", "token_embed", "plugin"]) async def test_pooling_not_supported( server: RemoteOpenAIServer, model_name: str, task: str ): @@ -469,4 +450,10 @@ async def test_pooling_not_supported( }, ) assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") + + if task == "token_classify": + err_msg = "Try switching the model's pooling_task via" + else: + err_msg = f"Unsupported task: {task!r}" + + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/embed/test_offline.py b/tests/entrypoints/pooling/embed/test_offline.py index 44328343f6d5..a9dbf98698c6 100644 --- a/tests/entrypoints/pooling/embed/test_offline.py +++ b/tests/entrypoints/pooling/embed/test_offline.py @@ -7,13 +7,16 @@ import torch import torch.nn.functional as F -from vllm import LLM, PoolingParams +from vllm import LLM, EmbeddingRequestOutput, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory from vllm.platforms import current_platform +from vllm.tasks import PoolingTask MODEL_NAME = "intfloat/multilingual-e5-small" -prompts = ["The chef prepared a delicious meal."] +prompt = "The chef prepared a delicious meal." +prompt_token_ids = [0, 581, 21861, 133888, 10, 8, 150, 60744, 109911, 5, 2] +embedding_size = 384 @pytest.fixture(scope="module") @@ -35,25 +38,47 @@ def llm(): seed=0, attention_config=attention_config, ) + assert embedding_size == llm.model_config.embedding_size yield weakref.proxy(llm) del llm - cleanup_dist_env_and_memory() @pytest.mark.skip_global_cleanup -def test_token_embed(llm: LLM): - outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) - multi_vector = outputs[0].outputs.data - assert multi_vector.shape == (11, 384) +def test_str_prompts(llm: LLM): + outputs = llm.embed(prompt, use_tqdm=False) + assert len(outputs) == 1 + assert isinstance(outputs[0], EmbeddingRequestOutput) + assert outputs[0].prompt_token_ids == prompt_token_ids + assert len(outputs[0].outputs.embedding) == embedding_size + + +@pytest.mark.skip_global_cleanup +def test_token_ids_prompts(llm: LLM): + outputs = llm.embed([prompt_token_ids], use_tqdm=False) + assert len(outputs) == 1 + assert isinstance(outputs[0], EmbeddingRequestOutput) + assert outputs[0].prompt_token_ids == prompt_token_ids + assert len(outputs[0].outputs.embedding) == embedding_size + +@pytest.mark.skip_global_cleanup +def test_list_prompts(llm: LLM): + outputs = llm.embed([prompt, prompt_token_ids], use_tqdm=False) + assert len(outputs) == 2 + for i in range(len(outputs)): + assert isinstance(outputs[i], EmbeddingRequestOutput) + assert outputs[i].prompt_token_ids == prompt_token_ids + assert len(outputs[i].outputs.embedding) == embedding_size + +@pytest.mark.skip_global_cleanup def test_pooling_params(llm: LLM): def get_outputs(normalize): outputs = llm.embed( - prompts, + [prompt], pooling_params=PoolingParams(use_activation=normalize), use_tqdm=False, ) @@ -70,3 +95,14 @@ def get_outputs(normalize): assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), ( "w_normal should be close to normal(wo_normal)." ) + + +@pytest.mark.parametrize("task", ["token_embed", "classify", "token_classify"]) +def test_unsupported_tasks(llm: LLM, task: PoolingTask): + if task == "token_embed": + err_msg = "Try switching the model's pooling_task via.+" + else: + err_msg = "Classification API is not supported by this model.+" + + with pytest.raises(ValueError, match=err_msg): + llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/embed/test_online.py b/tests/entrypoints/pooling/embed/test_online.py index 56ab09bc7afc..7df552420089 100644 --- a/tests/entrypoints/pooling/embed/test_online.py +++ b/tests/entrypoints/pooling/embed/test_online.py @@ -732,28 +732,9 @@ async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str): - task = "token_embed" - response = requests.post( - server.url_for("pooling"), - json={ - "model": model_name, - "input": input_text, - "encoding_format": "float", - "task": task, - }, - ) - - poolings = PoolingResponse.model_validate(response.json()) - - assert len(poolings.data) == 1 - assert len(poolings.data[0].data) == len(input_tokens) - assert len(poolings.data[0].data[0]) == 384 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"]) +@pytest.mark.parametrize( + "task", ["token_embed", "classify", "token_classify", "plugin"] +) async def test_pooling_not_supported( server: RemoteOpenAIServer, model_name: str, task: str ): @@ -767,4 +748,10 @@ async def test_pooling_not_supported( }, ) assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") + + if task == "token_embed": + err_msg = "Try switching the model's pooling_task via" + else: + err_msg = f"Unsupported task: {task!r}" + + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/score/test_online_rerank.py b/tests/entrypoints/pooling/score/test_online_rerank.py index b0e8152aed72..38d843532922 100644 --- a/tests/entrypoints/pooling/score/test_online_rerank.py +++ b/tests/entrypoints/pooling/score/test_online_rerank.py @@ -203,22 +203,7 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): - response = requests.post( - server.url_for("pooling"), - json={"model": model_name, "input": input_text, "encoding_format": "float"}, - ) - - poolings = PoolingResponse.model_validate(response.json()) - - assert len(poolings.data) == 1 - assert len(poolings.data[0].data) == len(input_tokens) - assert len(poolings.data[0].data[0]) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"]) +@pytest.mark.parametrize("task", ["token_classify", "embed", "token_embed", "plugin"]) async def test_pooling_not_supported( server: RemoteOpenAIServer, model_name: str, task: str ): @@ -232,4 +217,9 @@ async def test_pooling_not_supported( }, ) assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}") + if task == "token_classify": + err_msg = "Try switching the model's pooling_task via" + else: + err_msg = f"Unsupported task: {task!r}" + + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/token_classify/__init__.py b/tests/entrypoints/pooling/token_classify/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/pooling/token_classify/test_offline.py b/tests/entrypoints/pooling/token_classify/test_offline.py new file mode 100644 index 000000000000..19de08b271ec --- /dev/null +++ b/tests/entrypoints/pooling/token_classify/test_offline.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest + +from vllm import LLM, PoolingRequestOutput +from vllm.config import PoolerConfig +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.tasks import PoolingTask + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" + +prompt = "The chef prepared a delicious meal." +prompt_token_ids = [785, 29706, 10030, 264, 17923, 15145, 13] +num_labels = 2 + + +@pytest.fixture(scope="module") +def llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM( + model=MODEL_NAME, + pooler_config=PoolerConfig(pooling_task="token_classify"), + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + ) + + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_str_prompts(llm: LLM): + outputs = llm.encode(prompt, pooling_task="token_classify", use_tqdm=False) + assert len(outputs) == 1 + assert isinstance(outputs[0], PoolingRequestOutput) + assert outputs[0].prompt_token_ids == prompt_token_ids + assert outputs[0].outputs.data.shape == (len(prompt_token_ids), num_labels) + + +@pytest.mark.skip_global_cleanup +def test_token_ids_prompts(llm: LLM): + outputs = llm.encode( + [prompt_token_ids], pooling_task="token_classify", use_tqdm=False + ) + assert len(outputs) == 1 + assert isinstance(outputs[0], PoolingRequestOutput) + assert outputs[0].prompt_token_ids == prompt_token_ids + assert outputs[0].outputs.data.shape == (len(prompt_token_ids), num_labels) + + +@pytest.mark.skip_global_cleanup +def test_score_api(llm: LLM): + err_msg = "Score API is only enabled for num_labels == 1." + with pytest.raises(ValueError, match=err_msg): + llm.score("ping", "pong", use_tqdm=False) + + +@pytest.mark.parametrize("task", ["classify", "embed", "token_embed"]) +def test_unsupported_tasks(llm: LLM, task: PoolingTask): + if task == "classify": + err_msg = "Try switching the model's pooling_task via.+" + else: + err_msg = "Embedding API is not supported by this model.+" + + with pytest.raises(ValueError, match=err_msg): + llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/token_classify/test_online.py b/tests/entrypoints/pooling/token_classify/test_online.py new file mode 100644 index 000000000000..101841370a5d --- /dev/null +++ b/tests/entrypoints/pooling/token_classify/test_online.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" +DTYPE = "float32" # Use float32 to avoid NaN issue +input_text = "This product was excellent and exceeded my expectations" +input_tokens = [1986, 1985, 572, 9073, 323, 33808, 847, 16665] + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--enforce-eager", + "--max-model-len", + "512", + "--dtype", + DTYPE, + "--pooler-config", + '{"pooling_task": "token_classify"}', + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): + task = "token_classify" + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": input_text, + "encoding_format": "float", + "task": task, + }, + ) + poolings = PoolingResponse.model_validate(response.json()) + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 8 + assert len(poolings.data[0].data[0]) == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("task", ["classify", "embed", "token_embed", "plugin"]) +async def test_pooling_not_supported( + server: RemoteOpenAIServer, model_name: str, task: str +): + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": input_text, + "encoding_format": "float", + "task": task, + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + if task == "classify": + err_msg = "Try switching the model's pooling_task via" + else: + err_msg = f"Unsupported task: {task!r}" + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/entrypoints/pooling/token_embed/__init__.py b/tests/entrypoints/pooling/token_embed/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/pooling/token_embed/test_offline.py b/tests/entrypoints/pooling/token_embed/test_offline.py new file mode 100644 index 000000000000..dbd99e4fa9a5 --- /dev/null +++ b/tests/entrypoints/pooling/token_embed/test_offline.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest + +from vllm import LLM, PoolingRequestOutput +from vllm.config import PoolerConfig +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.platforms import current_platform +from vllm.tasks import PoolingTask + +MODEL_NAME = "intfloat/multilingual-e5-small" + +prompt = "The chef prepared a delicious meal." +prompt_token_ids = [0, 581, 21861, 133888, 10, 8, 150, 60744, 109911, 5, 2] +embedding_size = 384 + + +@pytest.fixture(scope="module") +def llm(): + # ROCm: Use FLEX_ATTENTION backend as it's the only attention backend + # that supports encoder-only models on ROCm. + attention_config = None + if current_platform.is_rocm(): + attention_config = {"backend": "FLEX_ATTENTION"} + + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM( + model=MODEL_NAME, + pooler_config=PoolerConfig(pooling_task="token_embed"), + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True, + seed=0, + attention_config=attention_config, + ) + assert embedding_size == llm.model_config.embedding_size + + yield weakref.proxy(llm) + + del llm + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_str_prompts(llm: LLM): + outputs = llm.encode(prompt, pooling_task="token_embed", use_tqdm=False) + assert len(outputs) == 1 + assert isinstance(outputs[0], PoolingRequestOutput) + assert outputs[0].outputs.data.shape == (11, 384) + + +@pytest.mark.skip_global_cleanup +def test_token_ids_prompts(llm: LLM): + outputs = llm.encode([prompt_token_ids], pooling_task="token_embed", use_tqdm=False) + assert len(outputs) == 1 + assert isinstance(outputs[0], PoolingRequestOutput) + assert outputs[0].outputs.data.shape == (11, 384) + + +@pytest.mark.parametrize("task", ["embed", "classify", "token_classify"]) +def test_unsupported_tasks(llm: LLM, task: PoolingTask): + if task == "embed": + err_msg = "Try switching the model's pooling_task via.+" + else: + err_msg = "Classification API is not supported by this model.+" + + with pytest.raises(ValueError, match=err_msg): + llm.encode(prompt, pooling_task=task, use_tqdm=False) diff --git a/tests/entrypoints/pooling/token_embed/test_online.py b/tests/entrypoints/pooling/token_embed/test_online.py new file mode 100644 index 000000000000..67be29533aa0 --- /dev/null +++ b/tests/entrypoints/pooling/token_embed/test_online.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse + +MODEL_NAME = "intfloat/multilingual-e5-small" +DTYPE = "bfloat16" +input_text = "The best thing about vLLM is that it supports many different models" +input_tokens = [ + 0, + 581, + 2965, + 13580, + 1672, + 81, + 23708, + 594, + 83, + 450, + 442, + 8060, + 7, + 5941, + 12921, + 115774, + 2, +] + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", + "--pooler-config", + '{"pooling_task": "token_embed"}', + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str): + task = "token_embed" + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": input_text, + "encoding_format": "float", + "task": task, + }, + ) + + poolings = PoolingResponse.model_validate(response.json()) + + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == len(input_tokens) + assert len(poolings.data[0].data[0]) == 384 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("task", ["embed", "classify", "token_classify", "plugin"]) +async def test_pooling_not_supported( + server: RemoteOpenAIServer, model_name: str, task: str +): + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": "test", + "encoding_format": "float", + "task": task, + }, + ) + assert response.json()["error"]["type"] == "BadRequestError" + + if task == "embed": + err_msg = "Try switching the model's pooling_task via" + else: + err_msg = f"Unsupported task: {task!r}" + + assert response.json()["error"]["message"].startswith(err_msg) diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py index c259c532220b..514221b3e023 100644 --- a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py +++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py @@ -6,6 +6,7 @@ from tests.models.utils import check_embeddings_close from vllm import TokensPrompt +from vllm.config import PoolerConfig @pytest.mark.parametrize( @@ -21,6 +22,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str): with vllm_runner( model, runner="pooling", + pooler_config=PoolerConfig(pooling_task="token_embed"), max_model_len=128, max_num_batched_tokens=chunk_size, enforce_eager=True, diff --git a/tests/models/language/pooling/test_bge_m3.py b/tests/models/language/pooling/test_bge_m3.py index c0ef263c7781..b825435541fe 100644 --- a/tests/models/language/pooling/test_bge_m3.py +++ b/tests/models/language/pooling/test_bge_m3.py @@ -25,29 +25,41 @@ similarity_reference = [[0.6259, 0.3474], [0.3309, 0.6734]] lexical_score_reference = [0.19554901123046875, 0.0] colbert_score_reference = [0.7797, 0.4620] +SUPPORTED_TASKS = ["embed", "token_embed", "token_classify"] + +@pytest.fixture(scope="module", params=SUPPORTED_TASKS) +def pooling_task(request): + yield request.param @pytest.fixture(scope="module") -def server(): +def server(pooling_task): args = [ "--max-model-len", str(MAX_MODEL_LEN), "--hf-overrides", '{"architectures": ["BgeM3EmbeddingModel"]}', + "--pooler-config", + '{"pooling_task": "'+pooling_task+'"}', ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server -@pytest_asyncio.fixture -async def client(server): - async with server.get_async_client() as async_client: - yield async_client - - @pytest.mark.asyncio -async def test_bge_m3_api_server_embedding(client: openai.AsyncOpenAI): +async def test_bge_m3_api_server_embedding(server, pooling_task): + client = server.get_async_client() + + if pooling_task != "embed": + with pytest.raises(openai.NotFoundError): + await run_client_embeddings( + client, + MODEL_NAME, + sentences_1, + ) + return + embeddings_list_1 = await run_client_embeddings( client, MODEL_NAME, @@ -117,7 +129,14 @@ def compute_lexical_matching_score( @pytest.mark.asyncio -async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI): +async def test_bge_m3_api_server_sparse_embedding(server, pooling_task): + client = server.get_async_client() + + if pooling_task != "token_classify": + with pytest.raises(openai.BadRequestError): + await sparse_embeddings(client, sentences_1) + return + embeddings_1 = await sparse_embeddings(client, sentences_1) embeddings_2 = await sparse_embeddings(client, sentences_2) @@ -138,8 +157,12 @@ async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_bge_m3_api_server_sparse_embedding_corner_case( - client: openai.AsyncOpenAI, + server, pooling_task ): + if pooling_task != "token_classify": + return + + client = server.get_async_client() embeddings = await sparse_embeddings(client, ["Hi"]) assert len(embeddings) == 1 assert 2673 in embeddings[0] @@ -155,7 +178,18 @@ def colbert_score(q_reps: torch.Tensor, p_reps: torch.Tensor) -> torch.Tensor: @pytest.mark.asyncio -async def test_bge_m3_api_server_multi_vector(client: openai.AsyncOpenAI): +async def test_bge_m3_api_server_multi_vector(server, pooling_task): + client = server.get_async_client() + + if pooling_task != "token_embed": + with pytest.raises(openai.BadRequestError): + await client.post( + "../pooling", + body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"}, + cast_to=httpx.Response, + ) + return + result_1 = await client.post( "../pooling", body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"}, diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index 488b27e2da0f..6f22c350a918 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -4,6 +4,7 @@ import torch from vllm import TokensPrompt +from vllm.config import PoolerConfig @pytest.mark.parametrize( @@ -20,6 +21,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): max_model_len=128, enforce_eager=True, runner="pooling", + pooler_config=PoolerConfig(pooling_task="token_embed"), enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( @@ -44,14 +46,3 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): assert len(output.prompt_token_ids) == n assert len(output.outputs.data) == n assert output.num_cached_tokens == 0 - - # skip_reading_prefix_cache can still write to cache - # to accelerate following requests - pooling_outputs = vllm_model.llm.encode( - [TokensPrompt(prompt_token_ids=t) for t in token_prompts], - pooling_task="embed", - ) - - for n, output in zip(n_prompt_tokens, pooling_outputs): - assert len(output.prompt_token_ids) == n - assert output.num_cached_tokens > 0 diff --git a/tests/models/language/pooling/test_multi_vector_retrieval.py b/tests/models/language/pooling/test_multi_vector_retrieval.py index 302f2df13557..7ad1dbe46d4e 100644 --- a/tests/models/language/pooling/test_multi_vector_retrieval.py +++ b/tests/models/language/pooling/test_multi_vector_retrieval.py @@ -5,6 +5,7 @@ from transformers import AutoModel from tests.models.utils import check_embeddings_close +from vllm.config import PoolerConfig @pytest.mark.parametrize( @@ -17,6 +18,7 @@ def test_embed_models(hf_runner, vllm_runner, example_prompts, model: str, dtype with vllm_runner( model, runner="pooling", + pooler_config=PoolerConfig(pooling_task="token_embed"), max_model_len=None, ) as vllm_model: vllm_outputs = vllm_model.token_embed(example_prompts) diff --git a/tests/models/language/pooling/test_pooler_config_init_behaviour.py b/tests/models/language/pooling/test_pooler_config_init_behaviour.py index a5a0c07e0c5d..9bb27386cd00 100644 --- a/tests/models/language/pooling/test_pooler_config_init_behaviour.py +++ b/tests/models/language/pooling/test_pooler_config_init_behaviour.py @@ -146,7 +146,7 @@ def test_multi_vector_retrieval_models_using_normalize( model, max_model_len=512, dtype=dtype, - pooler_config=PoolerConfig(use_activation=False), + pooler_config=PoolerConfig(use_activation=False, pooling_task="token_embed"), ) as vllm_model: wo_normalize = vllm_model.token_embed(example_prompts) @@ -154,7 +154,7 @@ def test_multi_vector_retrieval_models_using_normalize( model, max_model_len=512, dtype=dtype, - pooler_config=PoolerConfig(use_activation=True), + pooler_config=PoolerConfig(use_activation=True, pooling_task="token_embed"), ) as vllm_model: w_normalize = vllm_model.token_embed(example_prompts) diff --git a/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py b/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py index 85293e55cd81..2ff12c99fe14 100644 --- a/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py +++ b/tests/plugins_tests/test_bge_m3_sparse_io_processor_plugins.py @@ -102,7 +102,7 @@ async def test_bge_m3_sparse_plugin_online( """Test BGE-M3 sparse plugin in online mode via API.""" request_payload = { "model": model_config["model_name"], - "task": "token_classify", + "task": "plugin", "data": {"input": model_config["test_input"], "return_tokens": return_tokens}, } @@ -166,7 +166,7 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool): default_torch_num_threads=1, ) as llm_runner: llm = llm_runner.get_llm() - pooler_output = llm.encode(prompt, pooling_task="token_classify") + pooler_output = llm.encode(prompt, pooling_task="plugin") outputs = pooler_output[0] @@ -213,7 +213,7 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner): default_torch_num_threads=1, ) as llm_runner: llm = llm_runner.get_llm() - pooler_output = llm.encode(prompts, pooling_task="token_classify") + pooler_output = llm.encode(prompts, pooling_task="plugin") outputs = pooler_output[0] diff --git a/vllm/config/model.py b/vllm/config/model.py index 122d5eabd722..45ee591e5a51 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -25,7 +25,7 @@ from vllm.config.utils import config, getattr_iter from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.tasks import ScoreType +from vllm.tasks import PoolingTask, ScoreType, SupportedTask from vllm.transformers_utils.config import ( ConfigFormat, get_config, @@ -1409,6 +1409,41 @@ def get_diff_sampling_param(self) -> dict[str, Any]: return diff_sampling_param + def get_pooling_task( + self, supported_tasks: tuple[SupportedTask, ...] + ) -> PoolingTask | None: + if self.pooler_config is None: + return None + + pooling_task = self.pooler_config.pooling_task + + if pooling_task is not None: + if self.pooler_config.pooling_task in supported_tasks: + return self.pooler_config.pooling_task + else: + raise RuntimeError( + f"Unsupported task: {pooling_task!r} " + f"Supported tasks: {supported_tasks}" + ) + + if "token_classify" in supported_tasks: + for architecture in self.architectures: + if "ForTokenClassification" in architecture: + return "token_classify" + + priority: list[PoolingTask] = [ + "embed&token_classify", + "embed", + "classify", + "token_embed", + "token_classify", + "plugin", + ] + for task in priority: + if task in supported_tasks: + return task + return None + @cached_property def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 63aa1220b8ef..145adb26d0b6 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -5,6 +5,7 @@ from vllm.config.utils import config from vllm.logger import init_logger +from vllm.tasks import PoolingTask from vllm.utils.hashing import safe_hash logger = init_logger(__name__) @@ -20,6 +21,11 @@ class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" + pooling_task: PoolingTask | None = None + """ + The pooling task used for pooling. + """ + pooling_type: SequencePoolingType | TokenPoolingType | None = None """ The pooling method used for pooling. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4b617333c02f..85e6c75132f4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -382,16 +382,19 @@ def _make_config(value: Any, cls: type[_R]) -> _R: self.llm_engine = LLMEngine.from_engine_args( engine_args=engine_args, usage_context=UsageContext.LLM_CLASS ) + self.model_config = self.llm_engine.model_config self.engine_class = type(self.llm_engine) self.request_counter = Counter() self.default_sampling_params: dict[str, Any] | None = None supported_tasks = self.llm_engine.get_supported_tasks() - logger.info("Supported tasks: %s", supported_tasks) self.supported_tasks = supported_tasks + self.pooling_task = self.model_config.get_pooling_task(supported_tasks) + if self.pooling_task is not None: + logger.info("Supported pooling task: %s", self.pooling_task) - self.model_config = self.llm_engine.model_config + self.runner_type = self.model_config.runner_type self.renderer = self.llm_engine.renderer self.chat_template = load_chat_template(chat_template) self.io_processor = self.llm_engine.io_processor @@ -1072,31 +1075,7 @@ def encode( pooled hidden states in the same order as the input prompts. """ - if pooling_task is None: - raise ValueError( - "pooling_task required for `LLM.encode`\n" - "Please use one of the more specific methods or set the " - "pooling_task when using `LLM.encode`:\n" - " - For embeddings, use `LLM.embed(...)` " - 'or `pooling_task="embed"`.\n' - " - For classification logits, use `LLM.classify(...)` " - 'or `pooling_task="classify"`.\n' - " - For similarity scores, use `LLM.score(...)`.\n" - " - For rewards, use `LLM.reward(...)` " - 'or `pooling_task="token_classify"`\n' - " - For token classification, " - 'use `pooling_task="token_classify"`\n' - ' - For multi-vector retrieval, use `pooling_task="token_embed"`' - ) - - model_config = self.model_config - runner_type = model_config.runner_type - if runner_type != "pooling": - raise ValueError( - "LLM.encode() is only supported for pooling models. " - "Try passing `--runner pooling` to use the model as a " - "pooling model." - ) + self._verify_pooling_task(pooling_task) if isinstance(prompts, dict) and "data" in prompts: if self.io_processor is None: @@ -1206,6 +1185,62 @@ def encode( ) return outputs + def _verify_pooling_task(self, pooling_task: PoolingTask | None): + if self.runner_type != "pooling": + raise ValueError( + "LLM.encode() is only supported for pooling models. " + "Try passing `--runner pooling` to use the model as a " + "pooling model." + ) + + if pooling_task is None: + raise ValueError( + "pooling_task required for `LLM.encode`\n" + "Please use one of the more specific methods or set the " + "pooling_task when using `LLM.encode`:\n" + " - For embeddings, use `LLM.embed(...)` " + 'or `pooling_task="embed"`.\n' + " - For classification logits, use `LLM.classify(...)` " + 'or `pooling_task="classify"`.\n' + " - For similarity scores, use `LLM.score(...)`.\n" + " - For rewards, use `LLM.reward(...)` " + 'or `pooling_task="token_classify"`\n' + " - For token classification, " + 'use `pooling_task="token_classify"`\n' + ' - For multi-vector retrieval, use `pooling_task="token_embed"`' + ) + + if ( + pooling_task in ("embed", "token_embed") + and pooling_task not in self.supported_tasks + ): + raise ValueError( + "Embedding API is not supported by this model. " + "Try converting the model using `--convert embed`." + ) + + if ( + pooling_task in ("classify", "token_classify") + and pooling_task not in self.supported_tasks + ): + raise ValueError( + "Classification API is not supported by this model. " + "Try converting the model using `--convert classify`." + ) + + # plugin task uses io_processor.parse_request to verify inputs + if pooling_task != "plugin" and pooling_task != self.pooling_task: + if pooling_task not in self.supported_tasks: + raise ValueError( + f"Unsupported task: {pooling_task!r} " + f"Supported tasks: {self.pooling_task}" + ) + else: + raise ValueError( + f"Try switching the model's pooling_task " + f'via `PoolerConfig(pooling_task="{pooling_task}"`)' + ) + def embed( self, prompts: PromptType | Sequence[PromptType], @@ -1239,11 +1274,6 @@ def embed( A list of `EmbeddingRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - if "embed" not in self.supported_tasks: - raise ValueError( - "Embedding API is not supported by this model. " - "Try converting the model using `--convert embed`." - ) items = self.encode( prompts, @@ -1289,11 +1319,6 @@ def classify( A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - if "classify" not in self.supported_tasks: - raise ValueError( - "Classification API is not supported by this model. " - "Try converting the model using `--convert classify`." - ) items = self.encode( prompts, diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index e115b710ceeb..7fe0679ca067 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -27,10 +27,15 @@ def enable_scoring_api( supported_tasks: tuple["SupportedTask", ...], model_config: ModelConfig | None = None, ) -> bool: - if any(t in supported_tasks for t in ("embed", "token_embed")): + if model_config is None: + return False + + pooling_task = model_config.get_pooling_task(supported_tasks) + + if pooling_task in ("embed", "token_embed"): return True - if model_config is not None and "classify" in supported_tasks: + if pooling_task == "classify": num_labels = getattr(model_config.hf_config, "num_labels", 0) if num_labels != 1: logger.debug_once("Score API is only enabled for num_labels == 1.") @@ -45,18 +50,24 @@ def register_pooling_api_routers( supported_tasks: tuple["SupportedTask", ...], model_config: ModelConfig | None = None, ): - from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router + if model_config is None: + return + + pooling_task = model_config.get_pooling_task(supported_tasks) - app.include_router(pooling_router) + if pooling_task is not None: + from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router - if "classify" in supported_tasks: + app.include_router(pooling_router) + + if pooling_task == "classify": from vllm.entrypoints.pooling.classify.api_router import ( router as classify_router, ) app.include_router(classify_router) - if "embed" in supported_tasks: + if pooling_task == "embed": from vllm.entrypoints.pooling.embed.api_router import router as embed_router app.include_router(embed_router) @@ -79,10 +90,13 @@ def init_pooling_state( from vllm.entrypoints.pooling.embed.serving import ServingEmbedding from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.score.serving import ServingScores - from vllm.tasks import POOLING_TASKS model_config = engine_client.model_config + if model_config is None: + return + + pooling_task = model_config.get_pooling_task(supported_tasks) resolved_chat_template = load_chat_template(args.chat_template) state.serving_pooling = ( @@ -91,13 +105,14 @@ def init_pooling_state( engine_client, state.openai_serving_models, state.openai_serving_render, + supported_tasks=supported_tasks, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, ) ) - if any(t in supported_tasks for t in POOLING_TASKS) + if pooling_task is not None else None ) state.serving_embedding = ( @@ -109,7 +124,7 @@ def init_pooling_state( chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, ) - if "embed" in supported_tasks + if pooling_task == "embed" else None ) state.serving_classification = ( @@ -121,7 +136,7 @@ def init_pooling_state( chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, ) - if "classify" in supported_tasks + if pooling_task == "classify" else None ) state.serving_scores = ( diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 54151ccb7130..4bcd2f9f265d 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -37,6 +37,7 @@ from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.renderers.inputs.preprocess import prompt_to_seq +from vllm.tasks import SupportedTask from vllm.utils.async_utils import merge_async_iterators from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness @@ -49,6 +50,7 @@ def __init__( engine_client: EngineClient, models: OpenAIServingModels, openai_serving_render: OpenAIServingRender, + supported_tasks: tuple[SupportedTask, ...], *, request_logger: RequestLogger | None, chat_template: str | None, @@ -60,7 +62,8 @@ def __init__( models=models, request_logger=request_logger, ) - + self.supported_tasks = supported_tasks + self.pooling_task = self.model_config.get_pooling_task(supported_tasks) self.openai_serving_render = openai_serving_render self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format @@ -86,9 +89,25 @@ async def create_pooling( lora_request = self._maybe_get_adapters(request) + if request.task is None: + request.task = self.pooling_task + if getattr(request, "dimensions", None) is not None: return self.create_error_response("dimensions is currently not supported") + # plugin task uses io_processor.parse_request to verify inputs + if request.task != "plugin" and request.task != self.pooling_task: + if request.task not in self.supported_tasks: + raise ValueError( + f"Unsupported task: {request.task!r} " + f"Supported tasks: {self.pooling_task}" + ) + else: + raise ValueError( + "Try switching the model's pooling_task " + 'via --pooler-config {"pooling_task": "' + request.task + '"\})' + ) + engine_prompts: Sequence[ProcessorInputs] if use_io_processor := isinstance(request, IOProcessorRequest): if self.io_processor is None: diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py index e8c48d1c6d53..577217e19dc3 100644 --- a/vllm/entrypoints/sagemaker/api_router.py +++ b/vllm/entrypoints/sagemaker/api_router.py @@ -18,7 +18,7 @@ from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.health import health -from vllm.tasks import POOLING_TASKS, SupportedTask +from vllm.tasks import SupportedTask # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) @@ -53,50 +53,56 @@ def get_invocation_types( (CompletionRequest, (completion, create_completion)), ] - if "embed" in supported_tasks: - from vllm.entrypoints.pooling.embed.api_router import ( - create_embedding, - embedding, - ) - from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest - - INVOCATION_TYPES += [ - (EmbeddingRequest, (embedding, create_embedding)), - ] - - if "classify" in supported_tasks: - from vllm.entrypoints.pooling.classify.api_router import ( - classify, - create_classify, - ) - from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest - - INVOCATION_TYPES += [ - (ClassificationRequest, (classify, create_classify)), - ] - - if enable_scoring_api(supported_tasks, model_config): - from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank - from vllm.entrypoints.pooling.score.protocol import RerankRequest - - INVOCATION_TYPES += [ - (RerankRequest, (rerank, do_rerank)), - ] - - from vllm.entrypoints.pooling.score.api_router import create_score, score - from vllm.entrypoints.pooling.score.protocol import ScoreRequest - - INVOCATION_TYPES += [ - (ScoreRequest, (score, create_score)), - ] - - if any(task in POOLING_TASKS for task in supported_tasks): - from vllm.entrypoints.pooling.pooling.api_router import create_pooling, pooling - from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest - - INVOCATION_TYPES += [ - (PoolingRequest, (pooling, create_pooling)), - ] + if model_config: + pooling_task = model_config.get_pooling_task(supported_tasks) + + if pooling_task == "embed": + from vllm.entrypoints.pooling.embed.api_router import ( + create_embedding, + embedding, + ) + from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest + + INVOCATION_TYPES += [ + (EmbeddingRequest, (embedding, create_embedding)), + ] + + if pooling_task == "classify": + from vllm.entrypoints.pooling.classify.api_router import ( + classify, + create_classify, + ) + from vllm.entrypoints.pooling.classify.protocol import ClassificationRequest + + INVOCATION_TYPES += [ + (ClassificationRequest, (classify, create_classify)), + ] + + if enable_scoring_api(supported_tasks, model_config): + from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank + from vllm.entrypoints.pooling.score.protocol import RerankRequest + + INVOCATION_TYPES += [ + (RerankRequest, (rerank, do_rerank)), + ] + + from vllm.entrypoints.pooling.score.api_router import create_score, score + from vllm.entrypoints.pooling.score.protocol import ScoreRequest + + INVOCATION_TYPES += [ + (ScoreRequest, (score, create_score)), + ] + + if pooling_task is not None: + from vllm.entrypoints.pooling.pooling.api_router import ( + create_pooling, + pooling, + ) + from vllm.entrypoints.pooling.pooling.protocol import PoolingRequest + + INVOCATION_TYPES += [ + (PoolingRequest, (pooling, create_pooling)), + ] return INVOCATION_TYPES