Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove email in refresh #948

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ class Config:
},
}


class KGLocalSearchResult(BaseModel):
"""Result of a local knowledge graph search operation."""

query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
Expand All @@ -70,6 +72,7 @@ def __repr__(self) -> str:

class KGGlobalSearchResult(BaseModel):
"""Result of a global knowledge graph search operation."""

query: str
search_result: list[Dict[str, Any]]

Expand All @@ -82,6 +85,7 @@ def __repr__(self) -> str:

class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""

local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

Expand Down
4 changes: 1 addition & 3 deletions py/core/base/providers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ def login(self, email: str, password: str) -> Dict[str, Token]:
pass

@abstractmethod
def refresh_access_token(
self, user_email: str, refresh_token: str
) -> Dict[str, str]:
def refresh_access_token(self, refresh_token: str) -> Dict[str, str]:
pass

async def auth_wrapper(
Expand Down
1 change: 0 additions & 1 deletion py/core/main/api/routes/auth/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ async def refresh_access_token_app(
This endpoint allows users to obtain a new access token using their refresh token.
"""
refresh_result = await self.engine.arefresh_access_token(
user_email=auth_user.email,
refresh_token=refresh_token,
)
return refresh_result
Expand Down
6 changes: 2 additions & 4 deletions py/core/main/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,9 @@ async def user(self, token: str) -> UserResponse:

@telemetry_event("RefreshToken")
async def refresh_access_token(
self, user_email: str, refresh_token: str
self, refresh_token: str
) -> dict[str, Token]:
return self.providers.auth.refresh_access_token(
user_email, refresh_token
)
return self.providers.auth.refresh_access_token(refresh_token)

@telemetry_event("ChangePassword")
async def change_password(
Expand Down
18 changes: 15 additions & 3 deletions py/core/pipes/retrieval/kg_search_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
PromptProvider,
RunLoggingSingleton,
)
from core.base.abstractions.search import (
KGGlobalSearchResult,
KGLocalSearchResult,
KGSearchResult,
)

from core.base.abstractions.search import KGLocalSearchResult, KGGlobalSearchResult, KGSearchResult
from ..abstractions.generator_pipe import GeneratorPipe

logger = logging.getLogger(__name__)


class KGSearchSearchPipe(GeneratorPipe):
"""
Embeds and stores documents using a specified embedding model and database.
Expand Down Expand Up @@ -127,7 +132,12 @@ async def local_search(
)
all_search_results.append(search_result)

yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2])
yield KGLocalSearchResult(
query=message,
entities=all_search_results[0],
relationships=all_search_results[1],
communities=all_search_results[2],
)

async def global_search(
self,
Expand Down Expand Up @@ -209,7 +219,9 @@ async def process_community(merged_report):

output = output.choices[0].message.content

yield KGGlobalSearchResult(query=message, search_result=output, citations=None)
yield KGGlobalSearchResult(
query=message, search_result=output, citations=None
)

async def _run_logic(
self,
Expand Down
15 changes: 3 additions & 12 deletions py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,35 +197,26 @@ def login(self, email: str, password: str) -> Dict[str, Token]:
raise R2RException(status_code=401, message="Email not verified")

access_token = self.create_access_token(data={"sub": user.email})
refresh_token = self.create_refresh_token(data={"sub": user.email})
refresh_token = self.create_refresh_token()
return {
"access_token": Token(token=access_token, token_type="access"),
"refresh_token": Token(token=refresh_token, token_type="refresh"),
}

def refresh_access_token(
self, user_email: str, refresh_token: str
) -> Dict[str, Token]:
def refresh_access_token(self, refresh_token: str) -> Dict[str, Token]:
token_data = self.decode_token(refresh_token)
if token_data.token_type != "refresh":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the user_email parameter from refresh_access_token may reduce security if the email was used for validation. Ensure that the token itself provides sufficient validation.

raise R2RException(
status_code=401, message="Invalid refresh token"
)
if token_data.email != user_email:
raise R2RException(
status_code=402,
message="Invalid email address attached to token",
)

# Invalidate the old refresh token and create a new one
self.db_provider.relational.blacklist_token(refresh_token)

new_access_token = self.create_access_token(
data={"sub": token_data.email}
)
new_refresh_token = self.create_refresh_token(
data={"sub": token_data.email}
)
new_refresh_token = self.create_refresh_token()
return {
"access_token": Token(token=new_access_token, token_type="access"),
"refresh_token": Token(
Expand Down
5 changes: 3 additions & 2 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,14 @@ class Config:
{
"Paris": {
"name": "Paris",
"description": "Paris is the capital of France."
"description": "Paris is the capital of France.",
}
}
]
],
}
}


class R2RException(Exception):
def __init__(
self, message: str, status_code: int, detail: Optional[Any] = None
Expand Down
Loading