Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions portkey_ai/llms/llama_index/portkey_llama_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,30 +126,38 @@ def on_event_end(
"""Run when an event ends."""
span_id = event_id

if event_type == "llm":
response_payload = self.llm_event_end(payload, event_id)
elif event_type == "embedding":
response_payload = self.embedding_event_end(payload, event_id)
elif event_type == "agent_step":
response_payload = self.agent_step_event_end(payload, event_id)
elif event_type == "function_call":
response_payload = self.function_call_event_end(payload, event_id)
elif event_type == "query":
response_payload = self.query_event_end(payload, event_id)
elif event_type == "retrieve":
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
if payload is None:
response_payload = {}
if span_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
response_payload["response_time"] = total_time
else:
response_payload = payload
if event_type == "llm":
response_payload = self.llm_event_end(payload, event_id)
elif event_type == "embedding":
response_payload = self.embedding_event_end(payload, event_id)
elif event_type == "agent_step":
response_payload = self.agent_step_event_end(payload, event_id)
elif event_type == "function_call":
response_payload = self.function_call_event_end(payload, event_id)
elif event_type == "query":
response_payload = self.query_event_end(payload, event_id)
elif event_type == "retrieve":
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
else:
response_payload = payload

self.event_map[span_id]["response"] = response_payload

self.event_array.append(self.event_map[span_id])

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""

if trace_id == "index_construction":
self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr]

Expand Down Expand Up @@ -230,7 +238,7 @@ def llm_event_end(self, payload: Any, event_id) -> Any:
)
self.response["body"].update({"id": event_id})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": data.raw.get("model", "")})
self.response["body"].update({"model": getattr(data, "model", "")})
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

Expand Down