diff --git a/portkey_ai/langchain/__init__.py b/portkey_ai/langchain/__init__.py new file mode 100644 index 00000000..4ee80547 --- /dev/null +++ b/portkey_ai/langchain/__init__.py @@ -0,0 +1,3 @@ +from .portkey_langchain_callback_handler import LangchainCallbackHandler + +__all__ = ["LangchainCallbackHandler"] diff --git a/portkey_ai/llms/langchain/portkey_langchain_callback.py b/portkey_ai/langchain/portkey_langchain_callback_handler.py similarity index 99% rename from portkey_ai/llms/langchain/portkey_langchain_callback.py rename to portkey_ai/langchain/portkey_langchain_callback_handler.py index c45864f9..7d2899dc 100644 --- a/portkey_ai/llms/langchain/portkey_langchain_callback.py +++ b/portkey_ai/langchain/portkey_langchain_callback_handler.py @@ -13,7 +13,7 @@ raise ImportError("Please pip install langchain-core to use PortkeyLangchain") -class PortkeyLangchain(BaseCallbackHandler): +class LangchainCallbackHandler(BaseCallbackHandler): def __init__( self, api_key: str, diff --git a/portkey_ai/llamaindex/__init__.py b/portkey_ai/llamaindex/__init__.py new file mode 100644 index 00000000..55e48a8c --- /dev/null +++ b/portkey_ai/llamaindex/__init__.py @@ -0,0 +1,3 @@ +from .portkey_llama_callback_handler import LlamaIndexCallbackHandler + +__all__ = ["LlamaIndexCallbackHandler"] diff --git a/portkey_ai/llms/llama_index/portkey_llama_callback.py b/portkey_ai/llamaindex/portkey_llama_callback_handler.py similarity index 93% rename from portkey_ai/llms/llama_index/portkey_llama_callback.py rename to portkey_ai/llamaindex/portkey_llama_callback_handler.py index 0d5c77a0..615fcfe1 100644 --- a/portkey_ai/llms/llama_index/portkey_llama_callback.py +++ b/portkey_ai/llamaindex/portkey_llama_callback_handler.py @@ -23,7 +23,7 @@ raise ImportError("Please pip install llama-index to use Portkey Callback Handler") -class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler): +class LlamaIndexCallbackHandler(LlamaIndexBaseCallbackHandler): def __init__( self, api_key: str, @@ -126,22 +126,31 @@ def on_event_end( """Run when an event ends.""" span_id = event_id - if event_type == "llm": - response_payload = self.llm_event_end(payload, event_id) - elif event_type == "embedding": - response_payload = self.embedding_event_end(payload, event_id) - elif event_type == "agent_step": - response_payload = self.agent_step_event_end(payload, event_id) - elif event_type == "function_call": - response_payload = self.function_call_event_end(payload, event_id) - elif event_type == "query": - response_payload = self.query_event_end(payload, event_id) - elif event_type == "retrieve": - response_payload = self.retrieve_event_end(payload, event_id) - elif event_type == "templating": - response_payload = self.templating_event_end(payload, event_id) + if payload is None: + response_payload = {} + if span_id in self.event_map: + event = self.event_map[event_id] + start_time = event["start_time"] + end_time = int(datetime.now().timestamp()) + total_time = (end_time - start_time) * 1000 + response_payload["response_time"] = total_time else: - response_payload = payload + if event_type == "llm": + response_payload = self.llm_event_end(payload, event_id) + elif event_type == "embedding": + response_payload = self.embedding_event_end(payload, event_id) + elif event_type == "agent_step": + response_payload = self.agent_step_event_end(payload, event_id) + elif event_type == "function_call": + response_payload = self.function_call_event_end(payload, event_id) + elif event_type == "query": + response_payload = self.query_event_end(payload, event_id) + elif event_type == "retrieve": + response_payload = self.retrieve_event_end(payload, event_id) + elif event_type == "templating": + response_payload = self.templating_event_end(payload, event_id) + else: + response_payload = payload self.event_map[span_id]["response"] = response_payload @@ -149,7 +158,6 @@ def on_event_end( def start_trace(self, trace_id: Optional[str] = None) -> None: """Run when an overall trace is launched.""" - if trace_id == "index_construction": self.global_trace_id = self.metadata.get("traceId", str(uuid4())) # type: ignore [union-attr] @@ -230,7 +238,7 @@ def llm_event_end(self, payload: Any, event_id) -> Any: ) self.response["body"].update({"id": event_id}) self.response["body"].update({"created": int(time.time())}) - self.response["body"].update({"model": data.raw.get("model", "")}) + self.response["body"].update({"model": getattr(data, "model", "")}) self.response["headers"] = {} self.response["streamingMode"] = self.streamingMode diff --git a/portkey_ai/llms/langchain/__init__.py b/portkey_ai/llms/langchain/__init__.py index de0f7f0d..07e0ff1d 100644 --- a/portkey_ai/llms/langchain/__init__.py +++ b/portkey_ai/llms/langchain/__init__.py @@ -1,5 +1,7 @@ from .chat import ChatPortkey from .completion import PortkeyLLM -from .portkey_langchain_callback import PortkeyLangchain -__all__ = ["ChatPortkey", "PortkeyLLM", "PortkeyLangchain"] +__all__ = [ + "ChatPortkey", + "PortkeyLLM", +] diff --git a/portkey_ai/llms/llama_index/__init__.py b/portkey_ai/llms/llama_index/__init__.py index 9530d3eb..e69de29b 100644 --- a/portkey_ai/llms/llama_index/__init__.py +++ b/portkey_ai/llms/llama_index/__init__.py @@ -1,3 +0,0 @@ -from .portkey_llama_callback import PortkeyLlamaindex - -__all__ = ["PortkeyLlamaindex"] diff --git a/tests/test_llm_langchain.py b/tests/test_llm_langchain.py index 3441d85e..8cfe75e8 100644 --- a/tests/test_llm_langchain.py +++ b/tests/test_llm_langchain.py @@ -6,7 +6,7 @@ import pytest from tests.utils import read_json_file -from portkey_ai.llms.langchain import PortkeyLangchain +from portkey_ai.langchain import LangchainCallbackHandler from langchain.chat_models import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain.chains import LLMChain @@ -15,7 +15,7 @@ class TestLLMLangchain: - client = PortkeyLangchain + client = LangchainCallbackHandler parametrize = pytest.mark.parametrize("client", [client], ids=["strict"]) models = read_json_file("./tests/models.json") diff --git a/tests/test_llm_llamaindex.py b/tests/test_llm_llamaindex.py index 959dd81d..a0e9fff6 100644 --- a/tests/test_llm_llamaindex.py +++ b/tests/test_llm_llamaindex.py @@ -6,7 +6,7 @@ import pytest from tests.utils import read_json_file -from portkey_ai.llms.llama_index import PortkeyLlamaindex +from portkey_ai.llamaindex import LlamaIndexCallbackHandler from llama_index.llms.openai import OpenAI @@ -24,7 +24,7 @@ class TestLLMLlamaindex: - client = PortkeyLlamaindex + client = LlamaIndexCallbackHandler parametrize = pytest.mark.parametrize("client", [client], ids=["strict"]) models = read_json_file("./tests/models.json")