diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index d5c2bd68558a..885f1267ec5b 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -14,7 +14,7 @@ from alembic import command, util from alembic.config import Config from loguru import logger -from sqlalchemy import event, exc, inspect +from sqlalchemy import AsyncAdaptedQueuePool, event, exc, inspect from sqlalchemy.dialects import sqlite as dialect_sqlite from sqlalchemy.engine import Engine from sqlalchemy.exc import OperationalError @@ -39,6 +39,7 @@ class DatabaseService(Service): name = "database_service" def __init__(self, settings_service: SettingsService): + self._logged_pragma = False self.settings_service = settings_service if settings_service.settings.database_url is None: msg = "No database URL provided" @@ -67,8 +68,6 @@ def __init__(self, settings_service: SettingsService): else: self.alembic_log_path = Path(langflow_dir) / alembic_log_file - self._logged_pragma = False - async def initialize_alembic_log_file(self): # Ensure the directory and file for the alembic log file exists await anyio.Path(self.alembic_log_path.parent).mkdir(parents=True, exist_ok=True) @@ -89,22 +88,49 @@ def _sanitize_database_url(self): "To avoid this warning, update the database URL." ) + def _build_connection_kwargs(self): + """Build connection kwargs by merging deprecated settings with db_connection_settings. + + Returns: + dict: Connection kwargs with deprecated settings overriding db_connection_settings + """ + settings = self.settings_service.settings + # Start with db_connection_settings as base + connection_kwargs = settings.db_connection_settings.copy() + + # Override individual settings if explicitly set + if "pool_size" in settings.model_fields_set: + logger.warning("pool_size is deprecated. Use db_connection_settings['pool_size'] instead.") + connection_kwargs["pool_size"] = settings.pool_size + if "max_overflow" in settings.model_fields_set: + logger.warning("max_overflow is deprecated. Use db_connection_settings['max_overflow'] instead.") + connection_kwargs["max_overflow"] = settings.max_overflow + + return connection_kwargs + def _create_engine(self) -> AsyncEngine: """Create the engine for the database.""" url_components = self.database_url.split("://", maxsplit=1) + + # Get connection settings from config, with defaults if not specified + # if the user specifies an empty dict, we allow it. + kwargs = self._build_connection_kwargs() + if url_components[0].startswith("sqlite"): scheme = "sqlite+aiosqlite" - kwargs = {} + # Even though the docs say this is the default, it raises an error + # if we don't specify it. + # https://docs.sqlalchemy.org/en/20/errors.html#pool-class-cannot-be-used-with-asyncio-engine-or-vice-versa + pool = AsyncAdaptedQueuePool else: - kwargs = { - "pool_size": self.settings_service.settings.pool_size, - "max_overflow": self.settings_service.settings.max_overflow, - } scheme = "postgresql+psycopg" if url_components[0].startswith("postgresql") else url_components[0] + pool = None + database_url = f"{scheme}://{url_components[1]}" return create_async_engine( database_url, connect_args=self._get_connect_args(), + poolclass=pool, **kwargs, ) diff --git a/src/backend/base/langflow/services/settings/base.py b/src/backend/base/langflow/services/settings/base.py index ec752be77257..d73a8c8f3cb3 100644 --- a/src/backend/base/langflow/services/settings/base.py +++ b/src/backend/base/langflow/services/settings/base.py @@ -74,9 +74,11 @@ class Settings(BaseSettings): database_connection_retry: bool = False """If True, Langflow will retry to connect to the database if it fails.""" pool_size: int = 10 - """The number of connections to keep open in the connection pool. If not provided, the default is 10.""" + """DEPRECATED: Use db_connection_settings['pool_size'] instead. + The number of connections to keep open in the connection pool. If not provided, the default is 10.""" max_overflow: int = 20 - """The number of connections to allow that can be opened beyond the pool size. + """DEPRECATED: Use db_connection_settings['max_overflow'] instead. + The number of connections to allow that can be opened beyond the pool size. If not provided, the default is 20.""" db_connect_timeout: int = 20 """The number of seconds to wait before giving up on a lock to released or establishing a connection to the @@ -86,6 +88,14 @@ class Settings(BaseSettings): sqlite_pragmas: dict | None = {"synchronous": "NORMAL", "journal_mode": "WAL"} """SQLite pragmas to use when connecting to the database.""" + db_connection_settings: dict | None = { + "pool_size": 10, + "max_overflow": 20, + "pool_timeout": 30, + "pool_pre_ping": True, + } + """Common database connection settings.""" + # cache configuration cache_type: Literal["async", "redis", "memory", "disk"] = "async" """The cache type can be 'async' or 'redis'.""" diff --git a/src/backend/tests/unit/test_initial_setup.py b/src/backend/tests/unit/test_initial_setup.py index 9601e8f20d2e..4d0fe8d3f968 100644 --- a/src/backend/tests/unit/test_initial_setup.py +++ b/src/backend/tests/unit/test_initial_setup.py @@ -2,6 +2,7 @@ import uuid from datetime import datetime from pathlib import Path +from unittest.mock import AsyncMock, patch import anyio import pytest @@ -212,7 +213,20 @@ async def test_refresh_starter_projects(): ], ) async def test_detect_github_url(url, expected): - assert await detect_github_url(url) == expected + # Mock the GitHub API response for the default branch case + mock_response = AsyncMock() + mock_response.json = lambda: {"default_branch": "main"} # Not async, just returns a dict + mock_response.raise_for_status.return_value = None + + with patch("httpx.AsyncClient.get", return_value=mock_response) as mock_get: + result = await detect_github_url(url) + assert result == expected + + # Verify the API call was only made for GitHub repo URLs + if "github.com" in url and not any(x in url for x in ["/tree/", "/releases/", "/commit/"]): + mock_get.assert_called_once() + else: + mock_get.assert_not_called() @pytest.mark.usefixtures("client")