66from vllm .config import DecodingConfig , ModelConfig
77from vllm .core .scheduler import SchedulerOutputs
88from vllm .inputs .data import PromptType , TokensPrompt
9+ from vllm .inputs .preprocess import InputPreprocessor
910from vllm .logger import init_logger
1011from vllm .lora .request import LoRARequest
1112from vllm .model_executor .layers .sampler import SamplerOutput
@@ -59,7 +60,8 @@ def generate(
5960
6061 async def beam_search (
6162 self ,
62- prompt : Union [str , List [int ]],
63+ prompt : Union [PromptType , List [int ]],
64+ model_config : ModelConfig ,
6365 request_id : str ,
6466 params : BeamSearchParams ,
6567 ) -> AsyncGenerator [RequestOutput , None ]:
@@ -69,32 +71,40 @@ async def beam_search(
6971 ignore_eos = params .ignore_eos
7072 temperature = params .temperature
7173 length_penalty = params .length_penalty
74+ include_stop_str_in_output = params .include_stop_str_in_output
7275
73- tokenizer = await self .get_tokenizer (lora_request = None )
74- if isinstance (prompt , str ):
75- tokenized_prompt = tokenizer .encode (prompt )
76- prompt_text = prompt
77- else :
78- tokenized_prompt = prompt
79- prompt_text = None
80- tokenized_length = len (tokenized_prompt )
76+ tokenizer = await self .get_tokenizer ()
77+ input_preprocessor = InputPreprocessor (model_config , tokenizer )
78+
79+ (prompt_text , prompt_token_ids , multi_modal_data ,
80+ mm_processor_kwargs ) = input_preprocessor ._extract_prompt_components (
81+ prompt ,
82+ request_id = request_id ,
83+ )
84+ tokenized_length = len (prompt_token_ids )
8185
8286 sort_beams_key = create_sort_beams_key_function (
8387 tokenizer .eos_token_id , length_penalty )
8488
85- beam_search_params = SamplingParams (logprobs = 2 * beam_width ,
86- max_tokens = 1 ,
87- temperature = temperature )
89+ beam_search_params = SamplingParams (
90+ logprobs = 2 * beam_width ,
91+ max_tokens = 1 ,
92+ temperature = temperature ,
93+ )
8894 all_beams = [
89- BeamSearchSequence (tokens = tokenized_prompt ,
95+ BeamSearchSequence (tokens = prompt_token_ids ,
96+ cum_logprob = 0 ,
9097 logprobs = [],
91- cum_logprob = 0 )
98+ multi_modal_data = multi_modal_data ,
99+ mm_processor_kwargs = mm_processor_kwargs )
92100 ]
93101 completed = []
94102
95103 for _ in range (max_tokens ):
96104 prompts_batch = [
97- TokensPrompt (prompt_token_ids = beam .tokens )
105+ TokensPrompt (prompt_token_ids = beam .tokens ,
106+ multi_modal_data = beam .multi_modal_data ,
107+ mm_processor_kwargs = beam .mm_processor_kwargs )
98108 for beam in all_beams
99109 ]
100110
@@ -120,17 +130,31 @@ async def beam_search(
120130 if result .outputs [0 ].logprobs is not None :
121131 logprobs = result .outputs [0 ].logprobs [0 ]
122132 for token_id , logprob_obj in logprobs .items ():
123- new_beam = BeamSearchSequence (
124- tokens = current_beam .tokens + [token_id ],
125- logprobs = current_beam .logprobs + [logprobs ],
126- cum_logprob = current_beam .cum_logprob +
127- logprob_obj .logprob )
128-
129133 if token_id == tokenizer .eos_token_id and \
130134 not ignore_eos :
131- completed .append (new_beam )
135+ completed .append (
136+ BeamSearchSequence (
137+ tokens = current_beam .tokens +
138+ [token_id ] if include_stop_str_in_output
139+ else current_beam .tokens ,
140+ logprobs = current_beam .logprobs +
141+ [logprobs ],
142+ cum_logprob = current_beam .cum_logprob +
143+ logprob_obj .logprob ,
144+ finish_reason = "stop" ,
145+ stop_reason = tokenizer .eos_token_id ))
132146 else :
133- new_beams .append (new_beam )
147+ new_beams .append (
148+ BeamSearchSequence (
149+ tokens = current_beam .tokens + [token_id ],
150+ logprobs = current_beam .logprobs +
151+ [logprobs ],
152+ cum_logprob = current_beam .cum_logprob +
153+ logprob_obj .logprob ,
154+ multi_modal_data = current_beam .
155+ multi_modal_data ,
156+ mm_processor_kwargs = current_beam .
157+ mm_processor_kwargs ))
134158
135159 sorted_beams = sorted (new_beams , key = sort_beams_key , reverse = True )
136160 all_beams = sorted_beams [:beam_width ]
@@ -151,16 +175,18 @@ async def beam_search(
151175 request_id = request_id ,
152176 prompt = prompt_text ,
153177 outputs = [
154- CompletionOutput (
155- text = beam .text ,
156- cumulative_logprob = beam .cum_logprob ,
157- token_ids = beam .tokens [tokenized_length :],
158- index = i ,
159- logprobs = beam .logprobs ,
160- ) for (i , beam ) in enumerate (best_beams )
178+ CompletionOutput (text = beam .text ,
179+ cumulative_logprob = beam .cum_logprob ,
180+ token_ids = beam .tokens [tokenized_length :],
181+ index = i ,
182+ logprobs = beam .logprobs ,
183+ finish_reason = beam .finish_reason if
184+ beam .finish_reason is not None else "length" ,
185+ stop_reason = beam .stop_reason )
186+ for (i , beam ) in enumerate (best_beams )
161187 ],
162188 finished = True ,
163- prompt_token_ids = tokenized_prompt ,
189+ prompt_token_ids = prompt_token_ids ,
164190 prompt_logprobs = None )
165191
166192 yield beam_search_output
0 commit comments