Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix db session used in different threads #4381

Merged
merged 2 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend/base/langflow/api/v1/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


@router.post("/login", response_model=Token)
async def login_to_get_access_token(
def login_to_get_access_token(
ogabrielluiz marked this conversation as resolved.
Show resolved Hide resolved
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: DbSession,
Expand Down
65 changes: 33 additions & 32 deletions src/backend/base/langflow/services/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langflow.services.database.models.api_key.model import ApiKey
from langflow.services.database.models.user.crud import get_user_by_id, get_user_by_username, update_user_last_login_at
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import get_session, get_settings_service
from langflow.services.deps import get_db_service, get_session, get_settings_service
from langflow.services.settings.service import SettingsService

oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)
Expand All @@ -36,41 +36,42 @@
def api_key_security(
query_param: Annotated[str, Security(api_key_query)],
header_param: Annotated[str, Security(api_key_header)],
db: Annotated[Session, Depends(get_session)],
) -> UserRead | None:
settings_service = get_settings_service()
result: ApiKey | User | None = None
if settings_service.auth_settings.AUTO_LOGIN:
# Get the first user
if not settings_service.auth_settings.SUPERUSER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing first superuser credentials",
)

result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)
with get_db_service().with_session() as db:
if settings_service.auth_settings.AUTO_LOGIN:
# Get the first user
if not settings_service.auth_settings.SUPERUSER:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing first superuser credentials",
)

elif not query_param and not header_param:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)
result = get_user_by_username(db, settings_service.auth_settings.SUPERUSER)

elif query_param:
result = check_key(db, query_param)
elif not query_param and not header_param:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)

else:
result = check_key(db, header_param)
elif query_param:
result = check_key(db, query_param)

if not result:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
if isinstance(result, ApiKey):
return UserRead.model_validate(result.user, from_attributes=True)
if isinstance(result, User):
return UserRead.model_validate(result, from_attributes=True)
else:
result = check_key(db, header_param)

if not result:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
if isinstance(result, ApiKey):
return UserRead.model_validate(result.user, from_attributes=True)
if isinstance(result, User):
return UserRead.model_validate(result, from_attributes=True)
msg = "Invalid result type"
raise ValueError(msg)

Expand All @@ -83,7 +84,7 @@ async def get_current_user(
) -> User:
if token:
return await get_current_user_by_jwt(token, db)
user = await asyncio.to_thread(api_key_security, query_param, header_param, db)
user = await asyncio.to_thread(api_key_security, query_param, header_param)
if user:
return user

Expand Down Expand Up @@ -164,17 +165,17 @@ async def get_current_user_for_websocket(
if token:
return await get_current_user_by_jwt(token, db)
if api_key:
return await asyncio.to_thread(api_key_security, api_key, query_param, db)
return await asyncio.to_thread(api_key_security, api_key, query_param)
return None


def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
ogabrielluiz marked this conversation as resolved.
Show resolved Hide resolved
if not current_user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
return current_user


def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
async def get_current_active_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
if not current_user.is_active:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Inactive user")
if not current_user.is_superuser:
Expand Down
6 changes: 5 additions & 1 deletion src/backend/base/langflow/services/database/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import time
from contextlib import contextmanager
from datetime import datetime, timezone
Expand Down Expand Up @@ -317,7 +318,7 @@ def create_db_and_tables(self) -> None:

logger.debug("Database and tables created successfully")

async def teardown(self) -> None:
def _teardown(self) -> None:
logger.debug("Tearing down database")
try:
settings_service = get_settings_service()
Expand All @@ -330,3 +331,6 @@ async def teardown(self) -> None:
logger.exception("Error tearing down database")

self.engine.dispose()

async def teardown(self) -> None:
await asyncio.to_thread(self._teardown)
9 changes: 8 additions & 1 deletion src/backend/base/langflow/services/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from loguru import logger
from sqlmodel import Session, select

Expand Down Expand Up @@ -110,10 +112,15 @@ def teardown_superuser(settings_service, session) -> None:
raise RuntimeError(msg) from exc


def _teardown_superuser():
with get_db_service().with_session() as session:
teardown_superuser(get_settings_service(), session)


async def teardown_services() -> None:
"""Teardown all the services."""
try:
teardown_superuser(get_settings_service(), next(get_session()))
await asyncio.to_thread(_teardown_superuser)
except Exception as exc: # noqa: BLE001
logger.exception(exc)
try:
Expand Down
Loading