14
14
ChatCompletionResponse ,
15
15
FunctionCall ,
16
16
ToolCall ,
17
+ Usage ,
17
18
)
18
19
19
20
# MLX Imports
@@ -161,7 +162,9 @@ def apply_lm_chat_template(
161
162
return request .messages [- 1 ].content
162
163
163
164
164
- def handle_function_calls (output : str , request ):
165
+ def handle_function_calls (
166
+ output : str , request : ChatCompletionRequest , token_info : Usage
167
+ ) -> ChatCompletionResponse :
165
168
tool_calls = []
166
169
167
170
# Check for JSON format tool calls
@@ -264,6 +267,7 @@ def handle_function_calls(output: str, request):
264
267
id = f"chatcmpl-{ os .urandom (4 ).hex ()} " ,
265
268
created = int (time .time ()),
266
269
model = request .model ,
270
+ usage = token_info ,
267
271
choices = [
268
272
{
269
273
"index" : 0 ,
@@ -290,7 +294,9 @@ def load_vlm_model(model_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
290
294
291
295
292
296
def load_lm_model (model_name : str , config : Dict [str , Any ]) -> Dict [str , Any ]:
293
- model , tokenizer = lm_load (model_name )
297
+ time_start = time .time ()
298
+ model , tokenizer = lm_load (model_name , model_config = config )
299
+ print (f"Model loaded in { time .time () - time_start :.2f} seconds." )
294
300
return {"model" : model , "tokenizer" : tokenizer , "config" : config }
295
301
296
302
@@ -303,7 +309,15 @@ def vlm_stream_generator(
303
309
image_processor ,
304
310
max_tokens ,
305
311
temperature ,
312
+ stream_options ,
306
313
):
314
+ INCLUDE_USAGE = (
315
+ False if stream_options == None else stream_options .get ("include_usage" , False )
316
+ )
317
+ completion_tokens = 0
318
+ prompt_tokens = len (mx .array (processor .encode (prompt ))) if INCLUDE_USAGE else None
319
+ empty_usage : Usage = None
320
+
307
321
for token in vlm_stream_generate (
308
322
model ,
309
323
processor ,
@@ -313,10 +327,15 @@ def vlm_stream_generator(
313
327
max_tokens = max_tokens ,
314
328
temp = temperature ,
315
329
):
330
+ # Update token length info
331
+ if INCLUDE_USAGE :
332
+ completion_tokens += 1
333
+
316
334
chunk = ChatCompletionChunk (
317
335
id = f"chatcmpl-{ os .urandom (4 ).hex ()} " ,
318
336
created = int (time .time ()),
319
337
model = model_name ,
338
+ usage = empty_usage ,
320
339
choices = [
321
340
{
322
341
"index" : 0 ,
@@ -326,6 +345,20 @@ def vlm_stream_generator(
326
345
],
327
346
)
328
347
yield f"data: { json .dumps (chunk .model_dump ())} \n \n "
348
+
349
+ if INCLUDE_USAGE :
350
+ chunk = ChatCompletionChunk (
351
+ id = f"chatcmpl-{ os .urandom (4 ).hex ()} " ,
352
+ created = int (time .time ()),
353
+ model = model_name ,
354
+ choices = [],
355
+ usage = Usage (
356
+ prompt_tokens = prompt_tokens ,
357
+ completion_tokens = completion_tokens ,
358
+ total_tokens = prompt_tokens + completion_tokens ,
359
+ ),
360
+ )
361
+ yield f"data: { json .dumps (chunk .model_dump ())} \n \n "
329
362
yield "data: [DONE]\n \n "
330
363
331
364
@@ -361,6 +394,7 @@ def lm_generate(
361
394
)
362
395
363
396
prompt_tokens = mx .array (tokenizer .encode (prompt ))
397
+ prompt_token_len = len (prompt_tokens )
364
398
detokenizer = tokenizer .detokenizer
365
399
366
400
detokenizer .reset ()
@@ -377,24 +411,49 @@ def lm_generate(
377
411
detokenizer .add_token (token )
378
412
379
413
detokenizer .finalize ()
380
- return detokenizer .text
414
+
415
+ _completion_tokens = len (detokenizer .tokens )
416
+ token_length_info : Usage = Usage (
417
+ prompt_tokens = prompt_token_len ,
418
+ completion_tokens = _completion_tokens ,
419
+ total_tokens = prompt_token_len + _completion_tokens ,
420
+ )
421
+ return detokenizer .text , token_length_info
381
422
382
423
383
424
def lm_stream_generator (
384
- model , model_name , tokenizer , prompt , max_tokens , temperature , ** kwargs
425
+ model ,
426
+ model_name ,
427
+ tokenizer ,
428
+ prompt ,
429
+ max_tokens ,
430
+ temperature ,
431
+ stream_options ,
432
+ ** kwargs ,
385
433
):
386
434
stop_words = kwargs .pop ("stop_words" , [])
435
+ INCLUDE_USAGE = (
436
+ False if stream_options == None else stream_options .get ("include_usage" , False )
437
+ )
438
+ prompt_tokens = len (tokenizer .encode (prompt )) if INCLUDE_USAGE else None
439
+ completion_tokens = 0
440
+ empty_usage : Usage = None
387
441
388
442
for token in lm_stream_generate (
389
443
model , tokenizer , prompt , max_tokens = max_tokens , temp = temperature
390
444
):
391
445
if stop_words and token in stop_words :
392
446
break
393
447
448
+ # Update token length info
449
+ if INCLUDE_USAGE :
450
+ completion_tokens += 1
451
+
394
452
chunk = ChatCompletionChunk (
395
453
id = f"chatcmpl-{ os .urandom (4 ).hex ()} " ,
396
454
created = int (time .time ()),
397
455
model = model_name ,
456
+ usage = empty_usage ,
398
457
choices = [
399
458
{
400
459
"index" : 0 ,
@@ -405,4 +464,18 @@ def lm_stream_generator(
405
464
)
406
465
yield f"data: { json .dumps (chunk .model_dump ())} \n \n "
407
466
467
+ if INCLUDE_USAGE :
468
+ chunk = ChatCompletionChunk (
469
+ id = f"chatcmpl-{ os .urandom (4 ).hex ()} " ,
470
+ created = int (time .time ()),
471
+ model = model_name ,
472
+ choices = [],
473
+ usage = Usage (
474
+ prompt_tokens = prompt_tokens ,
475
+ completion_tokens = completion_tokens ,
476
+ total_tokens = prompt_tokens + completion_tokens ,
477
+ ),
478
+ )
479
+ yield f"data: { json .dumps (chunk .model_dump ())} \n \n "
480
+
408
481
yield "data: [DONE]\n \n "
0 commit comments