77import signal
88import socket
99import sys
10- from typing import Any , Dict , Optional , Union
10+ from operator import itemgetter
11+ from typing import Any , Optional
1112
1213import sglang as sgl
1314import uvloop
@@ -210,6 +211,7 @@ async def generate(self, request: dict):
210211 else request ["batch_token_ids" ],
211212 sampling_params = sampling_params ,
212213 stream = True ,
214+ return_logprob = True ,
213215 bootstrap_host = bootstrap_host ,
214216 bootstrap_port = bootstrap_port ,
215217 bootstrap_room = bootstrap_room ,
@@ -231,54 +233,49 @@ async def generate(self, request: dict):
231233 else request ["batch_token_ids" ],
232234 sampling_params = sampling_params ,
233235 stream = True ,
236+ return_logprob = True ,
234237 )
235238
236239 async for out in self ._process_stream (g , unpack = False , is_batch = is_batch ):
237240 yield out
238241
239242 async def _process_stream (self , stream_source , unpack : bool , is_batch : bool ):
240- # Initialize based on batch mode
241- num_output_tokens_so_far : Union [Dict [int , int ], int ]
242- if is_batch :
243- num_output_tokens_so_far = {}
244- else :
245- num_output_tokens_so_far = 0
243+ assert not is_batch , "Batch processing is not supported."
244+ num_output_tokens_so_far = 0
246245
247246 async for res in stream_source :
248247 data = res .data () if unpack else res
249248 finish_reason = data ["meta_info" ]["finish_reason" ]
250249
251- if is_batch :
252- # Handle batch response
253- assert isinstance (num_output_tokens_so_far , dict )
254- index = data .get ("index" , 0 )
255- if index not in num_output_tokens_so_far :
256- num_output_tokens_so_far [index ] = 0
257-
258- if finish_reason :
259- out = {
260- "token_ids" : [],
261- "finish_reason" : finish_reason ["type" ],
262- "index" : index ,
263- }
264- else :
265- next_total_toks = len (data ["output_ids" ])
266- new_tokens = data ["output_ids" ][num_output_tokens_so_far [index ] :]
267- out = {
268- "token_ids" : new_tokens ,
269- "index" : index ,
270- }
271- num_output_tokens_so_far [index ] = next_total_toks
250+ # Handle single response
251+ assert isinstance (num_output_tokens_so_far , int )
252+ if finish_reason :
253+ out = {"token_ids" : [], "finish_reason" : finish_reason ["type" ]}
272254 else :
273- # Handle single response
274- assert isinstance (num_output_tokens_so_far , int )
275- if finish_reason :
276- out = {"token_ids" : [], "finish_reason" : finish_reason ["type" ]}
277- else :
278- next_total_toks = len (data ["output_ids" ])
279- out = {"token_ids" : data ["output_ids" ][num_output_tokens_so_far :]}
280- num_output_tokens_so_far = next_total_toks
255+ next_total_toks = len (res ["meta_info" ]["output_token_logprobs" ])
256+ new_tokens = list (
257+ map (
258+ itemgetter (1 ),
259+ res ["meta_info" ]["output_token_logprobs" ][
260+ num_output_tokens_so_far :
261+ ],
262+ )
263+ )
264+ new_logprobs = list (
265+ map (
266+ itemgetter (0 ),
267+ res ["meta_info" ]["output_token_logprobs" ][
268+ num_output_tokens_so_far :
269+ ],
270+ )
271+ )
272+ out = {
273+ "token_ids" : new_tokens ,
274+ "log_probs" : new_logprobs ,
275+ }
276+ num_output_tokens_so_far = next_total_toks
281277
278+ logging .debug (f"Generated output: { out } " )
282279 yield out
283280
284281 async def _prefill_generator (self , prefill ):
0 commit comments