Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add lightweight assistant #759

Merged
merged 5 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions r2r/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from .assistant import *

# Keep '*' imports for enhanced development velocity
# corresponding flake8 error codes are F403, F405
from .base import *
Expand Down
4 changes: 4 additions & 0 deletions r2r/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .abstractions.assistant import Assistant, AssistantConfig, Tool
from .abstractions.base import AsyncSyncMeta, UserStats, syncable
from .abstractions.document import (
DataType,
Expand Down Expand Up @@ -96,6 +97,9 @@
"RedisLoggingConfig",
"AsyncSyncMeta",
"syncable",
"Assistant",
"AssistantConfig",
"Tool",
"RedisKVLoggingProvider",
"KVLoggingSingleton",
"RunManager",
Expand Down
98 changes: 98 additions & 0 deletions r2r/base/abstractions/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence

from pydantic import BaseModel, Field

from ..providers.llm import GenerationConfig, LLMProvider


class Tool(BaseModel):
name: str
description: str
function: Callable

class Config:
arbitrary_types_allowed = True


class Message(BaseModel):
role: str
content: Optional[str] = None
name: Optional[str] = None
function_call: Optional[Dict[str, Any]] = None


class Conversation:
def __init__(self):
self.messages: List[Message] = []

def add_message(
self,
role: str,
content: Optional[str] = None,
name: Optional[str] = None,
function_call: Optional[Dict[str, Any]] = None,
):
self.messages.append(
Message(
role=role,
content=content,
name=name,
function_call=function_call,
)
)

def get_messages(self) -> List[Dict[str, Any]]:
return [msg.dict(exclude_none=True) for msg in self.messages]


class AssistantConfig(BaseModel):
system_instruction: str
tools: List[Tool] = Field(default_factory=list)
generation_config: GenerationConfig = GenerationConfig()


class Assistant(ABC):
def __init__(
self,
instructions: str,
llm_provider: LLMProvider,
config: AssistantConfig,
):
self.instructions = instructions
self.llm_provider = llm_provider
self.config = config
self.conversation = Conversation()
self.completed = False
self._setup()

def _setup(self):
self.conversation.add_message("system", self.config.system_instruction)
self.conversation.add_message("system", self.instructions)

@property
def tools(self) -> Sequence[Tool]:
return self.config.tools

@abstractmethod
async def run(self) -> str:
pass

@abstractmethod
async def process_llm_response(self, response: Dict[str, Any]) -> str:
pass

def add_user_message(self, content: str):
self.conversation.add_message("user", content)

async def search(self, query: str) -> str:
# Implement the search functionality here
# This could involve calling an external search API or querying a local database
return f"Simulated search results for: {query}"

async def execute_tool(self, tool_name: str, **kwargs) -> str:
tool = next((t for t in self.tools if t.name == tool_name), None)
if tool:
return await tool.function(**kwargs)
else:
return f"Error: Tool {tool_name} not found."
1 change: 1 addition & 0 deletions r2r/base/abstractions/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
if TYPE_CHECKING:
from .search import AggregateSearchResult


LLMChatCompletion = ChatCompletion
LLMChatCompletionChunk = ChatCompletionChunk

Expand Down
1 change: 0 additions & 1 deletion r2r/base/providers/prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import os
from abc import abstractmethod
from pathlib import Path
from typing import Any, Optional
Expand Down
11 changes: 7 additions & 4 deletions r2r/main/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from .abstractions import R2RPipelines, R2RProviders
from .api.client import R2RClient
from .api.requests import (
from .api.routes.ingestion.requests import (
R2RIngestFilesRequest,
R2RUpdateFilesRequest,
)
from .api.routes.management.requests import (
R2RAnalyticsRequest,
R2RDeleteRequest,
R2RDocumentChunksRequest,
R2RDocumentsOverviewRequest,
R2RUpdatePromptRequest,
R2RUsersOverviewRequest,
)
from .api.routes.ingestion import R2RIngestFilesRequest, R2RUpdateFilesRequest
from .api.routes.management import R2RUpdatePromptRequest
from .api.routes.retrieval import (
from .api.routes.retrieval.requests import (
R2REvalRequest,
R2RRAGRequest,
R2RSearchRequest,
Expand Down
15 changes: 8 additions & 7 deletions r2r/main/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@
import requests
from fastapi.testclient import TestClient

from r2r.base import R2RException, UserCreate
from r2r.base import AnalysisTypes, FilterCriteria, R2RException, UserCreate

from .requests import (
AnalysisTypes,
FilterCriteria,
from .routes.ingestion.requests import (
R2RIngestFilesRequest,
R2RUpdateFilesRequest,
)
from .routes.management.requests import (
R2RAnalyticsRequest,
R2RDeleteRequest,
R2RDocumentChunksRequest,
R2RDocumentsOverviewRequest,
R2RLogsRequest,
R2RPrintRelationshipsRequest,
R2RUpdatePromptRequest,
R2RUsersOverviewRequest,
)
from .routes.ingestion import R2RIngestFilesRequest, R2RUpdateFilesRequest
from .routes.management import R2RUpdatePromptRequest
from .routes.retrieval import R2RRAGRequest, R2RSearchRequest
from .routes.retrieval.requests import R2RRAGRequest, R2RSearchRequest

nest_asyncio.apply()

Expand Down
Empty file.
22 changes: 7 additions & 15 deletions r2r/main/api/routes/auth.py → r2r/main/api/routes/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
from pydantic import BaseModel

from r2r.base import Token, User, UserCreate
from r2r.main.api.routes.auth.requests import (
PasswordChangeRequest,
PasswordResetConfirmRequest,
PasswordResetRequest,
)

from ...engine import R2REngine
from .base_router import BaseRouter
from ....engine import R2REngine
from ..base_router import BaseRouter

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

Expand All @@ -18,19 +23,6 @@ class TokenResponse(BaseModel):
results: dict[str, Token]


class PasswordChangeRequest(BaseModel):
current_password: str
new_password: str


class PasswordResetRequest(BaseModel):
email: str


class PasswordResetConfirmRequest(BaseModel):
new_password: str


class UserProfileUpdate(BaseModel):
email: str | None = None
name: str | None = None
Expand Down
14 changes: 14 additions & 0 deletions r2r/main/api/routes/auth/requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pydantic import BaseModel


class PasswordChangeRequest(BaseModel):
current_password: str
new_password: str


class PasswordResetRequest(BaseModel):
email: str


class PasswordResetConfirmRequest(BaseModel):
new_password: str
Empty file.
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
import uuid
from typing import Optional

from fastapi import Depends, File, UploadFile
from pydantic import BaseModel

from ...engine import R2REngine
from ...services.ingestion_service import IngestionService
from .base_router import BaseRouter


class R2RIngestFilesRequest(BaseModel):
document_ids: Optional[list[uuid.UUID]] = None
metadatas: Optional[list[dict]] = None
versions: Optional[list[str]] = None

from r2r.main.api.routes.ingestion.requests import (
R2RIngestFilesRequest,
R2RUpdateFilesRequest,
)

class R2RUpdateFilesRequest(BaseModel):
metadatas: Optional[list[dict]] = None
document_ids: Optional[list[uuid.UUID]] = None
from ....engine import R2REngine
from ....services.ingestion_service import IngestionService
from ..base_router import BaseRouter


class IngestionRouter(BaseRouter):
Expand Down
15 changes: 15 additions & 0 deletions r2r/main/api/routes/ingestion/requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import uuid
from typing import Optional

from pydantic import BaseModel


class R2RUpdateFilesRequest(BaseModel):
metadatas: Optional[list[dict]] = None
document_ids: Optional[list[uuid.UUID]] = None


class R2RIngestFilesRequest(BaseModel):
document_ids: Optional[list[uuid.UUID]] = None
metadatas: Optional[list[dict]] = None
versions: Optional[list[str]] = None
Empty file.
Original file line number Diff line number Diff line change
@@ -1,51 +1,20 @@
import uuid
from typing import Optional, Union

from fastapi import Depends
from pydantic import BaseModel

from r2r.base import AnalysisTypes, FilterCriteria, R2RException

from ...engine import R2REngine
from .base_router import BaseRouter


class R2RUpdatePromptRequest(BaseModel):
name: str
template: Optional[str] = None
input_types: Optional[dict[str, str]] = {}


class R2RDeleteRequest(BaseModel):
keys: list[str]
values: list[Union[bool, int, str]]


class R2RAnalyticsRequest(BaseModel):
filter_criteria: FilterCriteria
analysis_types: AnalysisTypes


class R2RUsersOverviewRequest(BaseModel):
user_ids: Optional[list[uuid.UUID]]


class R2RDocumentsOverviewRequest(BaseModel):
document_ids: Optional[list[uuid.UUID]]
user_ids: Optional[list[uuid.UUID]]


class R2RDocumentChunksRequest(BaseModel):
document_id: uuid.UUID


class R2RLogsRequest(BaseModel):
log_type_filter: Optional[str] = (None,)
max_runs_requested: int = 100


class R2RPrintRelationshipsRequest(BaseModel):
limit: int = 100
from r2r.base import R2RException
from r2r.main.api.routes.management.requests import (
R2RAnalyticsRequest,
R2RDeleteRequest,
R2RDocumentChunksRequest,
R2RDocumentsOverviewRequest,
R2RLogsRequest,
R2RPrintRelationshipsRequest,
R2RUpdatePromptRequest,
R2RUsersOverviewRequest,
)

from ....engine import R2REngine
from ..base_router import BaseRouter


class ManagementRouter(BaseRouter):
Expand Down
44 changes: 44 additions & 0 deletions r2r/main/api/routes/management/requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import uuid
from typing import Optional, Union

from pydantic import BaseModel

from r2r.base import AnalysisTypes, FilterCriteria


class R2RUpdatePromptRequest(BaseModel):
name: str
template: Optional[str] = None
input_types: Optional[dict[str, str]] = {}


class R2RDeleteRequest(BaseModel):
keys: list[str]
values: list[Union[bool, int, str]]


class R2RAnalyticsRequest(BaseModel):
filter_criteria: FilterCriteria
analysis_types: AnalysisTypes


class R2RUsersOverviewRequest(BaseModel):
user_ids: Optional[list[uuid.UUID]]


class R2RDocumentsOverviewRequest(BaseModel):
document_ids: Optional[list[uuid.UUID]]
user_ids: Optional[list[uuid.UUID]]


class R2RDocumentChunksRequest(BaseModel):
document_id: uuid.UUID


class R2RLogsRequest(BaseModel):
log_type_filter: Optional[str] = (None,)
max_runs_requested: int = 100


class R2RPrintRelationshipsRequest(BaseModel):
limit: int = 100
Empty file.
Loading
Loading