6
6
import sqlite3
7
7
import threading
8
8
import uuid
9
- from typing import TYPE_CHECKING , Any , Dict , List , Tuple , Union
9
+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Tuple , TypeVar , Union
10
10
11
11
from openai import AzureOpenAI , OpenAI
12
12
from openai .types .chat import ChatCompletion
25
25
26
26
__all__ = ("SqliteLogger" ,)
27
27
28
+ F = TypeVar ("F" , bound = Callable [..., Any ])
29
+
30
+
31
+ def safe_serialize (obj : Any ) -> str :
32
+ def default (o : Any ) -> str :
33
+ if hasattr (o , "to_json" ):
34
+ return str (o .to_json ())
35
+ else :
36
+ return f"<<non-serializable: { type (o ).__qualname__ } >>"
37
+
38
+ return json .dumps (obj , default = default )
39
+
28
40
29
41
class SqliteLogger (BaseLogger ):
30
42
schema_version = 1
@@ -49,6 +61,7 @@ def start(self) -> str:
49
61
client_id INTEGER,
50
62
wrapper_id INTEGER,
51
63
session_id TEXT,
64
+ source_name TEXT,
52
65
request TEXT,
53
66
response TEXT,
54
67
is_cached INEGER,
@@ -118,6 +131,18 @@ class TEXT, -- type or class name of cli
118
131
"""
119
132
self ._run_query (query = query )
120
133
134
+ query = """
135
+ CREATE TABLE IF NOT EXISTS function_calls (
136
+ source_id INTEGER,
137
+ source_name TEXT,
138
+ function_name TEXT,
139
+ args TEXT DEFAULT NULL,
140
+ returns TEXT DEFAULT NULL,
141
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
142
+ );
143
+ """
144
+ self ._run_query (query = query )
145
+
121
146
current_verion = self ._get_current_db_version ()
122
147
if current_verion is None :
123
148
self ._run_query (
@@ -192,6 +217,7 @@ def log_chat_completion(
192
217
invocation_id : uuid .UUID ,
193
218
client_id : int ,
194
219
wrapper_id : int ,
220
+ source : Union [str , Agent ],
195
221
request : Dict [str , Union [float , str , List [Dict [str , str ]]]],
196
222
response : Union [str , ChatCompletion ],
197
223
is_cached : int ,
@@ -208,10 +234,16 @@ def log_chat_completion(
208
234
else :
209
235
response_messages = json .dumps (to_dict (response ), indent = 4 )
210
236
237
+ source_name = None
238
+ if isinstance (source , str ):
239
+ source_name = source
240
+ else :
241
+ source_name = source .name
242
+
211
243
query = """
212
244
INSERT INTO chat_completions (
213
- invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time
214
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
245
+ invocation_id, client_id, wrapper_id, session_id, request, response, is_cached, cost, start_time, end_time, source_name
246
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? )
215
247
"""
216
248
args = (
217
249
invocation_id ,
@@ -224,6 +256,7 @@ def log_chat_completion(
224
256
cost ,
225
257
start_time ,
226
258
end_time ,
259
+ source_name ,
227
260
)
228
261
229
262
self ._run_query (query = query , args = args )
@@ -335,6 +368,24 @@ def log_new_wrapper(self, wrapper: OpenAIWrapper, init_args: Dict[str, Union[LLM
335
368
)
336
369
self ._run_query (query = query , args = args )
337
370
371
+ def log_function_use (self , source : Union [str , Agent ], function : F , args : Dict [str , Any ], returns : Any ) -> None :
372
+
373
+ if self .con is None :
374
+ return
375
+
376
+ query = """
377
+ INSERT INTO function_calls (source_id, source_name, function_name, args, returns, timestamp) VALUES (?, ?, ?, ?, ?, ?)
378
+ """
379
+ query_args : Tuple [Any , ...] = (
380
+ id (source ),
381
+ source .name if hasattr (source , "name" ) else source ,
382
+ function .__name__ ,
383
+ safe_serialize (args ),
384
+ safe_serialize (returns ),
385
+ get_current_ts (),
386
+ )
387
+ self ._run_query (query = query , args = query_args )
388
+
338
389
def log_new_client (
339
390
self , client : Union [AzureOpenAI , OpenAI , GeminiClient ], wrapper : OpenAIWrapper , init_args : Dict [str , Any ]
340
391
) -> None :
0 commit comments