Skip to content

Commit

Permalink
Add AsyncSession support for non-blocking db operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 5, 2024
1 parent 07d8f2e commit d07ddce
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 69 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies = [
"pymongo>=4.6.0",
"supabase~=2.6.0",
"certifi>=2023.11.17,<2025.0.0",
"psycopg>=3.1.9",
"psycopg[binary,pool]>=3.1.9",
"fastavro>=1.8.0",
"redis>=5.0.1",
"metaphor-python>=0.1.11",
Expand Down Expand Up @@ -111,6 +111,7 @@ dependencies = [
"langchain-elasticsearch>=0.2.0",
"langchain-ollama>=0.2.0",
"pymupdf~=1.24.13",
"aiosqlite>=0.20.0",
]

[project.urls]
Expand Down
50 changes: 29 additions & 21 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import inspect
import platform
import socket
Expand Down Expand Up @@ -27,7 +28,7 @@
create_default_folder_if_it_doesnt_exist,
)
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_db_service, get_settings_service, session_scope
from langflow.services.deps import async_session_scope, get_db_service, get_settings_service
from langflow.services.settings.constants import DEFAULT_SUPERUSER
from langflow.services.utils import initialize_services
from langflow.utils.version import fetch_latest_version, get_version_info
Expand Down Expand Up @@ -486,28 +487,35 @@ def api_key(
if not auth_settings.AUTO_LOGIN:
typer.echo("Auto login is disabled. API keys cannot be created through the CLI.")
return
with session_scope() as session:
from langflow.services.database.models.user.model import User

superuser = session.exec(select(User).where(User.username == DEFAULT_SUPERUSER)).first()
if not superuser:
typer.echo("Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled.")
return
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate
from langflow.services.database.models.api_key.crud import (
create_api_key,
delete_api_key,
)

api_key = session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id)).first()
if api_key:
delete_api_key(session, api_key.id)
async def aapi_key():
async with async_session_scope() as session:
from langflow.services.database.models.user.model import User

api_key_create = ApiKeyCreate(name="CLI")
unmasked_api_key = create_api_key(session, api_key_create, user_id=superuser.id)
session.commit()
# Create a banner to display the API key and tell the user it won't be shown again
api_key_banner(unmasked_api_key)
superuser = (await session.exec(select(User).where(User.username == DEFAULT_SUPERUSER))).first()
if not superuser:
typer.echo(
"Default superuser not found. This command requires a superuser and AUTO_LOGIN to be enabled."
)
return None
from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate
from langflow.services.database.models.api_key.crud import (
create_api_key,
delete_api_key,
)

api_key = (await session.exec(select(ApiKey).where(ApiKey.user_id == superuser.id))).first()
if api_key:
await delete_api_key(session, api_key.id)

api_key_create = ApiKeyCreate(name="CLI")
unmasked_api_key = await create_api_key(session, api_key_create, user_id=superuser.id)
await session.commit()
return unmasked_api_key

unmasked_api_key = asyncio.run(aapi_key())
# Create a banner to display the API key and tell the user it won't be shown again
api_key_banner(unmasked_api_key)


def api_key_banner(unmasked_api_key) -> None:
Expand Down
4 changes: 3 additions & 1 deletion src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from loguru import logger
from sqlalchemy import delete
from sqlmodel import Session
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models import User
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.transactions.model import TransactionTable
from langflow.services.database.models.vertex_builds.model import VertexBuildTable
from langflow.services.deps import get_session
from langflow.services.deps import get_async_session, get_session
from langflow.services.store.utils import get_lf_version_from_pypi

if TYPE_CHECKING:
Expand All @@ -31,6 +32,7 @@

CurrentActiveUser = Annotated[User, Depends(get_current_active_user)]
DbSession = Annotated[Session, Depends(get_session)]
AsyncDbSession = Annotated[AsyncSession, Depends(get_async_session)]


def has_api_terms(word: str):
Expand Down
14 changes: 7 additions & 7 deletions src/backend/base/langflow/api/v1/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi import APIRouter, Depends, HTTPException, Response

