Skip to content

Commit

Permalink
Feature/add lightweight assistant (#759)
Browse files Browse the repository at this point in the history
* refactoring requests

* last cleanups

* last cleanups

* up

* add first pass assistant w/ cleanup
  • Loading branch information
emrgnt-cmplxty authored Jul 25, 2024
1 parent a11a907 commit ed657f2
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 22 deletions.
2 changes: 2 additions & 0 deletions r2r/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from .assistant import *

# Keep '*' imports for enhanced development velocity
# corresponding flake8 error codes are F403, F405
from .base import *
Expand Down
4 changes: 4 additions & 0 deletions r2r/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .abstractions.assistant import Assistant, AssistantConfig, Tool
from .abstractions.base import AsyncSyncMeta, UserStats, syncable
from .abstractions.document import (
DataType,
Expand Down Expand Up @@ -96,6 +97,9 @@
"RedisLoggingConfig",
"AsyncSyncMeta",
"syncable",
"Assistant",
"AssistantConfig",
"Tool",
"RedisKVLoggingProvider",
"KVLoggingSingleton",
"RunManager",
Expand Down
98 changes: 98 additions & 0 deletions r2r/base/abstractions/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence

from pydantic import BaseModel, Field

from ..providers.llm import GenerationConfig, LLMProvider


class Tool(BaseModel):
name: str
description: str
function: Callable

class Config:
arbitrary_types_allowed = True


class Message(BaseModel):
role: str
content: Optional[str] = None
name: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None


class Conversation:
def __init__(self):
self.messages: List[Message] = []

def add_message(
self,
role: str,
content: Optional[str] = None,
name: Optional[str] = None,
function_call: Optional[Dict[str, Any]] = None,
):
self.messages.append(
Message(
role=role,
content=content,
name=name,
function_call=function_call,
)
)

def get_messages(self) -> List[Dict[str, Any]]:
return [msg.dict(exclude_none=True) for msg in self.messages]


class AssistantConfig(BaseModel):
system_instruction: str
tools: List[Tool] = Field(default_factory=list)
generation_config: GenerationConfig = GenerationConfig()


class Assistant(ABC):
def __init__(
self,
instructions: str,
llm_provider: LLMProvider,
config: AssistantConfig,
):
self.instructions = instructions
self.llm_provider = llm_provider
self.config = config
self.conversation = Conversation()
self.completed = False
self._setup()

def _setup(self):
self.conversation.add_message("system", self.config.system_instruction)
self.conversation.add_message("system", self.instructions)

@property
def tools(self) -> Sequence[Tool]:
return self.config.tools

@abstractmethod
async def run(self) -> str:
pass

@abstractmethod
async def process_llm_response(self, response: Dict[str, Any]) -> str:
pass

def add_user_message(self, content: str):
self.conversation.add_message("user", content)

async def search(self, query: str) -> str:
# Implement the search functionality here
# This could involve calling an external search API or querying a local database
return f"Simulated search results for: {query}"

async def execute_tool(self, tool_name: str, **kwargs) -> str:
tool = next((t for t in self.tools if t.name == tool_name), None)
if tool:
return await tool.function(**kwargs)
else:
return f"Error: Tool {tool_name} not found."
1 change: 1 addition & 0 deletions r2r/base/abstractions/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
if TYPE_CHECKING:
from .search import AggregateSearchResult


LLMChatCompletion = ChatCompletion
LLMChatCompletionChunk = ChatCompletionChunk

Expand Down
1 change: 0 additions & 1 deletion r2r/base/providers/prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from abc import abstractmethod
from pathlib import Path
from typing import Any, Optional
Expand Down
8 changes: 4 additions & 4 deletions r2r/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,13 @@ def create_llm_provider(
) -> LLMProvider:
llm_provider: Optional[LLMProvider] = None
if llm_config.provider == "openai":
from r2r.providers import OpenAILLM
from r2r.providers import OpenAILLMProvider

llm_provider = OpenAILLM(llm_config)
llm_provider = OpenAILLMProvider(llm_config)
elif llm_config.provider == "litellm":
from r2r.providers import LiteLLM
from r2r.providers import LiteLLMProvider

llm_provider = LiteLLM(llm_config)
llm_provider = LiteLLMProvider(llm_config)
else:
raise ValueError(
f"Language model provider {llm_config.provider} not supported"
Expand Down
1 change: 0 additions & 1 deletion r2r/pipelines/search_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
from asyncio import Queue
from copy import copy
from typing import Any, Optional

from ..base.abstractions.search import (
Expand Down
3 changes: 0 additions & 3 deletions r2r/pipes/retrieval/kg_agent_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ async def _run_logic(
):
async for message in input.message:
# TODO - Remove hard code
formatted_prompt = self.prompt_provider.get_prompt(
"kg_agent", {"input": message}
)
messages = self.prompt_provider._get_message_payload(
task_prompt_name="kg_agent", task_inputs={"input": message}
)
Expand Down
6 changes: 3 additions & 3 deletions r2r/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from .eval import LLMEvalProvider
from .kg import Neo4jKGProvider
from .llm import LiteLLM, OpenAILLM
from .llm import LiteLLMProvider, OpenAILLMProvider
from .prompts import R2RPromptProvider

__all__ = [
Expand All @@ -23,7 +23,7 @@
"SentenceTransformerEmbeddingProvider",
"LLMEvalProvider",
"Neo4jKGProvider",
"OpenAILLM",
"LiteLLM",
"OpenAILLMProvider",
"LiteLLMProvider",
"R2RPromptProvider",
]
8 changes: 4 additions & 4 deletions r2r/providers/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .litellm import LiteLLM
from .openai import OpenAILLM
from .litellm import LiteLLMProvider
from .openai import OpenAILLMProvider

__all__ = [
"LiteLLM",
"OpenAILLM",
"LiteLLMProvider",
"OpenAILLMProvider",
]
6 changes: 3 additions & 3 deletions r2r/providers/llm/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
logger = logging.getLogger(__name__)


class LiteLLM(LLMProvider):
"""A concrete class for creating LiteLLM models with request throttling."""
class LiteLLMProvider(LLMProvider):
"""A concrete class for creating LiteLLMProvider models with request throttling."""

def __init__(
self,
Expand Down Expand Up @@ -133,7 +133,7 @@ def _get_base_args(
generation_config: GenerationConfig,
prompt=None,
) -> dict:
"""Get the base arguments for the LiteLLM API."""
"""Get the base arguments for the LiteLLMProvider API."""
args = {
"model": generation_config.model,
"temperature": generation_config.temperature,
Expand Down
6 changes: 3 additions & 3 deletions r2r/providers/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger = logging.getLogger(__name__)


class OpenAILLM(LLMProvider):
class OpenAILLMProvider(LLMProvider):
"""A concrete class for creating OpenAI models."""

def __init__(
Expand All @@ -32,11 +32,11 @@ def __init__(
from openai import AsyncOpenAI, OpenAI # noqa
except ImportError:
raise ImportError(
"Error, `openai` is required to run an OpenAILLM. Please install it using `pip install openai`."
"Error, `openai` is required to run an OpenAILLMProvider. Please install it using `pip install openai`."
)
if config.provider != "openai":
raise ValueError(
"OpenAILLM must be initialized with config with `openai` provider."
"OpenAILLMProvider must be initialized with config with `openai` provider."
)
if not os.getenv("OPENAI_API_KEY"):
raise ValueError(
Expand Down

0 comments on commit ed657f2

Please sign in to comment.