diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index 488b27e2da0f..d539ad27aa83 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -3,7 +3,7 @@ import pytest import torch -from vllm import TokensPrompt +from vllm import SamplingParams, TokensPrompt @pytest.mark.parametrize( @@ -14,16 +14,17 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): n_prompt_tokens = [55, 56, 57] token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + prompts = [TokensPrompt(prompt_token_ids=t) for t in token_prompts] with vllm_runner( model, max_model_len=128, enforce_eager=True, - runner="pooling", + runner="generate", enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( - [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + prompts=prompts, pooling_task="token_embed", ) @@ -36,7 +37,7 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): # we need to skip reading cache at this request by # request.skip_reading_prefix_cache pooling_outputs = vllm_model.llm.encode( - [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + prompts=prompts, pooling_task="token_embed", ) @@ -48,10 +49,19 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): # skip_reading_prefix_cache can still write to cache # to accelerate following requests pooling_outputs = vllm_model.llm.encode( - [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + prompts=prompts, pooling_task="embed", ) for n, output in zip(n_prompt_tokens, pooling_outputs): assert len(output.prompt_token_ids) == n assert output.num_cached_tokens > 0 + + # Support generate text and returning Prompt Hidden States + generate_outputs = vllm_model.llm.generate( + prompts=prompts, + sampling_params=SamplingParams(max_tokens=1), + ) + for n, output in zip(n_prompt_tokens, generate_outputs): + assert len(output.prompt_token_ids) == n + assert output.num_cached_tokens > 0 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 31319cf64aeb..2612f9b04719 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1025,13 +1025,8 @@ def encode( raise ValueError(error_str) model_config = self.model_config - runner_type = model_config.runner_type - if runner_type != "pooling": - raise ValueError( - "LLM.encode() is only supported for pooling models. " - "Try passing `--runner pooling` to use the model as a " - "pooling model." - ) + if pooling_task not in self.supported_tasks: + raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") io_processor_prompt = False if isinstance(prompts, dict) and "data" in prompts: @@ -1069,9 +1064,6 @@ def encode( # Use default pooling params. pooling_params = PoolingParams() - if pooling_task not in self.supported_tasks: - raise ValueError(f"pooling_task must be one of {self.supported_tasks}.") - for param in as_iter(pooling_params): param.verify(pooling_task, model_config) # for backwards compatibility diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 4a5caa7e27fc..7ae10aaaaa37 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -108,6 +108,15 @@ def verify( def _merge_default_parameters( self, model_config: Optional["ModelConfig"] = None ) -> None: + if self.skip_reading_prefix_cache is None: + # If prefix caching is enabled, + # the output of all pooling may less than n_prompt_tokens, + # we need to skip reading cache at this request. + if self.task in ["token_embed", "token_classify"]: + self.skip_reading_prefix_cache = True + else: + self.skip_reading_prefix_cache = False + if model_config is None: return @@ -125,15 +134,6 @@ def _merge_default_parameters( if getattr(self, k, None) is None: setattr(self, k, getattr(pooler_config, k)) - if self.skip_reading_prefix_cache is None: - # If prefix caching is enabled, - # the output of all pooling may less than n_prompt_tokens, - # we need to skip reading cache at this request. - if self.task in ["token_embed", "token_classify"]: - self.skip_reading_prefix_cache = True - else: - self.skip_reading_prefix_cache = False - self._verify_step_pooling(pooler_config, valid_parameters) def _verify_step_pooling( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ead7a3619dea..1bcd7474ca37 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -479,8 +479,8 @@ def remove_request(self, req_id: str) -> int | None: del self.lora_id_to_lora_request[lora_id] self.request_lora_mapping[req_index] = 0 - if self.is_pooling_model: - self.pooling_params.pop(req_id, None) + pooling_params = self.pooling_params.pop(req_id, None) + if pooling_params is not None: self.pooling_states.pop(req_id, None) return req_index diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 978224faae65..ab745e847f67 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,7 +35,7 @@ CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, - update_config, + update_config, PoolerConfig, set_current_vllm_config, ) from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.eplb.eplb_state import EplbState @@ -173,6 +173,7 @@ sanity_check_mm_encoder_outputs, scatter_mm_placeholders, ) +from ...model_executor.layers.pooler import DispatchPooler, Pooler if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -817,8 +818,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: else: generator = None - if self.is_pooling_model: - assert pooling_params is not None + if pooling_params is not None: task = pooling_params.task assert task is not None, "You did not set `task` in the API" @@ -2295,12 +2295,12 @@ def get_model(self) -> nn.Module: return self.model.unwrap() return self.model - def get_supported_generation_tasks(self) -> list[GenerationTask]: + def get_supported_generation_tasks(self) -> list[GenerationTask|PoolingTask]: model = self.get_model() supported_tasks = list[GenerationTask]() if is_text_generation_model(model): - supported_tasks.append("generate") + supported_tasks.extend(["generate", "embed", "token_embed"]) if supports_transcription(model): if model.supports_transcription_only: @@ -3110,7 +3110,7 @@ def execute_model( self.kv_connector_output = kv_connector_output return hidden_states - if self.is_pooling_model: + if len(self.input_batch.pooling_params) > 0: # Return the pooling output. output = self._pool( hidden_states, num_scheduled_tokens, num_scheduled_tokens_np @@ -3674,6 +3674,16 @@ def load_model(self, eep_scale_up: bool = False) -> None: and mm_config.is_multimodal_pruning_enabled() ) + if not self.is_pooling_model: + with set_current_vllm_config(self.vllm_config): + pooler_config = PoolerConfig(pooling_type="LAST") + self.model.pooler = DispatchPooler( + { + "token_embed": Pooler.for_token_embed(pooler_config), + "embed": Pooler.for_embed(pooler_config), + }, + ) + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: logger.info_once("EPLB is enabled for model %s.", self.model_config.model) global_expert_load = (