from langflow.api.utils import CurrentActiveUser, DbSession
from langflow.api.utils import AsyncDbSession, CurrentActiveUser, DbSession
from langflow.api.v1.schemas import ApiKeyCreateRequest, ApiKeysResponse
from langflow.services.auth import utils as auth_utils

Expand All @@ -20,12 +20,12 @@

@router.get("/")
async def get_api_keys_route(
db: DbSession,
db: AsyncDbSession,
current_user: CurrentActiveUser,
) -> ApiKeysResponse:
try:
user_id = current_user.id
keys = get_api_keys(db, user_id)
keys = await get_api_keys(db, user_id)

return ApiKeysResponse(total_count=len(keys), user_id=user_id, api_keys=keys)
except Exception as exc:
Expand All @@ -36,22 +36,22 @@ async def get_api_keys_route(
async def create_api_key_route(
req: ApiKeyCreate,
current_user: CurrentActiveUser,
db: DbSession,
db: AsyncDbSession,
) -> UnmaskedApiKeyRead:
try:
user_id = current_user.id
return create_api_key(db, req, user_id=user_id)
return await create_api_key(db, req, user_id=user_id)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e


@router.delete("/{api_key_id}", dependencies=[Depends(auth_utils.get_current_active_user)])
async def delete_api_key_route(
api_key_id: UUID,
db: DbSession,
db: AsyncDbSession,
):
try:
delete_api_key(db, api_key_id)
await delete_api_key(db, api_key_id)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return {"detail": "API Key deleted"}
Expand Down
19 changes: 10 additions & 9 deletions src/backend/base/langflow/services/database/models/api_key/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
from uuid import UUID

from sqlmodel import Session, select
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.database.models.api_key import ApiKey, ApiKeyCreate, ApiKeyRead, UnmaskedApiKeyRead

if TYPE_CHECKING:
from sqlmodel.sql.expression import SelectOfScalar


def get_api_keys(session: Session, user_id: UUID) -> list[ApiKeyRead]:
async def get_api_keys(session: AsyncSession, user_id: UUID) -> list[ApiKeyRead]:
query: SelectOfScalar = select(ApiKey).where(ApiKey.user_id == user_id)
api_keys = session.exec(query).all()
api_keys = (await session.exec(query)).all()
return [ApiKeyRead.model_validate(api_key) for api_key in api_keys]


def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead:
async def create_api_key(session: AsyncSession, api_key_create: ApiKeyCreate, user_id: UUID) -> UnmaskedApiKeyRead:
# Generate a random API key with 32 bytes of randomness
generated_api_key = f"sk-{secrets.token_urlsafe(32)}"

Expand All @@ -30,20 +31,20 @@ def create_api_key(session: Session, api_key_create: ApiKeyCreate, user_id: UUID
)

session.add(api_key)
session.commit()
session.refresh(api_key)
await session.commit()
await session.refresh(api_key)
unmasked = UnmaskedApiKeyRead.model_validate(api_key, from_attributes=True)
unmasked.api_key = generated_api_key
return unmasked


def delete_api_key(session: Session, api_key_id: UUID) -> None:
api_key = session.get(ApiKey, api_key_id)
async def delete_api_key(session: AsyncSession, api_key_id: UUID) -> None:
api_key = await session.get(ApiKey, api_key_id)
if api_key is None:
msg = "API Key not found"
raise ValueError(msg)
session.delete(api_key)
session.commit()
await session.delete(api_key)
await session.commit()


def check_key(session: Session, api_key: str) -> ApiKey | None:
Expand Down
82 changes: 56 additions & 26 deletions src/backend/base/langflow/services/database/service.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import asyncio
import sqlite3
import time
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -14,7 +15,9 @@
from sqlalchemy import event, inspect
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import Session, SQLModel, create_engine, select, text
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.base import Service
from langflow.services.database import models
Expand All @@ -39,12 +42,17 @@ def __init__(self, settings_service: SettingsService):
msg = "No database URL provided"
raise ValueError(msg)
self.database_url: str = settings_service.settings.database_url
self._sanitize_database_url()
# This file is in langflow.services.database.manager.py
# the ini is in langflow
langflow_dir = Path(__file__).parent.parent.parent
self.script_location = langflow_dir / "alembic"
self.alembic_cfg_path = langflow_dir / "alembic.ini"
# register the event listener for sqlite as part of this class.
# Using decorator will make the method not able to use self
event.listen(Engine, "connect", self.on_connection)
self.engine = self._create_engine()
self.async_engine = self._create_async_engine()
alembic_log_file = self.settings_service.settings.alembic_log_file

# Check if the provided path is absolute, cross-platform.
Expand All @@ -56,10 +64,47 @@ def __init__(self, settings_service: SettingsService):
self.alembic_log_path = Path(langflow_dir) / alembic_log_file

def reload_engine(self) -> None:
self._sanitize_database_url()
self.engine = self._create_engine()
self.async_engine = self._create_async_engine()

def _sanitize_database_url(self):
if self.database_url.startswith("postgres://"):
self.database_url = self.database_url.replace("postgres://", "postgresql://")
logger.warning(
"Fixed postgres dialect in database URL. Replacing postgres:// with postgresql://. "
"To avoid this warning, update the database URL."
)

def _create_engine(self) -> Engine:
"""Create the engine for the database."""
return create_engine(
self.database_url,
connect_args=self._get_connect_args(),
pool_size=self.settings_service.settings.pool_size,
max_overflow=self.settings_service.settings.max_overflow,
)

def _create_async_engine(self) -> AsyncEngine:
"""Create the engine for the database."""
url_components = self.database_url.split("://", maxsplit=1)
if url_components[0].startswith("sqlite"):
database_url = "sqlite+aiosqlite://"
kwargs = {}
else:
kwargs = {
"pool_size": self.settings_service.settings.pool_size,
"max_overflow": self.settings_service.settings.max_overflow,
}
database_url = "postgresql+psycopg://" if url_components[0].startswith("postgresql") else url_components[0]
database_url += url_components[1]
return create_async_engine(
database_url,
connect_args=self._get_connect_args(),
**kwargs,
)

def _get_connect_args(self):
if self.settings_service.settings.database_url and self.settings_service.settings.database_url.startswith(
"sqlite"
):
Expand All @@ -69,33 +114,12 @@ def _create_engine(self) -> Engine:
}
else:
connect_args = {}
try:
# register the event listener for sqlite as part of this class.
# Using decorator will make the method not able to use self
event.listen(Engine, "connect", self.on_connection)

