Skip to content

Commit

Permalink
Remove email from refresh (#948)
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem authored Aug 23, 2024
1 parent cf8097c commit 76cd6db
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 25 deletions.
4 changes: 4 additions & 0 deletions py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ class Config:
},
}


class KGLocalSearchResult(BaseModel):
"""Result of a local knowledge graph search operation."""

query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
Expand All @@ -70,6 +72,7 @@ def __repr__(self) -> str:

class KGGlobalSearchResult(BaseModel):
"""Result of a global knowledge graph search operation."""

query: str
search_result: list[Dict[str, Any]]

Expand All @@ -82,6 +85,7 @@ def __repr__(self) -> str:

class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""

local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

Expand Down
4 changes: 1 addition & 3 deletions py/core/base/providers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ def login(self, email: str, password: str) -> Dict[str, Token]:
pass

@abstractmethod
def refresh_access_token(
self, user_email: str, refresh_token: str
) -> Dict[str, str]:
def refresh_access_token(self, refresh_token: str) -> Dict[str, str]:
pass

async def auth_wrapper(
Expand Down
1 change: 0 additions & 1 deletion py/core/main/api/routes/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ async def refresh_access_token_app(
This endpoint allows users to obtain a new access token using their refresh token.
"""
refresh_result = await self.engine.arefresh_access_token(
user_email=auth_user.email,
refresh_token=refresh_token,
)
return refresh_result
Expand Down
6 changes: 2 additions & 4 deletions py/core/main/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,9 @@ async def user(self, token: str) -> UserResponse:

@telemetry_event("RefreshToken")
async def refresh_access_token(
self, user_email: str, refresh_token: str
self, refresh_token: str
) -> dict[str, Token]:
return self.providers.auth.refresh_access_token(
user_email, refresh_token
)
return self.providers.auth.refresh_access_token(refresh_token)

@telemetry_event("ChangePassword")
async def change_password(
Expand Down
18 changes: 15 additions & 3 deletions py/core/pipes/retrieval/kg_search_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
PromptProvider,
RunLoggingSingleton,
)
from core.base.abstractions.search import (
KGGlobalSearchResult,
KGLocalSearchResult,
KGSearchResult,
)

from core.base.abstractions.search import KGLocalSearchResult, KGGlobalSearchResult, KGSearchResult
from ..abstractions.generator_pipe import GeneratorPipe

logger = logging.getLogger(__name__)


class KGSearchSearchPipe(GeneratorPipe):
"""
Embeds and stores documents using a specified embedding model and database.
Expand Down Expand Up @@ -127,7 +132,12 @@ async def local_search(
)
all_search_results.append(search_result)

yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2])
yield KGLocalSearchResult(
query=message,
entities=all_search_results[0],
relationships=all_search_results[1],
communities=all_search_results[2],
)

async def global_search(
self,
Expand Down Expand Up @@ -209,7 +219,9 @@ async def process_community(merged_report):

output = output.choices[0].message.content

yield KGGlobalSearchResult(query=message, search_result=output, citations=None)
yield KGGlobalSearchResult(
query=message, search_result=output, citations=None
)

async def _run_logic(
self,
Expand Down
15 changes: 3 additions & 12 deletions py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,35 +197,26 @@ def login(self, email: str, password: str) -> Dict[str, Token]:
raise R2RException(status_code=401, message="Email not verified")

access_token = self.create_access_token(data={"sub": user.email})
refresh_token = self.create_refresh_token(data={"sub": user.email})
refresh_token = self.create_refresh_token()
return {
"access_token": Token(token=access_token, token_type="access"),
"refresh_token": Token(token=refresh_token, token_type="refresh"),
}

def refresh_access_token(
self, user_email: str, refresh_token: str
) -> Dict[str, Token]:
def refresh_access_token(self, refresh_token: str) -> Dict[str, Token]:
token_data = self.decode_token(refresh_token)
if token_data.token_type != "refresh":
raise R2RException(
status_code=401, message="Invalid refresh token"
)
if token_data.email != user_email:
raise R2RException(
status_code=402,
message="Invalid email address attached to token",
)

# Invalidate the old refresh token and create a new one
self.db_provider.relational.blacklist_token(refresh_token)

new_access_token = self.create_access_token(
data={"sub": token_data.email}
)
new_refresh_token = self.create_refresh_token(
data={"sub": token_data.email}
)
new_refresh_token = self.create_refresh_token()
return {
"access_token": Token(token=new_access_token, token_type="access"),
"refresh_token": Token(
Expand Down
5 changes: 3 additions & 2 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,14 @@ class Config:
{
"Paris": {
"name": "Paris",
"description": "Paris is the capital of France."
"description": "Paris is the capital of France.",
}
}
]
],
}
}


class R2RException(Exception):
def __init__(
self, message: str, status_code: int, detail: Optional[Any] = None
Expand Down

0 comments on commit 76cd6db

Please sign in to comment.