diff --git a/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py b/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py index d23e1461b997..080c7e797221 100644 --- a/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py +++ b/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py @@ -10,7 +10,7 @@ import pytest import requests -from tests.utils import RemoteOpenAIServer +from tests.utils import ROCM_EXTRA_ARGS, RemoteOpenAIServer MODEL_NAME = "BAAI/bge-base-en-v1.5" DTYPE = "bfloat16" @@ -28,7 +28,7 @@ def server(): "512", "--gpu-memory-utilization", "0.02", - ] + ] + ROCM_EXTRA_ARGS with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/pooling/embed/test_online_dimensions.py b/tests/entrypoints/pooling/embed/test_online_dimensions.py index 0545b8a0ae2f..638f5218989d 100644 --- a/tests/entrypoints/pooling/embed/test_online_dimensions.py +++ b/tests/entrypoints/pooling/embed/test_online_dimensions.py @@ -10,7 +10,7 @@ from tests.conftest import HfRunner from tests.models.language.pooling.embed_utils import run_embedding_correctness_test from tests.models.utils import EmbedModelInfo -from tests.utils import RemoteOpenAIServer +from tests.utils import ROCM_EXTRA_ARGS, RemoteOpenAIServer from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse from vllm.platforms import current_platform @@ -49,7 +49,7 @@ def server(model_info, dtype: str): "--enforce-eager", "--max-model-len", "512", - ] + ] + ROCM_EXTRA_ARGS if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5": # Manually enable Matryoshka Embeddings diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py index 9bbdde5bbc80..312eed6bf167 100644 --- a/vllm/entrypoints/pooling/base/serving.py +++ b/vllm/entrypoints/pooling/base/serving.py @@ -118,6 +118,7 @@ async def _prepare_generators( ) pooling_params = self.io_processor.create_pooling_params(ctx.request) + pooling_params.verify(self.model_config) for i, engine_prompt in enumerate(ctx.engine_prompts): prompt_request_id = ( diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 1c5abecda863..e3682280ec50 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -309,6 +309,9 @@ def create_error_response( if isinstance(message, Exception): exc = message + logger.debug( + "create_error_response called with %s: %s", type(exc).__name__, exc + ) from vllm.exceptions import VLLMNotFoundError, VLLMValidationError