@@ -36,6 +36,8 @@ class OpenAIWrapper:
36
36
cache_path_root : str = ".cache"
37
37
extra_kwargs = {"cache_seed" , "filter_func" , "allow_format_str_template" , "context" , "api_version" }
38
38
openai_kwargs = set (inspect .getfullargspec (OpenAI .__init__ ).kwonlyargs )
39
+ total_usage_summary : Dict = None
40
+ actual_usage_summary : Dict = None
39
41
40
42
def __init__ (self , * , config_list : List [Dict ] = None , ** base_config ):
41
43
"""
@@ -233,14 +235,15 @@ def yes_or_no_filter(context, response):
233
235
# Try to get the response from cache
234
236
key = get_key (params )
235
237
response = cache .get (key , None )
238
+ if response is not None :
239
+ self ._update_usage_summary (response , use_cache = True )
236
240
if response is not None :
237
241
# check the filter
238
242
pass_filter = filter_func is None or filter_func (context = context , response = response )
239
243
if pass_filter or i == last :
240
244
# Return the response if it passes the filter or it is the last client
241
245
response .config_id = i
242
246
response .pass_filter = pass_filter
243
- response .cost = self .cost (response )
244
247
return response
245
248
continue # filter is not passed; try the next config
246
249
try :
@@ -250,6 +253,9 @@ def yes_or_no_filter(context, response):
250
253
if i == last :
251
254
raise
252
255
else :
256
+ # add cost calculation before caching not matter filter is passed or not
257
+ response .cost = self .cost (response )
258
+ self ._update_usage_summary (response , use_cache = False )
253
259
if cache_seed is not None :
254
260
# Cache the response
255
261
with diskcache .Cache (f"{ self .cache_path_root } /{ cache_seed } " ) as cache :
@@ -261,25 +267,9 @@ def yes_or_no_filter(context, response):
261
267
# Return the response if it passes the filter or it is the last client
262
268
response .config_id = i
263
269
response .pass_filter = pass_filter
264
- response .cost = self .cost (response )
265
270
return response
266
271
continue # filter is not passed; try the next config
267
272
268
- def cost (self , response : Union [ChatCompletion , Completion ]) -> float :
269
- """Calculate the cost of the response."""
270
- model = response .model
271
- if model not in oai_price1k :
272
- # TODO: add logging to warn that the model is not found
273
- return 0
274
-
275
- n_input_tokens = response .usage .prompt_tokens
276
- n_output_tokens = response .usage .completion_tokens
277
- tmp_price1K = oai_price1k [model ]
278
- # First value is input token rate, second value is output token rate
279
- if isinstance (tmp_price1K , tuple ):
280
- return (tmp_price1K [0 ] * n_input_tokens + tmp_price1K [1 ] * n_output_tokens ) / 1000
281
- return tmp_price1K * (n_input_tokens + n_output_tokens ) / 1000
282
-
283
273
def _completions_create (self , client , params ):
284
274
completions = client .chat .completions if "messages" in params else client .completions
285
275
# If streaming is enabled, has messages, and does not have functions, then
@@ -342,6 +332,105 @@ def _completions_create(self, client, params):
342
332
response = completions .create (** params )
343
333
return response
344
334
335
+ def _update_usage_summary (self , response : ChatCompletion | Completion , use_cache : bool ) -> None :
336
+ """Update the usage summary.
337
+
338
+ Usage is calculated no mattter filter is passed or not.
339
+ """
340
+
341
+ def update_usage (usage_summary ):
342
+ if usage_summary is None :
343
+ usage_summary = {"total_cost" : response .cost }
344
+ else :
345
+ usage_summary ["total_cost" ] += response .cost
346
+
347
+ usage_summary [response .model ] = {
348
+ "cost" : usage_summary .get (response .model , {}).get ("cost" , 0 ) + response .cost ,
349
+ "prompt_tokens" : usage_summary .get (response .model , {}).get ("prompt_tokens" , 0 )
350
+ + response .usage .prompt_tokens ,
351
+ "completion_tokens" : usage_summary .get (response .model , {}).get ("completion_tokens" , 0 )
352
+ + response .usage .completion_tokens ,
353
+ "total_tokens" : usage_summary .get (response .model , {}).get ("total_tokens" , 0 )
354
+ + response .usage .total_tokens ,
355
+ }
356
+ return usage_summary
357
+
358
+ self .total_usage_summary = update_usage (self .total_usage_summary )
359
+ if not use_cache :
360
+ self .actual_usage_summary = update_usage (self .actual_usage_summary )
361
+
362
+ def print_usage_summary (self , mode : Union [str , List [str ]] = ["actual" , "total" ]) -> None :
363
+ """Print the usage summary."""
364
+
365
+ def print_usage (usage_summary , usage_type = "total" ):
366
+ word_from_type = "including" if usage_type == "total" else "excluding"
367
+ if usage_summary is None :
368
+ print ("No actual cost incurred (all completions are using cache)." , flush = True )
369
+ return
370
+
371
+ print (f"Usage summary { word_from_type } cached usage: " , flush = True )
372
+ print (f"Total cost: { round (usage_summary ['total_cost' ], 5 )} " , flush = True )
373
+ for model , counts in usage_summary .items ():
374
+ if model == "total_cost" :
375
+ continue #
376
+ print (
377
+ f"* Model '{ model } ': cost: { round (counts ['cost' ], 5 )} , prompt_tokens: { counts ['prompt_tokens' ]} , completion_tokens: { counts ['completion_tokens' ]} , total_tokens: { counts ['total_tokens' ]} " ,
378
+ flush = True ,
379
+ )
380
+
381
+ if self .total_usage_summary is None :
382
+ print ('No usage summary. Please call "create" first.' , flush = True )
383
+ return
384
+
385
+ if isinstance (mode , list ):
386
+ if len (mode ) == 0 or len (mode ) > 2 :
387
+ raise ValueError (f'Invalid mode: { mode } , choose from "actual", "total", ["actual", "total"]' )
388
+ if "actual" in mode and "total" in mode :
389
+ mode = "both"
390
+ elif "actual" in mode :
391
+ mode = "actual"
392
+ elif "total" in mode :
393
+ mode = "total"
394
+
395
+ print ("-" * 100 , flush = True )
396
+ if mode == "both" :
397
+ print_usage (self .actual_usage_summary , "actual" )
398
+ print ()
399
+ if self .total_usage_summary != self .actual_usage_summary :
400
+ print_usage (self .total_usage_summary , "total" )
401
+ else :
402
+ print (
403
+ "All completions are non-cached: the total cost with cached completions is the same as actual cost." ,
404
+ flush = True ,
405
+ )
406
+ elif mode == "total" :
407
+ print_usage (self .total_usage_summary , "total" )
408
+ elif mode == "actual" :
409
+ print_usage (self .actual_usage_summary , "actual" )
410
+ else :
411
+ raise ValueError (f'Invalid mode: { mode } , choose from "actual", "total", ["actual", "total"]' )
412
+ print ("-" * 100 , flush = True )
413
+
414
+ def clear_usage_summary (self ) -> None :
415
+ """Clear the usage summary."""
416
+ self .total_usage_summary = None
417
+ self .actual_usage_summary = None
418
+
419
+ def cost (self , response : Union [ChatCompletion , Completion ]) -> float :
420
+ """Calculate the cost of the response."""
421
+ model = response .model
422
+ if model not in oai_price1k :
423
+ # TODO: add logging to warn that the model is not found
424
+ return 0
425
+
426
+ n_input_tokens = response .usage .prompt_tokens
427
+ n_output_tokens = response .usage .completion_tokens
428
+ tmp_price1K = oai_price1k [model ]
429
+ # First value is input token rate, second value is output token rate
430
+ if isinstance (tmp_price1K , tuple ):
431
+ return (tmp_price1K [0 ] * n_input_tokens + tmp_price1K [1 ] * n_output_tokens ) / 1000
432
+ return tmp_price1K * (n_input_tokens + n_output_tokens ) / 1000
433
+
345
434
@classmethod
346
435
def extract_text_or_function_call (cls , response : ChatCompletion | Completion ) -> List [str ]:
347
436
"""Extract the text or function calls from a completion or chat response.
0 commit comments