Skip to content

Commit ff26451

Browse files
feat: use async lc tracer instead of run_sync (#1529)
Co-authored-by: Mathijs de Bruin <[email protected]>
1 parent 2bd47f5 commit ff26451

File tree

1 file changed

+55
-46
lines changed

1 file changed

+55
-46
lines changed

backend/chainlit/langchain/callbacks.py

+55-46
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
44
from uuid import UUID
55

6-
from chainlit.context import context_var
7-
from chainlit.message import Message
8-
from chainlit.step import Step
9-
from langchain.callbacks.tracers.base import BaseTracer
6+
import pydantic
107
from langchain.callbacks.tracers.schemas import Run
118
from langchain.schema import BaseMessage
129
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
10+
from langchain_core.tracers.base import AsyncBaseTracer
1311
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
1412
from literalai.helper import utc_now
1513
from literalai.observability.step import TrueStepType
1614

15+
from chainlit.context import context_var
16+
from chainlit.message import Message
17+
from chainlit.step import Step
18+
1719
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
1820

1921

@@ -122,6 +124,14 @@ def ensure_values_serializable(self, data):
122124
key: self.ensure_values_serializable(value)
123125
for key, value in data.items()
124126
}
127+
elif isinstance(data, pydantic.BaseModel):
128+
# Fallback to support pydantic v1
129+
# https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
130+
if pydantic.VERSION.startswith("1"):
131+
return data.dict()
132+
133+
# pydantic v2
134+
return data.model_dump() # pyright: ignore reportAttributeAccessIssue
125135
elif isinstance(data, list):
126136
return [self.ensure_values_serializable(item) for item in data]
127137
elif isinstance(data, (str, int, float, bool, type(None))):
@@ -249,7 +259,7 @@ def process_content(content: Any) -> Tuple[Dict, Optional[str]]:
249259
DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]
250260

251261

