Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions portkey_ai/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .portkey_langchain_callback_handler import LangchainCallbackHandler

__all__ = ["LangchainCallbackHandler"]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
raise ImportError("Please pip install langchain-core to use PortkeyLangchain")


class PortkeyLangchain(BaseCallbackHandler):
class LangchainCallbackHandler(BaseCallbackHandler):
def __init__(
self,
api_key: str,
Expand Down
3 changes: 3 additions & 0 deletions portkey_ai/llamaindex/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .portkey_llama_callback_handler import LlamaIndexCallbackHandler

__all__ = ["LlamaIndexCallbackHandler"]
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
raise ImportError("Please pip install llama-index to use Portkey Callback Handler")


class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler):
class LlamaIndexCallbackHandler(LlamaIndexBaseCallbackHandler):
def __init__(
self,
api_key: str,
Expand Down Expand Up @@ -126,30 +126,38 @@ def on_event_end(
"""Run when an event ends."""
span_id = event_id

if event_type == "llm":
response_payload = self.llm_event_end(payload, event_id)
elif event_type == "embedding":
response_payload = self.embedding_event_end(payload, event_id)
elif event_type == "agent_step":
response_payload = self.agent_step_event_end(payload, event_id)
elif event_type == "function_call":
response_payload = self.function_call_event_end(payload, event_id)
elif event_type == "query":
response_payload = self.query_event_end(payload, event_id)
elif event_type == "retrieve":
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
if payload is None:
response_payload = {}
if span_id in self.event_map:
event = self.event_map[event_id]
start_time = event["start_time"]
end_time = int(datetime.now().timestamp())
total_time = (end_time - start_time) * 1000
response_payload["response_time"] = total_time
else:
response_payload = payload
if event_type == "llm":
response_payload = self.llm_event_end(payload, event_id)
elif event_type == "embedding":
response_payload = self.embedding_event_end(payload, event_id)
elif event_type == "agent_step":
response_payload = self.agent_step_event_end(payload, event_id)
elif event_type == "function_call":
response_payload = self.function_call_event_end(payload, event_id)
elif event_type == "query":
response_payload = self.query_event_end(payload, event_id)
elif event_type == "retrieve":
response_payload = self.retrieve_event_end(payload, event_id)
elif event_type == "templating":
response_payload = self.templating_event_end(payload, event_id)
else:
response_payload = payload

self.event_map[span_id]["response"] = response_payload

self.event_array.append(self.event_map[span_id])

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""

if trace_id == "index_construction":
self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr]

Expand Down Expand Up @@ -230,7 +238,7 @@ def llm_event_end(self, payload: Any, event_id) -> Any:
)
self.response["body"].update({"id": event_id})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": data.raw.get("model", "")})
self.response["body"].update({"model": getattr(data, "model", "")})
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

Expand Down
6 changes: 4 additions & 2 deletions portkey_ai/llms/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .chat import ChatPortkey
from .completion import PortkeyLLM
from .portkey_langchain_callback import PortkeyLangchain

__all__ = ["ChatPortkey", "PortkeyLLM", "PortkeyLangchain"]
__all__ = [
"ChatPortkey",
"PortkeyLLM",
]
3 changes: 0 additions & 3 deletions portkey_ai/llms/llama_index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .portkey_llama_callback import PortkeyLlamaindex

__all__ = ["PortkeyLlamaindex"]
4 changes: 2 additions & 2 deletions tests/test_llm_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from tests.utils import read_json_file
from portkey_ai.llms.langchain import PortkeyLangchain
from portkey_ai.langchain import LangchainCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
Expand All @@ -15,7 +15,7 @@


class TestLLMLangchain:
client = PortkeyLangchain
client = LangchainCallbackHandler
parametrize = pytest.mark.parametrize("client", [client], ids=["strict"])
models = read_json_file("./tests/models.json")

Expand Down
4 changes: 2 additions & 2 deletions tests/test_llm_llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from tests.utils import read_json_file
from portkey_ai.llms.llama_index import PortkeyLlamaindex
from portkey_ai.llamaindex import LlamaIndexCallbackHandler


from llama_index.llms.openai import OpenAI
Expand All @@ -24,7 +24,7 @@


class TestLLMLlamaindex:
client = PortkeyLlamaindex
client = LlamaIndexCallbackHandler
parametrize = pytest.mark.parametrize("client", [client], ids=["strict"])
models = read_json_file("./tests/models.json")

Expand Down