Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions portkey_ai/langchain/portkey_langchain_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from uuid import UUID, uuid4
from portkey_ai.api_resources.apis.logger import Logger
import re
from datetime import datetime

try:
from langchain_core.callbacks.base import BaseCallbackHandler
Expand All @@ -17,12 +16,13 @@ class LangchainCallbackHandler(BaseCallbackHandler):
def __init__(
self,
api_key: str,
metadata: Optional[Dict[str, Any]] = {},
metadata: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()

self.api_key = api_key
self.metadata = metadata
self.metadata: Dict[str, Any] = metadata or {}
self.metadata.update({"_source": "Langchain", "_source_type": "Agent"})

self.portkey_logger = Logger(api_key=api_key)

Expand Down Expand Up @@ -68,9 +68,11 @@ def on_llm_start(
info_obj = self.start_event_information(
run_id,
parent_run_id,
"llm_start",
serialized.get("name", "llm"),
self.global_trace_id,
request_payload,
"llm",
tags,
self.metadata,
)
self.event_map["llm_start_" + str(run_id)] = info_obj
Expand All @@ -88,8 +90,8 @@ def on_llm_end(
"""Run when LLM ends running."""

start_time = self.event_map["llm_start_" + str(run_id)]["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

response_payload = self.on_llm_end_transformer(response, kwargs=kwargs)
self.event_map["llm_start_" + str(run_id)]["response"] = response_payload
Expand Down Expand Up @@ -126,12 +128,13 @@ def on_chain_start(
info_obj = self.start_event_information(
run_id,
parent_span_id,
"chain_start",
request_payload.get("name", "chain"),
self.global_trace_id,
request_payload,
"chain",
tags,
self.metadata,
)

self.event_map["chain_start_" + str(run_id)] = info_obj
pass

Expand All @@ -147,8 +150,8 @@ def on_chain_end(
"""Run when chain ends running."""

start_time = self.event_map["chain_start_" + str(run_id)]["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

response_payload = self.on_chain_end_transformer(outputs)

Expand Down Expand Up @@ -185,9 +188,11 @@ def on_tool_start(
info_obj = self.start_event_information(
run_id,
parent_run_id,
"tool_start",
request_payload.get("serialized", {}).get("name", "tool"),
self.global_trace_id,
request_payload,
"tool",
tags,
self.metadata,
)
self.event_map["tool_start_" + str(run_id)] = info_obj
Expand All @@ -205,8 +210,8 @@ def on_tool_end(
"""Run when tool ends running."""

start_time = self.event_map["tool_start_" + str(run_id)]["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
end_time = time.time()
total_time = f"{((end_time - start_time) * 1000):04.0f}"

response_payload = self.on_tool_end_transformer(output)
self.event_map["tool_start_" + str(run_id)]["response"] = response_payload
Expand Down Expand Up @@ -250,17 +255,28 @@ def start_event_information(
span_name,
trace_id,
request_payload,
span_type,
tags,
metadata=None,
):
start_time = int(datetime.now().timestamp())
start_time = time.time()
source_metadata = {}
portkey_metadata = {}
portkey_metadata.update(metadata)
if span_type:
source_metadata.update({"type": span_type})
if len(tags):
source_metadata.update({"tags": tags})
if source_metadata:
portkey_metadata.update({"source_metadata": json.dumps(source_metadata)})
return {
"span_id": str(span_id),
"parent_span_id": str(parent_span_id),
"span_name": span_name,
"trace_id": trace_id,
"request": request_payload,
"start_time": start_time,
"metadata": metadata,
"metadata": portkey_metadata,
}

def serialize(self, obj):
Expand Down