252-
class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
262+
class LangchainTracer(AsyncBaseTracer, GenerationHelper, FinalStreamHelper):
253263
steps: Dict[str, Step]
254264
parent_id_map: Dict[str, str]
255265
ignored_runs: set
@@ -268,7 +278,7 @@ def __init__(
268278
to_keep: Optional[List[str]] = None,
269279
**kwargs: Any,
270280
) -> None:
271-
BaseTracer.__init__(self, **kwargs)
281+
AsyncBaseTracer.__init__(self, **kwargs)
272282
GenerationHelper.__init__(self)
273283
FinalStreamHelper.__init__(
274284
self,
@@ -296,7 +306,7 @@ def __init__(
296306
else:
297307
self.to_keep = to_keep
298308

299-
def on_chat_model_start(
309+
async def on_chat_model_start(
300310
self,
301311
serialized: Dict[str, Any],
302312
messages: List[List[BaseMessage]],
@@ -305,8 +315,9 @@ def on_chat_model_start(
305315
parent_run_id: Optional["UUID"] = None,
306316
tags: Optional[List[str]] = None,
307317
metadata: Optional[Dict[str, Any]] = None,
318+
name: Optional[str] = None,
308319
**kwargs: Any,
309-
) -> Any:
320+
) -> Run:
310321
lc_messages = messages[0]
311322
self.chat_generations[str(run_id)] = {
312323
"input_messages": lc_messages,
@@ -315,54 +326,63 @@ def on_chat_model_start(
315326
"tt_first_token": None,
316327
}
317328

318-
return super().on_chat_model_start(
329+
return await super().on_chat_model_start(
319330
serialized,
320331
messages,
321332
run_id=run_id,
322333
parent_run_id=parent_run_id,
323334
tags=tags,
324335
metadata=metadata,
336+
name=name,
325337
**kwargs,
326338
)
327339

328-
def on_llm_start(
340+
async def on_llm_start(
329341
self,
330342
serialized: Dict[str, Any],
331343
prompts: List[str],
332344
*,
333345
run_id: "UUID",
346+
parent_run_id: Optional[UUID] = None,
334347
tags: Optional[List[str]] = None,
335-
parent_run_id: Optional["UUID"] = None,
336348
metadata: Optional[Dict[str, Any]] = None,
337-
name: Optional[str] = None,
338349
**kwargs: Any,
339-
) -> Run:
340-
self.completion_generations[str(run_id)] = {
341-
"prompt": prompts[0],
342-
"start": time.time(),
343-
"token_count": 0,
344-
"tt_first_token": None,
345-
}
346-
return super().on_llm_start(
350+
) -> None:
351+
await super().on_llm_start(
347352
serialized,
348353
prompts,
349354
run_id=run_id,
350355
parent_run_id=parent_run_id,
351356
tags=tags,
352357
metadata=metadata,
353-
name=name,
354358
**kwargs,
355359
)
356360

357-
def on_llm_new_token(
361+
self.completion_generations[str(run_id)] = {
362+
"prompt": prompts[0],
363+
"start": time.time(),
364+
"token_count": 0,
365+
"tt_first_token": None,
366+
}
367+
368+
return None
369+
370+
async def on_llm_new_token(
358371
self,
359372
token: str,
360373
*,
361374
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
362375
run_id: "UUID",
363376
parent_run_id: Optional["UUID"] = None,
364377
**kwargs: Any,
365-
) -> Run:
378+
) -> None:
379+
await super().on_llm_new_token(
380+
token=token,
381+
chunk=chunk,
382+
run_id=run_id,
383+
parent_run_id=parent_run_id,
384+
**kwargs,
385+
)
366386
if isinstance(chunk, ChatGenerationChunk):
367387
start = self.chat_generations[str(run_id)]
368388
else:
@@ -377,24 +397,13 @@ def on_llm_new_token(
377397
if self.answer_reached:
378398
if not self.final_stream:
379399
self.final_stream = Message(content="")
380-
self._run_sync(self.final_stream.send())
381-
self._run_sync(self.final_stream.stream_token(token))
400+
await self.final_stream.send()
401+
await self.final_stream.stream_token(token)
382402
self.has_streamed_final_answer = True
383403
else:
384404
self.answer_reached = self._check_if_answer_reached()
385405

386-
return super().on_llm_new_token(
387-
token,
388-
chunk=chunk,
389-
run_id=run_id,
390-
parent_run_id=parent_run_id,
391-
)
392-
393-
def _run_sync(self, co): # TODO: WHAT TO DO WITH THIS?
394-
context_var.set(self.context)
395-
self.context.loop.create_task(co)
396-
397-
def _persist_run(self, run: Run) -> None:
406+
async def _persist_run(self, run: Run) -> None:
398407
pass
399408

400409
def _get_run_parent_id(self, run: Run):
@@ -445,8 +454,8 @@ def _should_ignore_run(self, run: Run):
445454
self.ignored_runs.add(str(run.id))
446455
return ignore, parent_id
447456

448-
def _start_trace(self, run: Run) -> None:
449-
super()._start_trace(run)
457+
async def _start_trace(self, run: Run) -> None:
458+
await super()._start_trace(run)
450459
context_var.set(self.context)
451460

452461
ignore, parent_id = self._should_ignore_run(run)
@@ -489,9 +498,9 @@ def _start_trace(self, run: Run) -> None:
489498

490499
self.steps[str(run.id)] = step
491500

492-
self._run_sync(step.send())
501+
await step.send()
493502

494-
def _on_run_update(self, run: Run) -> None:
503+
async def _on_run_update(self, run: Run) -> None:
495504
"""Process a run upon update."""
496505
context_var.set(self.context)
497506

@@ -576,10 +585,10 @@ def _on_run_update(self, run: Run) -> None:
576585

577586
if current_step:
578587
current_step.end = utc_now()
579-
self._run_sync(current_step.update())
588+
await current_step.update()
580589

581590
if self.final_stream and self.has_streamed_final_answer:
582-
self._run_sync(self.final_stream.update())
591+
await self.final_stream.update()
583592

584593
return
585594

@@ -599,16 +608,16 @@ def _on_run_update(self, run: Run) -> None:
599608
else output
600609
)
601610
current_step.end = utc_now()
602-
self._run_sync(current_step.update())
611+
await current_step.update()
603612

604-
def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
613+
async def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
605614
context_var.set(self.context)
606615

607616
if current_step := self.steps.get(str(run_id), None):
608617
current_step.is_error = True
609618
current_step.output = str(error)
610619
current_step.end = utc_now()
611-
self._run_sync(current_step.update())
620+
await current_step.update()
612621

613622
on_llm_error = _on_error
614623
on_chain_error = _on_error

0 commit comments

Comments
 (0)