3
3
from typing import Any , Dict , List , Optional , Sequence , Tuple , TypedDict , Union
4
4
from uuid import UUID
5
5
6
- from chainlit .context import context_var
7
- from chainlit .message import Message
8
- from chainlit .step import Step
9
- from langchain .callbacks .tracers .base import BaseTracer
6
+ import pydantic
10
7
from langchain .callbacks .tracers .schemas import Run
11
8
from langchain .schema import BaseMessage
12
9
from langchain_core .outputs import ChatGenerationChunk , GenerationChunk
10
+ from langchain_core .tracers .base import AsyncBaseTracer
13
11
from literalai import ChatGeneration , CompletionGeneration , GenerationMessage
14
12
from literalai .helper import utc_now
15
13
from literalai .observability .step import TrueStepType
16
14
15
+ from chainlit .context import context_var
16
+ from chainlit .message import Message
17
+ from chainlit .step import Step
18
+
17
19
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final" , "Answer" , ":" ]
18
20
19
21
@@ -122,6 +124,14 @@ def ensure_values_serializable(self, data):
122
124
key : self .ensure_values_serializable (value )
123
125
for key , value in data .items ()
124
126
}
127
+ elif isinstance (data , pydantic .BaseModel ):
128
+ # Fallback to support pydantic v1
129
+ # https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
130
+ if pydantic .VERSION .startswith ("1" ):
131
+ return data .dict ()
132
+
133
+ # pydantic v2
134
+ return data .model_dump () # pyright: ignore reportAttributeAccessIssue
125
135
elif isinstance (data , list ):
126
136
return [self .ensure_values_serializable (item ) for item in data ]
127
137
elif isinstance (data , (str , int , float , bool , type (None ))):
@@ -249,7 +259,7 @@ def process_content(content: Any) -> Tuple[Dict, Optional[str]]:
249
259
DEFAULT_TO_KEEP = ["retriever" , "llm" , "agent" , "chain" , "tool" ]
250
260
251
261
252
- class LangchainTracer (BaseTracer , GenerationHelper , FinalStreamHelper ):
262
+ class LangchainTracer (AsyncBaseTracer , GenerationHelper , FinalStreamHelper ):
253
263
steps : Dict [str , Step ]
254
264
parent_id_map : Dict [str , str ]
255
265
ignored_runs : set
@@ -268,7 +278,7 @@ def __init__(
268
278
to_keep : Optional [List [str ]] = None ,
269
279
** kwargs : Any ,
270
280
) -> None :
271
- BaseTracer .__init__ (self , ** kwargs )
281
+ AsyncBaseTracer .__init__ (self , ** kwargs )
272
282
GenerationHelper .__init__ (self )
273
283
FinalStreamHelper .__init__ (
274
284
self ,
@@ -296,7 +306,7 @@ def __init__(
296
306
else :
297
307
self .to_keep = to_keep
298
308
299
- def on_chat_model_start (
309
+ async def on_chat_model_start (
300
310
self ,
301
311
serialized : Dict [str , Any ],
302
312
messages : List [List [BaseMessage ]],
@@ -305,8 +315,9 @@ def on_chat_model_start(
305
315
parent_run_id : Optional ["UUID" ] = None ,
306
316
tags : Optional [List [str ]] = None ,
307
317
metadata : Optional [Dict [str , Any ]] = None ,
318
+ name : Optional [str ] = None ,
308
319
** kwargs : Any ,
309
- ) -> Any :
320
+ ) -> Run :
310
321
lc_messages = messages [0 ]
311
322
self .chat_generations [str (run_id )] = {
312
323
"input_messages" : lc_messages ,
@@ -315,54 +326,63 @@ def on_chat_model_start(
315
326
"tt_first_token" : None ,
316
327
}
317
328
318
- return super ().on_chat_model_start (
329
+ return await super ().on_chat_model_start (
319
330
serialized ,
320
331
messages ,
321
332
run_id = run_id ,
322
333
parent_run_id = parent_run_id ,
323
334
tags = tags ,
324
335
metadata = metadata ,
336
+ name = name ,
325
337
** kwargs ,
326
338
)
327
339
328
- def on_llm_start (
340
+ async def on_llm_start (
329
341
self ,
330
342
serialized : Dict [str , Any ],
331
343
prompts : List [str ],
332
344
* ,
333
345
run_id : "UUID" ,
346
+ parent_run_id : Optional [UUID ] = None ,
334
347
tags : Optional [List [str ]] = None ,
335
- parent_run_id : Optional ["UUID" ] = None ,
336
348
metadata : Optional [Dict [str , Any ]] = None ,
337
- name : Optional [str ] = None ,
338
349
** kwargs : Any ,
339
- ) -> Run :
340
- self .completion_generations [str (run_id )] = {
341
- "prompt" : prompts [0 ],
342
- "start" : time .time (),
343
- "token_count" : 0 ,
344
- "tt_first_token" : None ,
345
- }
346
- return super ().on_llm_start (
350
+ ) -> None :
351
+ await super ().on_llm_start (
347
352
serialized ,
348
353
prompts ,
349
354
run_id = run_id ,
350
355
parent_run_id = parent_run_id ,
351
356
tags = tags ,
352
357
metadata = metadata ,
353
- name = name ,
354
358
** kwargs ,
355
359
)
356
360
357
- def on_llm_new_token (
361
+ self .completion_generations [str (run_id )] = {
362
+ "prompt" : prompts [0 ],
363
+ "start" : time .time (),
364
+ "token_count" : 0 ,
365
+ "tt_first_token" : None ,
366
+ }
367
+
368
+ return None
369
+
370
+ async def on_llm_new_token (
358
371
self ,
359
372
token : str ,
360
373
* ,
361
374
chunk : Optional [Union [GenerationChunk , ChatGenerationChunk ]] = None ,
362
375
run_id : "UUID" ,
363
376
parent_run_id : Optional ["UUID" ] = None ,
364
377
** kwargs : Any ,
365
- ) -> Run :
378
+ ) -> None :
379
+ await super ().on_llm_new_token (
380
+ token = token ,
381
+ chunk = chunk ,
382
+ run_id = run_id ,
383
+ parent_run_id = parent_run_id ,
384
+ ** kwargs ,
385
+ )
366
386
if isinstance (chunk , ChatGenerationChunk ):
367
387
start = self .chat_generations [str (run_id )]
368
388
else :
@@ -377,24 +397,13 @@ def on_llm_new_token(
377
397
if self .answer_reached :
378
398
if not self .final_stream :
379
399
self .final_stream = Message (content = "" )
380
- self ._run_sync ( self . final_stream .send () )
381
- self ._run_sync ( self . final_stream .stream_token (token ) )
400
+ await self .final_stream .send ()
401
+ await self .final_stream .stream_token (token )
382
402
self .has_streamed_final_answer = True
383
403
else :
384
404
self .answer_reached = self ._check_if_answer_reached ()
385
405
386
- return super ().on_llm_new_token (
387
- token ,
388
- chunk = chunk ,
389
- run_id = run_id ,
390
- parent_run_id = parent_run_id ,
391
- )
392
-
393
- def _run_sync (self , co ): # TODO: WHAT TO DO WITH THIS?
394
- context_var .set (self .context )
395
- self .context .loop .create_task (co )
396
-
397
- def _persist_run (self , run : Run ) -> None :
406
+ async def _persist_run (self , run : Run ) -> None :
398
407
pass
399
408
400
409
def _get_run_parent_id (self , run : Run ):
@@ -445,8 +454,8 @@ def _should_ignore_run(self, run: Run):
445
454
self .ignored_runs .add (str (run .id ))
446
455
return ignore , parent_id
447
456
448
- def _start_trace (self , run : Run ) -> None :
449
- super ()._start_trace (run )
457
+ async def _start_trace (self , run : Run ) -> None :
458
+ await super ()._start_trace (run )
450
459
context_var .set (self .context )
451
460
452
461
ignore , parent_id = self ._should_ignore_run (run )
@@ -489,9 +498,9 @@ def _start_trace(self, run: Run) -> None:
489
498
490
499
self .steps [str (run .id )] = step
491
500
492
- self . _run_sync ( step .send () )
501
+ await step .send ()
493
502
494
- def _on_run_update (self , run : Run ) -> None :
503
+ async def _on_run_update (self , run : Run ) -> None :
495
504
"""Process a run upon update."""
496
505
context_var .set (self .context )
497
506
@@ -576,10 +585,10 @@ def _on_run_update(self, run: Run) -> None:
576
585
577
586
if current_step :
578
587
current_step .end = utc_now ()
579
- self . _run_sync ( current_step .update () )
588
+ await current_step .update ()
580
589
581
590
if self .final_stream and self .has_streamed_final_answer :
582
- self ._run_sync ( self . final_stream .update () )
591
+ await self .final_stream .update ()
583
592
584
593
return
585
594
@@ -599,16 +608,16 @@ def _on_run_update(self, run: Run) -> None:
599
608
else output
600
609
)
601
610
current_step .end = utc_now ()
602
- self . _run_sync ( current_step .update () )
611
+ await current_step .update ()
603
612
604
- def _on_error (self , error : BaseException , * , run_id : UUID , ** kwargs : Any ):
613
+ async def _on_error (self , error : BaseException , * , run_id : UUID , ** kwargs : Any ):
605
614
context_var .set (self .context )
606
615
607
616
if current_step := self .steps .get (str (run_id ), None ):
608
617
current_step .is_error = True
609
618
current_step .output = str (error )
610
619
current_step .end = utc_now ()
611
- self . _run_sync ( current_step .update () )
620
+ await current_step .update ()
612
621
613
622
on_llm_error = _on_error
614
623
on_chain_error = _on_error
0 commit comments