2525 resolve_chat_template_content_format )
2626from vllm .entrypoints .score_utils import (_cosine_similarity ,
2727 _validate_score_input_lens )
28+ from vllm .entrypoints .utils import _validate_truncation_size
2829from vllm .inputs import PromptType , SingletonPrompt , TextPrompt , TokensPrompt
2930from vllm .inputs .parse import is_token_prompt , parse_and_batch_prompt
3031from vllm .logger import init_logger
@@ -793,6 +794,7 @@ def encode(
793794 pooling_params : Optional [Union [PoolingParams ,
794795 Sequence [PoolingParams ]]] = None ,
795796 * ,
797+ truncate_prompt_tokens : Optional [int ] = None ,
796798 use_tqdm : bool = True ,
797799 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
798800 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
@@ -807,6 +809,7 @@ def encode(
807809 pooling_params : Optional [Union [PoolingParams ,
808810 Sequence [PoolingParams ]]] = None ,
809811 prompt_token_ids : Optional [list [int ]] = None ,
812+ truncate_prompt_tokens : Optional [int ] = None ,
810813 use_tqdm : bool = True ,
811814 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
812815 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
@@ -821,6 +824,7 @@ def encode(
821824 pooling_params : Optional [Union [PoolingParams ,
822825 Sequence [PoolingParams ]]] = None ,
823826 prompt_token_ids : Optional [list [list [int ]]] = None ,
827+ truncate_prompt_tokens : Optional [int ] = None ,
824828 use_tqdm : bool = True ,
825829 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
826830 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
@@ -836,6 +840,7 @@ def encode(
836840 Sequence [PoolingParams ]]] = None ,
837841 * ,
838842 prompt_token_ids : list [int ],
843+ truncate_prompt_tokens : Optional [int ] = None ,
839844 use_tqdm : bool = True ,
840845 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
841846 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
@@ -851,6 +856,7 @@ def encode(
851856 Sequence [PoolingParams ]]] = None ,
852857 * ,
853858 prompt_token_ids : list [list [int ]],
859+ truncate_prompt_tokens : Optional [int ] = None ,
854860 use_tqdm : bool = True ,
855861 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
856862 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
@@ -864,6 +870,7 @@ def encode(
864870 prompts : None ,
865871 pooling_params : None ,
866872 prompt_token_ids : Union [list [int ], list [list [int ]]],
873+ truncate_prompt_tokens : Optional [int ] = None ,
867874 use_tqdm : bool = True ,
868875 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
869876 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
@@ -882,6 +889,7 @@ def encode(
882889 pooling_params : Optional [Union [PoolingParams ,
883890 Sequence [PoolingParams ]]] = None ,
884891 prompt_token_ids : Optional [Union [list [int ], list [list [int ]]]] = None ,
892+ truncate_prompt_tokens : Optional [int ] = None ,
885893 use_tqdm : bool = True ,
886894 lora_request : Optional [Union [list [LoRARequest ], LoRARequest ]] = None ,
887895 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
@@ -946,10 +954,15 @@ def encode(
946954 for pooling_param in pooling_params :
947955 pooling_param .verify (self .llm_engine .model_config )
948956
957+ tokenization_kwargs : dict [str , Any ] = {}
958+ _validate_truncation_size (self .llm_engine .model_config .max_model_len ,
959+ truncate_prompt_tokens , tokenization_kwargs )
960+
949961 self ._validate_and_add_requests (
950962 prompts = parsed_prompts ,
951963 params = pooling_params ,
952964 lora_request = lora_request ,
965+ tokenization_kwargs = tokenization_kwargs ,
953966 prompt_adapter_request = prompt_adapter_request ,
954967 )
955968
@@ -962,6 +975,7 @@ def embed(
962975 prompts : Union [PromptType , Sequence [PromptType ]],
963976 / ,
964977 * ,
978+ truncate_prompt_tokens : Optional [int ] = None ,
965979 use_tqdm : bool = True ,
966980 pooling_params : Optional [Union [PoolingParams ,
967981 Sequence [PoolingParams ]]] = None ,
@@ -995,6 +1009,7 @@ def embed(
9951009 "Embedding API is only enabled for `--task embed`" )
9961010
9971011 items = self .encode (prompts ,
1012+ truncate_prompt_tokens = truncate_prompt_tokens ,
9981013 use_tqdm = use_tqdm ,
9991014 pooling_params = pooling_params ,
10001015 lora_request = lora_request ,
@@ -1055,6 +1070,7 @@ def _embedding_score(
10551070
10561071 encoded_output : list [PoolingRequestOutput ] = self .encode (
10571072 text_1 + text_2 ,
1073+ truncate_prompt_tokens = truncate_prompt_tokens ,
10581074 use_tqdm = use_tqdm ,
10591075 lora_request = lora_request ,
10601076 prompt_adapter_request = prompt_adapter_request )
@@ -1098,9 +1114,8 @@ def _cross_encoding_score(
10981114 pooling_params = PoolingParams ()
10991115
11001116 tokenization_kwargs : dict [str , Any ] = {}
1101- if truncate_prompt_tokens is not None :
1102- tokenization_kwargs ["truncation" ] = True
1103- tokenization_kwargs ["max_length" ] = truncate_prompt_tokens
1117+ _validate_truncation_size (self .llm_engine .model_config .max_model_len ,
1118+ truncate_prompt_tokens , tokenization_kwargs )
11041119
11051120 parsed_prompts = []
11061121
@@ -1323,6 +1338,7 @@ def _validate_and_add_requests(
13231338 Sequence [PoolingParams ]],
13241339 lora_request : Optional [Union [Sequence [LoRARequest ], LoRARequest ]],
13251340 prompt_adapter_request : Optional [PromptAdapterRequest ],
1341+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
13261342 guided_options : Optional [GuidedDecodingRequest ] = None ,
13271343 priority : Optional [list [int ]] = None ,
13281344 ) -> None :
@@ -1359,6 +1375,7 @@ def _validate_and_add_requests(
13591375 self ._add_request (
13601376 prompt ,
13611377 params [i ] if isinstance (params , Sequence ) else params ,
1378+ tokenization_kwargs = tokenization_kwargs ,
13621379 lora_request = lora_request [i ] if isinstance (
13631380 lora_request , Sequence ) else lora_request ,
13641381 prompt_adapter_request = prompt_adapter_request ,
@@ -1369,6 +1386,7 @@ def _add_request(
13691386 self ,
13701387 prompt : PromptType ,
13711388 params : Union [SamplingParams , PoolingParams ],
1389+ tokenization_kwargs : Optional [dict [str , Any ]] = None ,
13721390 lora_request : Optional [LoRARequest ] = None ,
13731391 prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
13741392 priority : int = 0 ,
@@ -1379,6 +1397,7 @@ def _add_request(
13791397 prompt ,
13801398 params ,
13811399 lora_request = lora_request ,
1400+ tokenization_kwargs = tokenization_kwargs ,
13821401 prompt_adapter_request = prompt_adapter_request ,
13831402 priority = priority ,
13841403 )
0 commit comments