diff --git a/py/core/base/abstractions/search.py b/py/core/base/abstractions/search.py index affc9cb9f..565728794 100644 --- a/py/core/base/abstractions/search.py +++ b/py/core/base/abstractions/search.py @@ -54,8 +54,10 @@ class Config: }, } + class KGLocalSearchResult(BaseModel): """Result of a local knowledge graph search operation.""" + query: str entities: list[dict[str, Any]] relationships: list[dict[str, Any]] @@ -70,6 +72,7 @@ def __repr__(self) -> str: class KGGlobalSearchResult(BaseModel): """Result of a global knowledge graph search operation.""" + query: str search_result: list[Dict[str, Any]] @@ -82,6 +85,7 @@ def __repr__(self) -> str: class KGSearchResult(BaseModel): """Result of a knowledge graph search operation.""" + local_result: Optional[KGLocalSearchResult] = None global_result: Optional[KGGlobalSearchResult] = None diff --git a/py/core/base/providers/auth.py b/py/core/base/providers/auth.py index f979f38ba..17f2a2508 100644 --- a/py/core/base/providers/auth.py +++ b/py/core/base/providers/auth.py @@ -90,9 +90,7 @@ def login(self, email: str, password: str) -> Dict[str, Token]: pass @abstractmethod - def refresh_access_token( - self, user_email: str, refresh_token: str - ) -> Dict[str, str]: + def refresh_access_token(self, refresh_token: str) -> Dict[str, str]: pass async def auth_wrapper( diff --git a/py/core/main/api/routes/auth/base.py b/py/core/main/api/routes/auth/base.py index 408ee555e..9dac245b9 100644 --- a/py/core/main/api/routes/auth/base.py +++ b/py/core/main/api/routes/auth/base.py @@ -141,7 +141,6 @@ async def refresh_access_token_app( This endpoint allows users to obtain a new access token using their refresh token. """ refresh_result = await self.engine.arefresh_access_token( - user_email=auth_user.email, refresh_token=refresh_token, ) return refresh_result diff --git a/py/core/main/services/auth_service.py b/py/core/main/services/auth_service.py index 5229ed3a0..d704ae91d 100644 --- a/py/core/main/services/auth_service.py +++ b/py/core/main/services/auth_service.py @@ -74,11 +74,9 @@ async def user(self, token: str) -> UserResponse: @telemetry_event("RefreshToken") async def refresh_access_token( - self, user_email: str, refresh_token: str + self, refresh_token: str ) -> dict[str, Token]: - return self.providers.auth.refresh_access_token( - user_email, refresh_token - ) + return self.providers.auth.refresh_access_token(refresh_token) @telemetry_event("ChangePassword") async def change_password( diff --git a/py/core/pipes/retrieval/kg_search_search_pipe.py b/py/core/pipes/retrieval/kg_search_search_pipe.py index e5664fae5..fcd37889a 100644 --- a/py/core/pipes/retrieval/kg_search_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_search_pipe.py @@ -14,12 +14,17 @@ PromptProvider, RunLoggingSingleton, ) +from core.base.abstractions.search import ( + KGGlobalSearchResult, + KGLocalSearchResult, + KGSearchResult, +) -from core.base.abstractions.search import KGLocalSearchResult, KGGlobalSearchResult, KGSearchResult from ..abstractions.generator_pipe import GeneratorPipe logger = logging.getLogger(__name__) + class KGSearchSearchPipe(GeneratorPipe): """ Embeds and stores documents using a specified embedding model and database. @@ -127,7 +132,12 @@ async def local_search( ) all_search_results.append(search_result) - yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2]) + yield KGLocalSearchResult( + query=message, + entities=all_search_results[0], + relationships=all_search_results[1], + communities=all_search_results[2], + ) async def global_search( self, @@ -209,7 +219,9 @@ async def process_community(merged_report): output = output.choices[0].message.content - yield KGGlobalSearchResult(query=message, search_result=output, citations=None) + yield KGGlobalSearchResult( + query=message, search_result=output, citations=None + ) async def _run_logic( self, diff --git a/py/core/providers/auth/r2r_auth.py b/py/core/providers/auth/r2r_auth.py index 7dab2d02a..d82257f73 100644 --- a/py/core/providers/auth/r2r_auth.py +++ b/py/core/providers/auth/r2r_auth.py @@ -197,25 +197,18 @@ def login(self, email: str, password: str) -> Dict[str, Token]: raise R2RException(status_code=401, message="Email not verified") access_token = self.create_access_token(data={"sub": user.email}) - refresh_token = self.create_refresh_token(data={"sub": user.email}) + refresh_token = self.create_refresh_token() return { "access_token": Token(token=access_token, token_type="access"), "refresh_token": Token(token=refresh_token, token_type="refresh"), } - def refresh_access_token( - self, user_email: str, refresh_token: str - ) -> Dict[str, Token]: + def refresh_access_token(self, refresh_token: str) -> Dict[str, Token]: token_data = self.decode_token(refresh_token) if token_data.token_type != "refresh": raise R2RException( status_code=401, message="Invalid refresh token" ) - if token_data.email != user_email: - raise R2RException( - status_code=402, - message="Invalid email address attached to token", - ) # Invalidate the old refresh token and create a new one self.db_provider.relational.blacklist_token(refresh_token) @@ -223,9 +216,7 @@ def refresh_access_token( new_access_token = self.create_access_token( data={"sub": token_data.email} ) - new_refresh_token = self.create_refresh_token( - data={"sub": token_data.email} - ) + new_refresh_token = self.create_refresh_token() return { "access_token": Token(token=new_access_token, token_type="access"), "refresh_token": Token( diff --git a/py/sdk/models.py b/py/sdk/models.py index bba344f50..571659dfb 100644 --- a/py/sdk/models.py +++ b/py/sdk/models.py @@ -200,13 +200,14 @@ class Config: { "Paris": { "name": "Paris", - "description": "Paris is the capital of France." + "description": "Paris is the capital of France.", } } - ] + ], } } + class R2RException(Exception): def __init__( self, message: str, status_code: int, detail: Optional[Any] = None