diff --git a/portkey_ai/langchain/portkey_langchain_callback_handler.py b/portkey_ai/langchain/portkey_langchain_callback_handler.py index c6a1d163..945f758d 100644 --- a/portkey_ai/langchain/portkey_langchain_callback_handler.py +++ b/portkey_ai/langchain/portkey_langchain_callback_handler.py @@ -5,7 +5,6 @@ from uuid import UUID, uuid4 from portkey_ai.api_resources.apis.logger import Logger import re -from datetime import datetime try: from langchain_core.callbacks.base import BaseCallbackHandler @@ -17,12 +16,13 @@ class LangchainCallbackHandler(BaseCallbackHandler): def __init__( self, api_key: str, - metadata: Optional[Dict[str, Any]] = {}, + metadata: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() self.api_key = api_key - self.metadata = metadata + self.metadata: Dict[str, Any] = metadata or {} + self.metadata.update({"_source": "Langchain", "_source_type": "Agent"}) self.portkey_logger = Logger(api_key=api_key) @@ -68,9 +68,11 @@ def on_llm_start( info_obj = self.start_event_information( run_id, parent_run_id, - "llm_start", + serialized.get("name", "llm"), self.global_trace_id, request_payload, + "llm", + tags, self.metadata, ) self.event_map["llm_start_" + str(run_id)] = info_obj @@ -88,8 +90,8 @@ def on_llm_end( """Run when LLM ends running.""" start_time = self.event_map["llm_start_" + str(run_id)]["start_time"] - end_time = int(datetime.now().timestamp()) - total_time = (end_time - start_time) * 1000 + end_time = time.time() + total_time = f"{((end_time - start_time) * 1000):04.0f}" response_payload = self.on_llm_end_transformer(response, kwargs=kwargs) self.event_map["llm_start_" + str(run_id)]["response"] = response_payload @@ -126,12 +128,13 @@ def on_chain_start( info_obj = self.start_event_information( run_id, parent_span_id, - "chain_start", + request_payload.get("name", "chain"), self.global_trace_id, request_payload, + "chain", + tags, self.metadata, ) - self.event_map["chain_start_" + str(run_id)] = info_obj pass @@ -147,8 +150,8 @@ def on_chain_end( """Run when chain ends running.""" start_time = self.event_map["chain_start_" + str(run_id)]["start_time"] - end_time = int(datetime.now().timestamp()) - total_time = (end_time - start_time) * 1000 + end_time = time.time() + total_time = f"{((end_time - start_time) * 1000):04.0f}" response_payload = self.on_chain_end_transformer(outputs) @@ -185,9 +188,11 @@ def on_tool_start( info_obj = self.start_event_information( run_id, parent_run_id, - "tool_start", + request_payload.get("serialized", {}).get("name", "tool"), self.global_trace_id, request_payload, + "tool", + tags, self.metadata, ) self.event_map["tool_start_" + str(run_id)] = info_obj @@ -205,8 +210,8 @@ def on_tool_end( """Run when tool ends running.""" start_time = self.event_map["tool_start_" + str(run_id)]["start_time"] - end_time = int(datetime.now().timestamp()) - total_time = (end_time - start_time) * 1000 + end_time = time.time() + total_time = f"{((end_time - start_time) * 1000):04.0f}" response_payload = self.on_tool_end_transformer(output) self.event_map["tool_start_" + str(run_id)]["response"] = response_payload @@ -250,9 +255,20 @@ def start_event_information( span_name, trace_id, request_payload, + span_type, + tags, metadata=None, ): - start_time = int(datetime.now().timestamp()) + start_time = time.time() + source_metadata = {} + portkey_metadata = {} + portkey_metadata.update(metadata) + if span_type: + source_metadata.update({"type": span_type}) + if len(tags): + source_metadata.update({"tags": tags}) + if source_metadata: + portkey_metadata.update({"source_metadata": json.dumps(source_metadata)}) return { "span_id": str(span_id), "parent_span_id": str(parent_span_id), @@ -260,7 +276,7 @@ def start_event_information( "trace_id": trace_id, "request": request_payload, "start_time": start_time, - "metadata": metadata, + "metadata": portkey_metadata, } def serialize(self, obj):