Skip to content

Commit cd199d8

Browse files
ZachZimmBlaizzy
andauthored
Implement Token Usage Tracking (#27)
* Added token usage tracking in accordance with OpenAI API spec * Removed some unneeded, commented out code * Added optional dict to with option. * Removed extraneous comment * Fixed indentation error in lm_stream_generator during final chunk creation and send * updated tests --------- Co-authored-by: Prince Canuma <[email protected]>
1 parent c0ce5fd commit cd199d8

File tree

5 files changed

+160
-13
lines changed

5 files changed

+160
-13
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.DS_Store
22
__pycache__
33
*.egg-info
4-
venv/*
4+
env/
5+
venv/*

fastmlx/fastmlx.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ChatCompletionRequest,
2020
ChatCompletionResponse,
2121
ChatMessage,
22+
Usage,
2223
)
2324
from .types.model import SupportedModels
2425

@@ -180,6 +181,7 @@ async def chat_completion(request: ChatCompletionRequest):
180181
image_processor,
181182
request.max_tokens,
182183
request.temperature,
184+
stream_options=request.stream_options,
183185
),
184186
media_type="text/event-stream",
185187
)
@@ -235,11 +237,12 @@ async def chat_completion(request: ChatCompletionRequest):
235237
request.max_tokens,
236238
request.temperature,
237239
stop_words=stop_words,
240+
stream_options=request.stream_options,
238241
),
239242
media_type="text/event-stream",
240243
)
241244
else:
242-
output = lm_generate(
245+
output, token_length_info = lm_generate(
243246
model,
244247
tokenizer,
245248
prompt,
@@ -249,7 +252,7 @@ async def chat_completion(request: ChatCompletionRequest):
249252
)
250253

251254
# Parse the output to check for function calls
252-
return handle_function_calls(output, request)
255+
return handle_function_calls(output, request, token_length_info)
253256

254257

255258
@app.get("/v1/supported_models", response_model=SupportedModels)

fastmlx/types/chat/chat_completion.py

+9
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ class ChatMessage(BaseModel):
3131
content: Union[str, List[ChatCompletionContentPartParam]]
3232

3333

34+
class Usage(BaseModel):
35+
prompt_tokens: int
36+
completion_tokens: int
37+
total_tokens: int
38+
39+
3440
class ChatCompletionRequest(BaseModel):
3541
model: str
3642
messages: List[ChatMessage]
@@ -40,13 +46,15 @@ class ChatCompletionRequest(BaseModel):
4046
temperature: Optional[float] = Field(default=0.2)
4147
tools: Optional[List[Function]] = Field(default=None)
4248
tool_choice: Optional[str] = Field(default=None)
49+
stream_options: Optional[Dict[str, Any]] = Field(default=None)
4350

4451

4552
class ChatCompletionResponse(BaseModel):
4653
id: str
4754
object: str = "chat.completion"
4855
created: int
4956
model: str
57+
usage: Usage
5058
choices: List[dict]
5159
tool_calls: Optional[List[ToolCall]] = None
5260

@@ -57,3 +65,4 @@ class ChatCompletionChunk(BaseModel):
5765
created: int
5866
model: str
5967
choices: List[Dict[str, Any]]
68+
usage: Optional[Usage] = None

fastmlx/utils.py

+77-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ChatCompletionResponse,
1515
FunctionCall,
1616
ToolCall,
17+
Usage,
1718
)
1819

1920
# MLX Imports
@@ -161,7 +162,9 @@ def apply_lm_chat_template(
161162
return request.messages[-1].content
162163

163164

164-
def handle_function_calls(output: str, request):
165+
def handle_function_calls(
166+
output: str, request: ChatCompletionRequest, token_info: Usage
167+
) -> ChatCompletionResponse:
165168
tool_calls = []
166169

167170
# Check for JSON format tool calls
@@ -264,6 +267,7 @@ def handle_function_calls(output: str, request):
264267
id=f"chatcmpl-{os.urandom(4).hex()}",
265268
created=int(time.time()),
266269
model=request.model,
270+
usage=token_info,
267271
choices=[
268272
{
269273
"index": 0,
@@ -290,7 +294,9 @@ def load_vlm_model(model_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
290294

291295

292296
def load_lm_model(model_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
293-
model, tokenizer = lm_load(model_name)
297+
time_start = time.time()
298+
model, tokenizer = lm_load(model_name, model_config=config)
299+
print(f"Model loaded in {time.time() - time_start:.2f} seconds.")
294300
return {"model": model, "tokenizer": tokenizer, "config": config}
295301

296302

@@ -303,7 +309,15 @@ def vlm_stream_generator(
303309
image_processor,
304310
max_tokens,
305311
temperature,
312+
stream_options,
306313
):
314+
INCLUDE_USAGE = (
315+
False if stream_options == None else stream_options.get("include_usage", False)
316+
)
317+
completion_tokens = 0
318+
prompt_tokens = len(mx.array(processor.encode(prompt))) if INCLUDE_USAGE else None
319+
empty_usage: Usage = None
320+
307321
for token in vlm_stream_generate(
308322
model,
309323
processor,
@@ -313,10 +327,15 @@ def vlm_stream_generator(
313327
max_tokens=max_tokens,
314328
temp=temperature,
315329
):
330+
# Update token length info
331+
if INCLUDE_USAGE:
332+
completion_tokens += 1
333+
316334
chunk = ChatCompletionChunk(
317335
id=f"chatcmpl-{os.urandom(4).hex()}",
318336
created=int(time.time()),
319337
model=model_name,
338+
usage=empty_usage,
320339
choices=[
321340
{
322341
"index": 0,
@@ -326,6 +345,20 @@ def vlm_stream_generator(
326345
],
327346
)
328347
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
348+
349+
if INCLUDE_USAGE:
350+
chunk = ChatCompletionChunk(
351+
id=f"chatcmpl-{os.urandom(4).hex()}",
352+
created=int(time.time()),
353+
model=model_name,
354+
choices=[],
355+
usage=Usage(
356+
prompt_tokens=prompt_tokens,
357+
completion_tokens=completion_tokens,
358+
total_tokens=prompt_tokens + completion_tokens,
359+
),
360+
)
361+
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
329362
yield "data: [DONE]\n\n"
330363

331364

@@ -361,6 +394,7 @@ def lm_generate(
361394
)
362395

363396
prompt_tokens = mx.array(tokenizer.encode(prompt))
397+
prompt_token_len = len(prompt_tokens)
364398
detokenizer = tokenizer.detokenizer
365399

366400
detokenizer.reset()
@@ -377,24 +411,49 @@ def lm_generate(
377411
detokenizer.add_token(token)
378412

379413
detokenizer.finalize()
380-
return detokenizer.text
414+
415+
_completion_tokens = len(detokenizer.tokens)
416+
token_length_info: Usage = Usage(
417+
prompt_tokens=prompt_token_len,
418+
completion_tokens=_completion_tokens,
419+
total_tokens=prompt_token_len + _completion_tokens,
420+
)
421+
return detokenizer.text, token_length_info
381422

382423

383424
def lm_stream_generator(
384-
model, model_name, tokenizer, prompt, max_tokens, temperature, **kwargs
425+
model,
426+
model_name,
427+
tokenizer,
428+
prompt,
429+
max_tokens,
430+
temperature,
431+
stream_options,
432+
**kwargs,
385433
):
386434
stop_words = kwargs.pop("stop_words", [])
435+
INCLUDE_USAGE = (
436+
False if stream_options == None else stream_options.get("include_usage", False)
437+
)
438+
prompt_tokens = len(tokenizer.encode(prompt)) if INCLUDE_USAGE else None
439+
completion_tokens = 0
440+
empty_usage: Usage = None
387441

388442
for token in lm_stream_generate(
389443
model, tokenizer, prompt, max_tokens=max_tokens, temp=temperature
390444
):
391445
if stop_words and token in stop_words:
392446
break
393447

448+
# Update token length info
449+
if INCLUDE_USAGE:
450+
completion_tokens += 1
451+
394452
chunk = ChatCompletionChunk(
395453
id=f"chatcmpl-{os.urandom(4).hex()}",
396454
created=int(time.time()),
397455
model=model_name,
456+
usage=empty_usage,
398457
choices=[
399458
{
400459
"index": 0,
@@ -405,4 +464,18 @@ def lm_stream_generator(
405464
)
406465
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
407466

467+
if INCLUDE_USAGE:
468+
chunk = ChatCompletionChunk(
469+
id=f"chatcmpl-{os.urandom(4).hex()}",
470+
created=int(time.time()),
471+
model=model_name,
472+
choices=[],
473+
usage=Usage(
474+
prompt_tokens=prompt_tokens,
475+
completion_tokens=completion_tokens,
476+
total_tokens=prompt_tokens + completion_tokens,
477+
),
478+
)
479+
yield f"data: {json.dumps(chunk.model_dump())}\n\n"
480+
408481
yield "data: [DONE]\n\n"

0 commit comments

Comments
 (0)