Skip to content

Commit

Permalink
add first pass assistant w/ cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Jul 25, 2024
1 parent 985333b commit 729cb38
Show file tree
Hide file tree
Showing 5 changed files with 1 addition and 159 deletions.
8 changes: 0 additions & 8 deletions r2r/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@
from .abstractions.exception import R2RDocumentProcessingError, R2RException
from .abstractions.llama_abstractions import VectorStoreQuery
from .abstractions.llm import (
FunctionCall,
GenerationConfig,
LLMChatCompletion,
LLMChatCompletionChunk,
LLMChatMessage,
LLMConversation,
LLMIterationResult,
RAGCompletion,
)
from .abstractions.prompt import Prompt
Expand Down Expand Up @@ -104,7 +100,6 @@
"Assistant",
"AssistantConfig",
"Tool",
"FunctionCall",
"RedisKVLoggingProvider",
"KVLoggingSingleton",
"RunManager",
Expand Down Expand Up @@ -156,13 +151,10 @@
"PromptConfig",
"PromptProvider",
"GenerationConfig",
"LLMChatMessage",
"RAGCompletion",
"VectorStoreQuery",
"LLMChatCompletion",
"LLMChatCompletionChunk",
"LLMConversation",
"LLMIterationResult",
"LLMConfig",
"LLMProvider",
"AuthConfig",
Expand Down
147 changes: 1 addition & 146 deletions r2r/base/abstractions/llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
"""Abstractions for the LLM model."""

import json
import re
from typing import (
TYPE_CHECKING,
ClassVar,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
from typing import TYPE_CHECKING, ClassVar, Optional

from openai.types.chat import ChatCompletion, ChatCompletionChunk
from pydantic import BaseModel, Field
Expand All @@ -23,141 +13,6 @@
LLMChatCompletionChunk = ChatCompletionChunk


class FunctionCall(NamedTuple):
"""A class representing function call to be made by the OpenAI agent."""

name: str
arguments: dict[str, str]

def to_dict(self) -> dict[str, Union[dict[str, str], str]]:
"""Convert the function call to a dictionary."""

return {
"name": self.name,
"arguments": json.dumps(self.arguments),
}

@classmethod
def from_response_dict(
cls, response_dict: dict[str, str]
) -> "FunctionCall":
"""Create a FunctionCall from a response dictionary."""

def preprocess_json_string(json_string: str) -> str:
"""Preprocess the JSON string to handle control characters."""
import re

# Match only the newline characters that are not preceded by a backslash
json_string = re.sub(r"(?<!\\)\n", "\\n", json_string)
# Do the same for tabs or any other control characters
json_string = re.sub(r"(?<!\\)\t", "\\t", json_string)
return json_string

if (
response_dict["name"] == "call-termination"
and '"result":' in response_dict["arguments"]
):
return cls(
name=response_dict["name"],
arguments=FunctionCall.handle_termination(
response_dict["arguments"]
),
)
try:
return cls(
name=response_dict["name"],
arguments=json.loads(
preprocess_json_string(response_dict["arguments"])
),
)
except Exception as e:
# TODO - put robust infra so this bubbles back up to the agent
return cls(
name="error-occurred",
arguments={"error": f"Error occurred: {e}"},
)

@staticmethod
def handle_termination(arguments: str) -> dict[str, str]:
"""
Handle the termination message from the conversation.
Note/FIXME - This is a hacky solution to the problem of parsing Markdown
with JSON. It needs to be made more robust and generalizable.
Further, we need to be sure that this is adequate to solve all
possible problems we might face due to adopting a Markdown return format.
"""

try:
return json.loads(arguments)
except json.decoder.JSONDecodeError as e:
split_result = arguments.split('{"result":')
if len(split_result) <= 1:
raise ValueError(
"Invalid arguments for call-termination"
) from e
result_str = split_result[1].strip().replace('"}', "")
if result_str[0] != '"':
raise ValueError(
"Invalid format for call-termination arguments"
) from e
result_str = result_str[1:]
return {"result": result_str}

def __str__(self) -> str:
return json.dumps(self._asdict())


class LLMChatMessage(BaseModel):
"""Base class for different types of LLM chat messages."""

role: str
content: Optional[str] = None
function_call: Optional[FunctionCall] = (None,)

def to_dict(self) -> dict[str, str]:
return {"role": self.role, "content": self.content}


LLMIterationResult = Optional[Tuple[LLMChatMessage, LLMChatMessage]]


class LLMConversation:
"""A class to represent a conversation with the OpenAI API."""

def __init__(self) -> None:
super().__init__()
self._messages: list[LLMChatMessage] = []

def __len__(self) -> int:
return len(self._messages)

@property
def messages(self) -> Sequence[LLMChatMessage]:
return self._messages

def add_message(self, message: LLMChatMessage) -> None:
"""Add a message to the conversation."""

if not isinstance(message, LLMChatMessage):
raise Exception(
f"Message must be of type {LLMChatMessage}, but got {type(message)}"
)
self._messages.append(message)

def to_dictarray(self) -> list[dict[str, any]]:
"""Get the messages for the next completion."""
return [message.to_dict() for message in self._messages]

def get_latest_message(self) -> LLMChatMessage:
"""Get the latest message in the conversation."""
return self._messages[-1]

def reset_conversation(self) -> None:
"""Reset the conversation."""
self._messages = []


class RAGCompletion:
completion: LLMChatCompletion
search_results: "AggregateSearchResult"
Expand Down
1 change: 0 additions & 1 deletion r2r/base/providers/prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from abc import abstractmethod
from pathlib import Path
from typing import Any, Optional
Expand Down
1 change: 0 additions & 1 deletion r2r/pipelines/search_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
from asyncio import Queue
from copy import copy
from typing import Any, Optional

from ..base.abstractions.search import (
Expand Down
3 changes: 0 additions & 3 deletions r2r/pipes/retrieval/kg_agent_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ async def _run_logic(
):
async for message in input.message:
# TODO - Remove hard code
formatted_prompt = self.prompt_provider.get_prompt(
"kg_agent", {"input": message}
)
messages = self.prompt_provider._get_message_payload(
task_prompt_name="kg_agent", task_inputs={"input": message}
)
Expand Down

0 comments on commit 729cb38

Please sign in to comment.