return create_engine(
self.database_url,
connect_args=connect_args,
pool_size=self.settings_service.settings.pool_size,
max_overflow=self.settings_service.settings.max_overflow,
)
except sa.exc.NoSuchModuleError as exc:
if "postgres" in str(exc) and not self.database_url.startswith("postgresql"):
# https://stackoverflow.com/questions/62688256/sqlalchemy-exc-nosuchmoduleerror-cant-load-plugin-sqlalchemy-dialectspostgre
self.database_url = self.database_url.replace("postgres://", "postgresql://")
logger.warning(
"Fixed postgres dialect in database URL. Replacing postgres:// with postgresql://. "
"To avoid this warning, update the database URL."
)
return self._create_engine()
msg = "Error creating database engine"
raise RuntimeError(msg) from exc
return connect_args

def on_connection(self, dbapi_connection, _connection_record) -> None:
from sqlite3 import Connection as sqliteConnection

if isinstance(dbapi_connection, sqliteConnection):
if isinstance(
dbapi_connection, sqlite3.Connection | sa.dialects.sqlite.aiosqlite.AsyncAdapt_aiosqlite_connection
):
pragmas: dict = self.settings_service.settings.sqlite_pragmas or {}
pragmas_list = []
for key, val in pragmas.items():
Expand All @@ -117,6 +141,11 @@ def with_session(self):
with Session(self.engine) as session:
yield session

@asynccontextmanager
async def with_async_session(self):
async with AsyncSession(self.async_engine) as session:
yield session

def migrate_flows_if_auto_login(self) -> None:
# if auto_login is enabled, we need to migrate the flows
# to the default superuser if they don't have a user id
Expand Down Expand Up @@ -334,3 +363,4 @@ def _teardown(self) -> None:

async def teardown(self) -> None:
await asyncio.to_thread(self._teardown)
await self.async_engine.dispose()
Loading

0 comments on commit d07ddce

Please sign in to comment.