diff --git a/UPDATING.md b/UPDATING.md index 0f22ef36f037..635a7d2f3a0e 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -24,6 +24,41 @@ assists people when migrating to a new version. ## Next +### Engine Manager for Connection Pooling + +A new `EngineManager` class has been introduced to centralize SQLAlchemy engine creation and management. This enables connection pooling for analytics databases and provides a more flexible architecture for engine configuration. + +#### Breaking Changes + +1. **Removed `SSH_TUNNEL_MANAGER_CLASS` config**: SSH tunnel handling is now integrated into the EngineManager. If you have custom SSH tunnel managers, you'll need to migrate to the new architecture. + +2. **Removed `nullpool` parameter**: The `get_sqla_engine()` and `get_raw_connection()` methods on the `Database` model no longer accept a `nullpool` parameter. Pool configuration is now controlled through the engine manager. + +3. **Removed `_get_sqla_engine()` method**: The private `_get_sqla_engine()` method has been removed from the `Database` model. All engine creation now goes through the `EngineManager`. + +#### New Configuration Options + +```python +# Engine manager mode: +# - EngineModes.NEW: Creates a new engine for every connection (default, original behavior) +# - EngineModes.SINGLETON: Reuses engines with connection pooling +from superset.engines.manager import EngineModes +ENGINE_MANAGER_MODE = EngineModes.NEW + +# Cleanup interval for abandoned locks (default: 5 minutes) +from datetime import timedelta +ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5) + +# Automatically start cleanup thread for SINGLETON mode (default: True) +ENGINE_MANAGER_AUTO_START_CLEANUP = True +``` + +#### Migration Guide + +- If you were using the `nullpool` parameter, remove it from your calls +- If you had a custom `SSH_TUNNEL_MANAGER_CLASS`, refactor to use the new EngineManager architecture +- If you need connection pooling, set `ENGINE_MANAGER_MODE = EngineModes.SINGLETON` and configure the pool in your database's `extra` JSON field + ### WebSocket config for GAQ with Docker [35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to: diff --git a/superset/config.py b/superset/config.py index 9731d80bd587..4110d902ffa9 100644 --- a/superset/config.py +++ b/superset/config.py @@ -52,10 +52,15 @@ from superset.advanced_data_type.plugins.internet_port import internet_port from superset.advanced_data_type.types import AdvancedDataType from superset.constants import CHANGE_ME_SECRET_KEY +from superset.engines.manager import EngineModes from superset.jinja_context import BaseTemplateProcessor from superset.key_value.types import JsonKeyValueCodec from superset.stats_logger import DummyStatsLogger -from superset.superset_typing import CacheConfig +from superset.superset_typing import ( + CacheConfig, + DBConnectionMutator, + EngineContextManager, +) from superset.tasks.types import ExecutorType from superset.themes.types import Theme from superset.utils import core as utils @@ -260,6 +265,22 @@ def _try_json_readsha(filepath: str, length: int) -> str | None: # SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password SQLALCHEMY_CUSTOM_PASSWORD_STORE = None +# --------------------------------------------------------- +# Engine Manager Configuration +# --------------------------------------------------------- + +# Engine manager mode: "NEW" creates a new engine for every connection (default), +# "SINGLETON" reuses engines with connection pooling +ENGINE_MANAGER_MODE = EngineModes.NEW + +# Cleanup interval for abandoned locks in seconds (default: 5 minutes) +ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5) + +# Automatically start cleanup thread for SINGLETON mode (default: True) +ENGINE_MANAGER_AUTO_START_CLEANUP = True + +# --------------------------------------------------------- + # # The EncryptedFieldTypeAdapter is used whenever we're building SqlAlchemy models # which include sensitive fields that should be app-encrypted BEFORE sending @@ -809,7 +830,6 @@ class D3TimeFormat(TypedDict, total=False): # FIREWALL (only port 22 is open) # ---------------------------------------------------------------------- -SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager" SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1" #: Timeout (seconds) for tunnel connection (open_channel timeout) SSH_TUNNEL_TIMEOUT_SEC = 10.0 @@ -1684,7 +1704,7 @@ def engine_context_manager( # pylint: disable=unused-argument yield None -ENGINE_CONTEXT_MANAGER = engine_context_manager +ENGINE_CONTEXT_MANAGER: EngineContextManager = engine_context_manager # A callable that allows altering the database connection URL and params # on the fly, at runtime. This allows for things like impersonation or @@ -1701,7 +1721,7 @@ def engine_context_manager( # pylint: disable=unused-argument # # Note that the returned uri and params are passed directly to sqlalchemy's # as such `create_engine(url, **params)` -DB_CONNECTION_MUTATOR = None +DB_CONNECTION_MUTATOR: DBConnectionMutator | None = None # A callable that is invoked for every invocation of DB Engine Specs diff --git a/tests/unit_tests/extensions/ssh_test.py b/superset/engines/__init__.py similarity index 57% rename from tests/unit_tests/extensions/ssh_test.py rename to superset/engines/__init__.py index a36f0fe03eb0..13a83393a912 100644 --- a/tests/unit_tests/extensions/ssh_test.py +++ b/superset/engines/__init__.py @@ -14,23 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest.mock import Mock - -import sshtunnel - -from superset.extensions.ssh import SSHManagerFactory - - -def test_ssh_tunnel_timeout_setting() -> None: - app = Mock() - app.config = { - "SSH_TUNNEL_MAX_RETRIES": 2, - "SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test", - "SSH_TUNNEL_TIMEOUT_SEC": 123.0, - "SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0, - "SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager", - } - factory = SSHManagerFactory() - factory.init_app(app) - assert sshtunnel.TUNNEL_TIMEOUT == 123.0 - assert sshtunnel.SSH_TIMEOUT == 321.0 diff --git a/superset/engines/manager.py b/superset/engines/manager.py new file mode 100644 index 000000000000..9f3722d5eb89 --- /dev/null +++ b/superset/engines/manager.py @@ -0,0 +1,630 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import enum +import hashlib +import logging +import threading +from contextlib import contextmanager +from datetime import timedelta +from io import StringIO +from typing import Any, Iterator, TYPE_CHECKING + +import sshtunnel +from paramiko import RSAKey +from sqlalchemy import create_engine, event, pool +from sqlalchemy.engine import Engine +from sqlalchemy.engine.url import URL +from sshtunnel import SSHTunnelForwarder + +from superset.databases.utils import make_url_safe +from superset.superset_typing import DBConnectionMutator, EngineContextManager +from superset.utils.core import get_query_source_from_request, get_user_id, QuerySource + +if TYPE_CHECKING: + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + +logger = logging.getLogger(__name__) + + +class _LockManager: + """ + Manages per-key locks safely without defaultdict race conditions. + + This class provides a thread-safe way to create and manage locks for specific keys, + avoiding the race conditions that occur when using defaultdict with threading.Lock. + + The implementation uses a two-level locking strategy: + 1. A meta-lock to protect the lock dictionary itself + 2. Per-key locks to protect specific resources + + This ensures that: + - Different keys can be locked concurrently (scalability) + - Lock creation is thread-safe (no race conditions) + - The same key always gets the same lock instance + """ + + def __init__(self) -> None: + self._locks: dict[str, threading.RLock] = {} + self._meta_lock = threading.Lock() + + def get_lock(self, key: str) -> threading.RLock: + """ + Get or create a lock for the given key. + + This method uses double-checked locking to ensure thread safety: + 1. First check without lock (fast path) + 2. Acquire meta-lock if needed + 3. Double-check inside the lock to prevent race conditions + + This approach minimizes lock contention while ensuring correctness. + + :param key: The key to get a lock for + :returns: An RLock instance for the given key + """ + if lock := self._locks.get(key): + return lock + + with self._meta_lock: + # Double-check inside the lock + lock = self._locks.get(key) + if lock is None: + lock = threading.RLock() + self._locks[key] = lock + return lock + + def cleanup(self, active_keys: set[str]) -> None: + """ + Remove locks for keys that are no longer in use. + + This prevents memory leaks from accumulating locks for resources + that have been disposed. + + :param active_keys: Set of keys that are still active + """ + with self._meta_lock: + # Find locks to remove + locks_to_remove = self._locks.keys() - active_keys + for key in locks_to_remove: + self._locks.pop(key, None) + + +EngineKey = str +TunnelKey = str + + +def _generate_cache_key(*args: Any) -> str: + """ + Generate a deterministic cache key from arbitrary arguments. + + Uses repr() for serialization and SHA-256 for hashing. The resulting key + is a 32-character hex string that: + 1. Is deterministic for the same inputs + 2. Does not expose sensitive data (everything is hashed) + 3. Has sufficient entropy to avoid collisions + + :param args: Arguments to include in the cache key + :returns: 32-character hex string + """ + # Use repr() which works with most Python objects and is deterministic + serialized = repr(args).encode("utf-8") + return hashlib.sha256(serialized).hexdigest()[:32] + + +class EngineModes(enum.Enum): + # reuse existing engine if available, otherwise create a new one; this mode should + # have a connection pool configured in the database + SINGLETON = enum.auto() + + # always create a new engine for every connection; this mode will use a NullPool + # and is the default behavior for Superset + NEW = enum.auto() + + +class EngineManager: + """ + A manager for SQLAlchemy engines. + + This class handles the creation and management of SQLAlchemy engines, allowing them + to be configured with connection pools and reused across requests. The default mode + is the original behavior for Superset, where we create a new engine for every + connection, using a NullPool. The `SINGLETON` mode, on the other hand, allows for + reusing of the engines, as well as configuring the pool through the database + settings. + """ + + def __init__( + self, + engine_context_manager: EngineContextManager, + db_connection_mutator: DBConnectionMutator | None = None, + mode: EngineModes = EngineModes.NEW, + cleanup_interval: timedelta = timedelta(minutes=5), + local_bind_address: str = "127.0.0.1", + tunnel_timeout: timedelta = timedelta(seconds=30), + ssh_timeout: timedelta = timedelta(seconds=1), + ) -> None: + self.engine_context_manager = engine_context_manager + self.db_connection_mutator = db_connection_mutator + self.mode = mode + self.cleanup_interval = cleanup_interval + self.local_bind_address = local_bind_address + + sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds() + sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds() + + self._engines: dict[EngineKey, Engine] = {} + self._engine_locks = _LockManager() + self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {} + self._tunnel_locks = _LockManager() + + # Background cleanup thread management + self._cleanup_thread: threading.Thread | None = None + self._cleanup_stop_event = threading.Event() + self._cleanup_thread_lock = threading.Lock() + + def __del__(self) -> None: + """ + Ensure cleanup thread is stopped when the manager is destroyed. + """ + try: + self.stop_cleanup_thread() + except Exception as ex: + # Avoid exceptions during garbage collection, but log if possible + try: + logger.warning("Error stopping cleanup thread: %s", ex) + except Exception: # noqa: S110 + # If logging fails during destruction, we can't do anything + pass + + @contextmanager + def get_engine( + self, + database: "Database", + catalog: str | None, + schema: str | None, + source: QuerySource | None, + ) -> Iterator[Engine]: + """ + Context manager to get a SQLAlchemy engine. + """ + # users can wrap the engine in their own context manager for different + # reasons + with self.engine_context_manager(database, catalog, schema): + # we need to check for errors indicating that OAuth2 is needed, and + # return the proper exception so it starts the authentication flow + from superset.utils.oauth2 import check_for_oauth2 + + with check_for_oauth2(database): + yield self._get_engine(database, catalog, schema, source) + + def _get_engine( + self, + database: "Database", + catalog: str | None, + schema: str | None, + source: QuerySource | None, + ) -> Engine: + """ + Get a specific engine, or create it if none exists. + """ + source = source or get_query_source_from_request() + user_id = get_user_id() + + if self.mode == EngineModes.NEW: + return self._create_engine( + database, + catalog, + schema, + source, + user_id, + ) + + engine_key = self._get_engine_key( + database, + catalog, + schema, + source, + user_id, + ) + + if engine := self._engines.get(engine_key): + return engine + + lock = self._engine_locks.get_lock(engine_key) + with lock: + # Double-check inside the lock + if engine := self._engines.get(engine_key): + return engine + + # Create and cache the engine + engine = self._create_engine( + database, + catalog, + schema, + source, + user_id, + ) + self._engines[engine_key] = engine + self._add_disposal_listener(engine, engine_key) + return engine + + def _get_engine_key( + self, + database: "Database", + catalog: str | None, + schema: str | None, + source: QuerySource | None, + user_id: int | None, + ) -> EngineKey: + """ + Generate a cache key for the engine. + + The key is a hash of all parameters that affect the engine, ensuring + proper cache isolation without exposing sensitive data. + + :returns: 32-character hex string + """ + uri, kwargs = self._get_engine_args( + database, + catalog, + schema, + source, + user_id, + ) + + return _generate_cache_key( + database.id, + catalog, + schema, + str(uri), + source, + user_id, + kwargs, + ) + + def _get_engine_args( + self, + database: "Database", + catalog: str | None, + schema: str | None, + source: QuerySource | None, + user_id: int | None, + ) -> tuple[URL, dict[str, Any]]: + """ + Build the almost final SQLAlchemy URI and engine kwargs. + + "Almost" final because we may still need to mutate the URI if an SSH tunnel is + needed, since it needs to connect to the tunnel instead of the original DB. But + that information is only available after the tunnel is created. + """ + # Import here to avoid circular imports + from superset.extensions import security_manager + from superset.utils.feature_flag_manager import FeatureFlagManager + + uri = make_url_safe(database.sqlalchemy_uri_decrypted) + + extra = database.get_extra(source) + # Make a copy to avoid mutating the original extra dict + kwargs = dict(extra.get("engine_params", {})) + + # get pool class + if self.mode == EngineModes.NEW or "poolclass" not in kwargs: + kwargs["poolclass"] = pool.NullPool + else: + pools = { + "queue": pool.QueuePool, + "singleton": pool.SingletonThreadPool, + "assertion": pool.AssertionPool, + "null": pool.NullPool, + "static": pool.StaticPool, + } + kwargs["poolclass"] = pools.get(extra["poolclass"], pool.QueuePool) + + # update URI for specific catalog/schema + connect_args = dict(extra.get("connect_args", {})) + uri, connect_args = database.db_engine_spec.adjust_engine_params( + uri, + connect_args, + catalog, + schema, + ) + + # get effective username + username = database.get_effective_user(uri) + + feature_flag_manager = FeatureFlagManager() + if username and feature_flag_manager.is_feature_enabled( + "IMPERSONATE_WITH_EMAIL_PREFIX" + ): + user = security_manager.find_user(username=username) + if user and user.email and "@" in user.email: + username = user.email.split("@")[0] + + # update URI/kwargs for user impersonation + if database.impersonate_user: + oauth2_config = database.get_oauth2_config() + # Import here to avoid circular imports + from superset.utils.oauth2 import get_oauth2_access_token + + access_token = ( + get_oauth2_access_token( + oauth2_config, + database.id, + user_id, + database.db_engine_spec, + ) + if oauth2_config and user_id + else None + ) + + uri, kwargs = database.db_engine_spec.impersonate_user( + database, + username, + access_token, + uri, + kwargs, + ) + + # update kwargs from params stored encrupted at rest + database.update_params_from_encrypted_extra(kwargs) + + # mutate URI + if self.db_connection_mutator: + source = source or get_query_source_from_request() + # Import here to avoid circular imports + from superset.extensions import security_manager + + uri, kwargs = self.db_connection_mutator( + uri, + kwargs, + username, + security_manager, + source, + ) + + # validate final URI + database.db_engine_spec.validate_database_uri(uri) + + return uri, kwargs + + def _create_engine( + self, + database: "Database", + catalog: str | None, + schema: str | None, + source: QuerySource | None, + user_id: int | None, + ) -> Engine: + """ + Create the actual engine. + + This should be the only place in Superset where a SQLAlchemy engine is created, + """ + uri, kwargs = self._get_engine_args( + database, + catalog, + schema, + source, + user_id, + ) + + if database.ssh_tunnel: + tunnel = self._get_tunnel(database.ssh_tunnel, uri) + uri = uri.set( + host=tunnel.local_bind_address[0], + port=tunnel.local_bind_port, + ) + + try: + engine = create_engine(uri, **kwargs) + except Exception as ex: + raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + + return engine + + def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: + tunnel_key = self._get_tunnel_key(ssh_tunnel, uri) + + tunnel = self._tunnels.get(tunnel_key) + if tunnel is not None and tunnel.is_active: + return tunnel + + lock = self._tunnel_locks.get_lock(tunnel_key) + with lock: + # Double-check inside the lock + tunnel = self._tunnels.get(tunnel_key) + if tunnel is not None and tunnel.is_active: + return tunnel + + # Create or replace tunnel + return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, tunnel) + + def _replace_tunnel( + self, + tunnel_key: str, + ssh_tunnel: "SSHTunnel", + uri: URL, + old_tunnel: SSHTunnelForwarder | None, + ) -> SSHTunnelForwarder: + """ + Replace tunnel with proper cleanup. + + This function assumes caller holds lock. + """ + if old_tunnel: + try: + old_tunnel.stop() + except Exception: + logger.exception("Error stopping old tunnel") + + try: + new_tunnel = self._create_tunnel(ssh_tunnel, uri) + self._tunnels[tunnel_key] = new_tunnel + except Exception: + # Remove failed tunnel from cache + self._tunnels.pop(tunnel_key, None) + logger.exception("Failed to create tunnel") + raise + + return new_tunnel + + def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey: + """ + Generate a cache key for the SSH tunnel. + + :returns: 32-character hex string + """ + tunnel_kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri) + return _generate_cache_key(tunnel_kwargs) + + def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: + kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri) + # Use open_tunnel which handles debug_level properly + tunnel = sshtunnel.open_tunnel(**kwargs) + tunnel.start() + + return tunnel + + def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]: + # Import here to avoid circular imports + from superset.utils.ssh_tunnel import get_default_port + + backend = uri.get_backend_name() + kwargs = { + "ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port), + "ssh_username": ssh_tunnel.username, + "remote_bind_address": (uri.host, uri.port or get_default_port(backend)), + "local_bind_address": (self.local_bind_address,), + "debug_level": logging.getLogger("flask_appbuilder").level, + } + + if ssh_tunnel.password: + kwargs["ssh_password"] = ssh_tunnel.password + elif ssh_tunnel.private_key: + private_key_file = StringIO(ssh_tunnel.private_key) + private_key = RSAKey.from_private_key( + private_key_file, + ssh_tunnel.private_key_password, + ) + kwargs["ssh_pkey"] = private_key + + if self.mode == EngineModes.NEW: + kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels + + return kwargs + + def start_cleanup_thread(self) -> None: + """ + Start the background cleanup thread. + + The thread will periodically clean up abandoned locks at the configured + interval. This is safe to call multiple times - subsequent calls are no-ops. + """ + with self._cleanup_thread_lock: + if self._cleanup_thread is None or not self._cleanup_thread.is_alive(): + self._cleanup_stop_event.clear() + self._cleanup_thread = threading.Thread( + target=self._cleanup_worker, + name=f"EngineManager-cleanup-{id(self)}", + daemon=True, + ) + self._cleanup_thread.start() + logger.info( + "Started cleanup thread with %ds interval", + self.cleanup_interval.total_seconds(), + ) + + def stop_cleanup_thread(self) -> None: + """ + Stop the background cleanup thread gracefully. + + This will signal the thread to stop and wait for it to finish. + Safe to call even if no thread is running. + """ + with self._cleanup_thread_lock: + if self._cleanup_thread is not None and self._cleanup_thread.is_alive(): + self._cleanup_stop_event.set() + self._cleanup_thread.join(timeout=5.0) # 5 second timeout + if self._cleanup_thread.is_alive(): + logger.warning("Cleanup thread did not stop within timeout") + else: + logger.info("Cleanup thread stopped") + self._cleanup_thread = None + + def _cleanup_worker(self) -> None: + """ + Background thread worker that periodically cleans up abandoned locks. + """ + while not self._cleanup_stop_event.is_set(): + try: + self._cleanup_abandoned_locks() + except Exception: + logger.exception("Error during background cleanup") + + # Use wait() instead of sleep() to allow for immediate shutdown + if self._cleanup_stop_event.wait( + timeout=self.cleanup_interval.total_seconds() + ): + break # Stop event was set + + def cleanup(self) -> None: + """ + Public method to manually trigger cleanup of abandoned locks. + + This can be called periodically by external systems to prevent + memory leaks from accumulating locks. + """ + self._cleanup_abandoned_locks() + + def _cleanup_abandoned_locks(self) -> None: + """ + Clean up locks for engines and tunnels that no longer exist. + + This prevents memory leaks from accumulating locks when engines/tunnels + are disposed outside of normal cleanup paths. + """ + # Clean up engine locks for inactive engines + active_engine_keys = set(self._engines.keys()) + self._engine_locks.cleanup(active_engine_keys) + + # Clean up tunnel locks for inactive tunnels + active_tunnel_keys = set(self._tunnels.keys()) + self._tunnel_locks.cleanup(active_tunnel_keys) + + # Log for debugging + if active_engine_keys or active_tunnel_keys: + logger.debug( + "EngineManager resources - Engines: %d, Tunnels: %d", + len(active_engine_keys), + len(active_tunnel_keys), + ) + + def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> None: + @event.listens_for(engine, "engine_disposed") + def on_engine_disposed(engine_instance: Engine) -> None: + try: + # Remove engine from cache - no per-key locks to clean up anymore + if self._engines.pop(engine_key, None): + # Log only first 8 chars of hash for safety + # (still enough for debugging, but doesn't expose full key) + log_key = engine_key[:8] + "..." + logger.info("Engine disposed and removed from cache: %s", log_key) + except Exception as ex: + logger.error("Error during engine disposal cleanup: %s", str(ex)) + # Don't log engine_key to avoid exposing credential hash diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index 628af40cd621..a396e207ee9e 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -41,7 +41,7 @@ def get_sqla_class() -> Any: from superset.async_events.async_query_manager import AsyncQueryManager from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory -from superset.extensions.ssh import SSHManagerFactory +from superset.extensions.engine_manager import EngineManagerExtension from superset.extensions.stats_logger import BaseStatsLoggerManager from superset.security.manager import SupersetSecurityManager from superset.utils.cache_manager import CacheManager @@ -136,6 +136,7 @@ def init_app(self, app: Flask) -> None: celery_app = celery.Celery() csrf = CSRFProtect() db = get_sqla_class()() +engine_manager_extension = EngineManagerExtension() _event_logger: dict[str, Any] = {} encrypted_field_factory = EncryptedFieldFactory() event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) @@ -146,6 +147,5 @@ def init_app(self, app: Flask) -> None: profiling = ProfilingExtension() results_backend_manager = ResultsBackendManager() security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm) -ssh_manager_factory = SSHManagerFactory() stats_logger_manager = BaseStatsLoggerManager() talisman = Talisman() diff --git a/superset/extensions/engine_manager.py b/superset/extensions/engine_manager.py new file mode 100644 index 000000000000..e15ead09b43d --- /dev/null +++ b/superset/extensions/engine_manager.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from datetime import timedelta + +from flask import Flask + +from superset.engines.manager import EngineManager, EngineModes + +logger = logging.getLogger(__name__) + + +class EngineManagerExtension: + """ + Flask extension for managing SQLAlchemy engines in Superset. + + This extension creates and configures an EngineManager instance based on + Flask configuration, handling startup and shutdown of background cleanup + threads as needed. + """ + + def __init__(self) -> None: + self.engine_manager: EngineManager | None = None + + def init_app(self, app: Flask) -> None: + """ + Initialize the EngineManager with Flask app configuration. + """ + engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"] + db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"] + mode = app.config["ENGINE_MANAGER_MODE"] + cleanup_interval = app.config["ENGINE_MANAGER_CLEANUP_INTERVAL"] + local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] + tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"]) + ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]) + auto_start_cleanup = app.config["ENGINE_MANAGER_AUTO_START_CLEANUP"] + + # Create the engine manager + self.engine_manager = EngineManager( + engine_context_manager, + db_connection_mutator, + mode, + cleanup_interval, + local_bind_address, + tunnel_timeout, + ssh_timeout, + ) + + # Start cleanup thread if requested and in SINGLETON mode + if auto_start_cleanup and mode == EngineModes.SINGLETON: + self.engine_manager.start_cleanup_thread() + logger.info("Started EngineManager cleanup thread") + + # Register shutdown handler + def shutdown_engine_manager() -> None: + if self.engine_manager: + self.engine_manager.stop_cleanup_thread() + + app.teardown_appcontext_funcs.append(lambda exc: None) + + # Register with atexit for clean shutdown + import atexit + + atexit.register(shutdown_engine_manager) + + logger.info( + "Initialized EngineManager with mode=%s, cleanup_interval=%ds", + mode, + cleanup_interval.total_seconds(), + ) + + @property + def manager(self) -> EngineManager: + """ + Get the EngineManager instance. + + Raises: + RuntimeError: If the extension hasn't been initialized with an app. + """ + if self.engine_manager is None: + raise RuntimeError( + "EngineManager extension not initialized. " + "Call init_app() with a Flask app first." + ) + return self.engine_manager diff --git a/superset/extensions/ssh.py b/superset/extensions/ssh.py deleted file mode 100644 index 74fb44cfd75e..000000000000 --- a/superset/extensions/ssh.py +++ /dev/null @@ -1,94 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from io import StringIO -from typing import TYPE_CHECKING - -import sshtunnel -from flask import Flask -from paramiko import RSAKey - -from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError -from superset.databases.utils import make_url_safe -from superset.utils.class_utils import load_class_from_name - -if TYPE_CHECKING: - from superset.databases.ssh_tunnel.models import SSHTunnel - - -class SSHManager: - def __init__(self, app: Flask) -> None: - super().__init__() - self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] - sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"] - sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"] - - def build_sqla_url( - self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder - ) -> str: - # override any ssh tunnel configuration object - url = make_url_safe(sqlalchemy_url) - return url.set( - host=server.local_bind_address[0], - port=server.local_bind_port, - ) - - def create_tunnel( - self, - ssh_tunnel: "SSHTunnel", - sqlalchemy_database_uri: str, - ) -> sshtunnel.SSHTunnelForwarder: - from superset.utils.ssh_tunnel import get_default_port - - url = make_url_safe(sqlalchemy_database_uri) - backend = url.get_backend_name() - port = url.port or get_default_port(backend) - if not port: - raise SSHTunnelDatabasePortError() - params = { - "ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port), - "ssh_username": ssh_tunnel.username, - "remote_bind_address": (url.host, port), - "local_bind_address": (self.local_bind_address,), - "debug_level": logging.getLogger("flask_appbuilder").level, - } - - if ssh_tunnel.password: - params["ssh_password"] = ssh_tunnel.password - elif ssh_tunnel.private_key: - private_key_file = StringIO(ssh_tunnel.private_key) - private_key = RSAKey.from_private_key( - private_key_file, ssh_tunnel.private_key_password - ) - params["ssh_pkey"] = private_key - - return sshtunnel.open_tunnel(**params) - - -class SSHManagerFactory: - def __init__(self) -> None: - self._ssh_manager = None - - def init_app(self, app: Flask) -> None: - self._ssh_manager = load_class_from_name( - app.config["SSH_TUNNEL_MANAGER_CLASS"] - )(app) - - @property - def instance(self) -> SSHManager: - return self._ssh_manager # type: ignore diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 3a34d315bf52..4cefa0c7337f 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -49,13 +49,13 @@ csrf, db, encrypted_field_factory, + engine_manager_extension, feature_flag_manager, machine_auth_provider_factory, manifest_processor, migrate, profiling, results_backend_manager, - ssh_manager_factory, stats_logger_manager, talisman, ) @@ -585,8 +585,8 @@ def init_app_in_ctx(self) -> None: self.configure_url_map_converters() self.configure_data_sources() self.configure_auth_provider() + self.configure_engine_manager() self.configure_async_queries() - self.configure_ssh_manager() self.configure_stats_manager() # Hook that provides administrators a handle on the Flask APP @@ -761,8 +761,8 @@ def set_db_default_isolation(self) -> None: def configure_auth_provider(self) -> None: machine_auth_provider_factory.init_app(self.superset_app) - def configure_ssh_manager(self) -> None: - ssh_manager_factory.init_app(self.superset_app) + def configure_engine_manager(self) -> None: + engine_manager_extension.init_app(self.superset_app) def configure_stats_manager(self) -> None: stats_logger_manager.init_app(self.superset_app) diff --git a/superset/models/core.py b/superset/models/core.py index b9b7b6059119..f20e0fce1e56 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -25,24 +25,22 @@ import logging import textwrap from ast import literal_eval -from contextlib import closing, contextmanager, nullcontext, suppress +from contextlib import closing, contextmanager, suppress from copy import deepcopy from datetime import datetime from functools import lru_cache from inspect import signature -from typing import Any, Callable, cast, Optional, TYPE_CHECKING +from typing import Any, Callable, cast, Iterator, Optional, TYPE_CHECKING import numpy import pandas as pd import sqlalchemy as sqla -import sshtunnel from flask import current_app as app, g, has_app_context from flask_appbuilder import Model from marshmallow.exceptions import ValidationError from sqlalchemy import ( Boolean, Column, - create_engine, DateTime, ForeignKey, Integer, @@ -57,7 +55,6 @@ from sqlalchemy.exc import NoSuchModuleError from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import relationship -from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import ColumnElement, expression, Select from superset_core.api.models import Database as CoreDatabase @@ -72,7 +69,6 @@ encrypted_field_factory, event_logger, security_manager, - ssh_manager_factory, ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin from superset.result_set import SupersetResultSet @@ -84,10 +80,9 @@ ) from superset.utils import cache as cache_util, core as utils, json from superset.utils.backports import StrEnum -from superset.utils.core import get_query_source_from_request, get_username +from superset.utils.core import get_username from superset.utils.oauth2 import ( check_for_oauth2, - get_oauth2_access_token, OAuth2ClientConfigSchema, ) @@ -424,130 +419,31 @@ def get_effective_user(self, object_url: URL) -> str | None: ) @contextmanager - def get_sqla_engine( # pylint: disable=too-many-arguments + def get_sqla_engine( self, catalog: str | None = None, schema: str | None = None, - nullpool: bool = True, source: utils.QuerySource | None = None, - ) -> Engine: + ) -> Iterator[Engine]: """ Context manager for a SQLAlchemy engine. - This method will return a context manager for a SQLAlchemy engine. Using the - context manager (as opposed to the engine directly) is important because we need - to potentially establish SSH tunnels before the connection is created, and clean - them up once the engine is no longer used. + This method will return a context manager for a SQLAlchemy engine. The engine + manager handles connection pooling, SSH tunnels, and other connection details + based on the configured mode (NEW or SINGLETON). """ + # Import here to avoid circular imports + from superset.extensions import engine_manager_extension - sqlalchemy_uri = self.sqlalchemy_uri_decrypted - - ssh_context_manager = ( - ssh_manager_factory.instance.create_tunnel( - ssh_tunnel=self.ssh_tunnel, - sqlalchemy_database_uri=sqlalchemy_uri, - ) - if self.ssh_tunnel - else nullcontext() - ) - - with ssh_context_manager as ssh_context: - if ssh_context: - logger.info( - "[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s " - "ssh_timeout at %s", - sshtunnel.TUNNEL_TIMEOUT, - sshtunnel.SSH_TIMEOUT, - ssh_context.local_bind_address, - ) - sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url( - sqlalchemy_uri, - ssh_context, - ) - - engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"] - with engine_context_manager(self, catalog, schema): - with check_for_oauth2(self): - yield self._get_sqla_engine( - catalog=catalog, - schema=schema, - nullpool=nullpool, - source=source, - sqlalchemy_uri=sqlalchemy_uri, - ) - - def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901 - self, - catalog: str | None = None, - schema: str | None = None, - nullpool: bool = True, - source: utils.QuerySource | None = None, - sqlalchemy_uri: str | None = None, - ) -> Engine: - sqlalchemy_url = make_url_safe( - sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted - ) - self.db_engine_spec.validate_database_uri(sqlalchemy_url) - - extra = self.get_extra(source) - engine_kwargs = extra.get("engine_params", {}) - if nullpool: - engine_kwargs["poolclass"] = NullPool - connect_args = engine_kwargs.setdefault("connect_args", {}) - - # modify URL/args for a specific catalog/schema - sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params( - uri=sqlalchemy_url, - connect_args=connect_args, + # Use the engine manager to get the engine + engine_manager = engine_manager_extension.manager + with engine_manager.get_engine( + database=self, catalog=catalog, schema=schema, - ) - - effective_username = self.get_effective_user(sqlalchemy_url) - if effective_username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"): - user = security_manager.find_user(username=effective_username) - if user and user.email: - effective_username = user.email.split("@")[0] - - oauth2_config = self.get_oauth2_config() - access_token = ( - get_oauth2_access_token( - oauth2_config, - self.id, - g.user.id, - self.db_engine_spec, - ) - if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id") - else None - ) - masked_url = self.get_password_masked_url(sqlalchemy_url) - logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) - - if self.impersonate_user: - sqlalchemy_url, engine_kwargs = self.db_engine_spec.impersonate_user( - self, - effective_username, - access_token, - sqlalchemy_url, - engine_kwargs, - ) - - self.update_params_from_encrypted_extra(engine_kwargs) - - if DB_CONNECTION_MUTATOR := app.config["DB_CONNECTION_MUTATOR"]: # noqa: N806 - source = source or get_query_source_from_request() - - sqlalchemy_url, engine_kwargs = DB_CONNECTION_MUTATOR( - sqlalchemy_url, - engine_kwargs, - effective_username, - security_manager, - source, - ) - try: - return create_engine(sqlalchemy_url, **engine_kwargs) - except Exception as ex: - raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex + source=source, + ) as engine: + yield engine def add_database_to_signature( self, @@ -572,13 +468,11 @@ def get_raw_connection( self, catalog: str | None = None, schema: str | None = None, - nullpool: bool = True, source: utils.QuerySource | None = None, ) -> Connection: with self.get_sqla_engine( catalog=catalog, schema=schema, - nullpool=nullpool, source=source, ) as engine: with check_for_oauth2(self): diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 02e294a08cfb..e3252483e84c 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -18,17 +18,42 @@ from collections.abc import Hashable, Sequence from datetime import datetime -from typing import Any, Literal, TYPE_CHECKING, TypeAlias, TypedDict +from typing import ( + Any, + Callable, + ContextManager, + Literal, + TYPE_CHECKING, + TypeAlias, + TypedDict, +) +from sqlalchemy.engine.url import URL from sqlalchemy.sql.type_api import TypeEngine from typing_extensions import NotRequired from werkzeug.wrappers import Response if TYPE_CHECKING: - from superset.utils.core import GenericDataType, QueryObjectFilterClause + from superset.models.core import Database + from superset.utils.core import ( + GenericDataType, + QueryObjectFilterClause, + QuerySource, + ) SQLType: TypeAlias = TypeEngine | type[TypeEngine] +# Type alias for database connection mutator function +DBConnectionMutator: TypeAlias = Callable[ + [URL, dict[str, Any], str | None, Any, "QuerySource | None"], + tuple[URL, dict[str, Any]], +] + +# Type alias for engine context manager +EngineContextManager: TypeAlias = Callable[ + ["Database", str | None, str | None], ContextManager[None] +] + class LegacyMetric(TypedDict): label: str | None diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 4f6ce10b0f9a..95f1015e85f1 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -170,7 +170,6 @@ def __call__(self) -> Database: return self._db def _load_lazy_data_to_decouple_from_session(self) -> None: - self._db._get_sqla_engine() # type: ignore self._db.backend # type: ignore # noqa: B018 def remove(self) -> None: diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 27c1ce565428..21939386abba 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -897,7 +897,7 @@ def test_import_v1_rollback(self, mock_add_permissions, mock_import_dataset): class TestTestConnectionDatabaseCommand(SupersetTestCase): - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_db_exception( @@ -906,19 +906,23 @@ def test_connection_db_exception( """Test to make sure event_logger is called when an exception is raised""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.side_effect = Exception("An error has occurred!") + mock_get_sqla_engine.return_value.__enter__.side_effect = Exception( + "An error has occurred!" + ) db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) - with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012 + with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: command_without_db_name.run() - assert str(excinfo.value) == ( - "Unexpected error occurred, please check your logs for details" - ) + # Exception wraps errors from db_engine_spec.extract_errors() + assert ( + excinfo.value.errors[0].error_type + == SupersetErrorType.GENERIC_DB_ENGINE_ERROR + ) mock_event_logger.assert_called() - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_do_ping_exception( @@ -927,9 +931,8 @@ def test_connection_do_ping_exception( """Test to make sure do_ping exceptions gets captured""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception( - "An error has occurred!" - ) + mock_engine = mock_get_sqla_engine.return_value.__enter__.return_value + mock_engine.dialect.do_ping.side_effect = Exception("An error has occurred!") db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) @@ -967,7 +970,7 @@ def test_connection_do_ping_timeout( == SupersetErrorType.CONNECTION_DATABASE_TIMEOUT ) - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_superset_security_connection( @@ -977,20 +980,20 @@ def test_connection_superset_security_connection( connection exc is raised""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.side_effect = SupersetSecurityException( - SupersetError(error_type=500, message="test", level="info") + mock_get_sqla_engine.return_value.__enter__.side_effect = ( + SupersetSecurityException( + SupersetError(error_type=500, message="test", level="info") + ) ) db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) - with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: # noqa: PT012 + with pytest.raises(DatabaseSecurityUnsafeError): command_without_db_name.run() - assert str(excinfo.value) == ("Stopped an unsafe database connection") - mock_event_logger.assert_called() - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_db_api_exc( @@ -999,19 +1002,20 @@ def test_connection_db_api_exc( """Test to make sure event_logger is called when DBAPIError is raised""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.side_effect = DBAPIError( + mock_get_sqla_engine.return_value.__enter__.side_effect = DBAPIError( statement="error", params={}, orig={} ) db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) - with pytest.raises(SupersetErrorsException) as excinfo: # noqa: PT012 + with pytest.raises(SupersetErrorsException) as excinfo: command_without_db_name.run() - assert str(excinfo.value) == ( - "Connection failed, please check your connection settings" - ) - + # Exception wraps errors from db_engine_spec.extract_errors() + assert ( + excinfo.value.errors[0].error_type + == SupersetErrorType.GENERIC_DB_ENGINE_ERROR + ) mock_event_logger.assert_called() @@ -1147,7 +1151,7 @@ def test_database_tables_list_with_unknown_database(self, mock_find_by_id): with pytest.raises(DatabaseNotFoundError) as excinfo: # noqa: PT012 command.run() - assert str(excinfo.value) == ("Database not found.") + assert str(excinfo.value) == ("Database not found.") @patch("superset.daos.database.DatabaseDAO.find_by_id") @patch("superset.security.manager.SupersetSecurityManager.can_access_database") @@ -1166,26 +1170,35 @@ def test_database_tables_superset_exception( command = TablesDatabaseCommand(database.id, None, "main", False) with pytest.raises(SupersetException) as excinfo: # noqa: PT012 command.run() - assert str(excinfo.value) == "Test Error" + assert str(excinfo.value) == "Test Error" @patch("superset.daos.database.DatabaseDAO.find_by_id") + @patch("superset.models.core.Database.get_all_materialized_view_names_in_schema") + @patch("superset.models.core.Database.get_all_view_names_in_schema") + @patch("superset.models.core.Database.get_all_table_names_in_schema") @patch("superset.security.manager.SupersetSecurityManager.can_access_database") @patch("superset.utils.core.g") def test_database_tables_exception( - self, mock_g, mock_can_access_database, mock_find_by_id + self, + mock_g, + mock_can_access_database, + mock_get_tables, + mock_get_views, + mock_get_mvs, + mock_find_by_id, ): database = get_example_database() mock_find_by_id.return_value = database + mock_get_tables.return_value = {("table1", "main", None)} + mock_get_views.return_value = set() + mock_get_mvs.return_value = [] mock_can_access_database.side_effect = Exception("Test Error") mock_g.user = security_manager.find_user("admin") command = TablesDatabaseCommand(database.id, None, "main", False) with pytest.raises(DatabaseTablesUnexpectedError) as excinfo: # noqa: PT012 command.run() - assert ( - str(excinfo.value) - == "Unexpected error occurred, please check your logs for details" - ) + assert str(excinfo.value) == "Test Error" @patch("superset.daos.database.DatabaseDAO.find_by_id") @patch("superset.security.manager.SupersetSecurityManager.can_access_database") diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index e80fea7db883..4e9987fe3f35 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -145,7 +145,7 @@ def test_database_impersonate_user(self): username = make_url(engine.url).username assert example_user.username != username - @mock.patch("superset.models.core.create_engine") + @mock.patch("superset.engines.manager.create_engine") @unittest.skipUnless( SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed" ) @@ -172,7 +172,8 @@ def test_impersonate_user_presto(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri=uri, extra=extra ) model.impersonate_user = True - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://gamma@localhost/" @@ -185,7 +186,8 @@ def test_impersonate_user_presto(self, mocked_create_engine): } model.impersonate_user = False - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://localhost/" @@ -199,13 +201,14 @@ def test_impersonate_user_presto(self, mocked_create_engine): @unittest.skipUnless( SupersetTestCase.is_module_installed("mysqlclient"), "mysqlclient not installed" ) - @mock.patch("superset.models.core.create_engine") + @mock.patch("superset.engines.manager.create_engine") def test_adjust_engine_params_mysql(self, mocked_create_engine): model = Database( database_name="test_database1", sqlalchemy_uri="mysql://user:password@localhost", ) - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "mysql://user:password@localhost" @@ -215,13 +218,14 @@ def test_adjust_engine_params_mysql(self, mocked_create_engine): database_name="test_database2", sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost", ) - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost" assert call_args[1]["connect_args"]["allow_local_infile"] == 0 - @mock.patch("superset.models.core.create_engine") + @mock.patch("superset.engines.manager.create_engine") def test_impersonate_user_trino(self, mocked_create_engine): principal_user = security_manager.find_user(username="gamma") @@ -230,7 +234,8 @@ def test_impersonate_user_trino(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri="trino://localhost" ) model.impersonate_user = True - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "trino://localhost/" @@ -242,7 +247,8 @@ def test_impersonate_user_trino(self, mocked_create_engine): ) model.impersonate_user = True - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert ( @@ -251,7 +257,7 @@ def test_impersonate_user_trino(self, mocked_create_engine): ) assert call_args[1]["connect_args"]["user"] == "gamma" - @mock.patch("superset.models.core.create_engine") + @mock.patch("superset.engines.manager.create_engine") @unittest.skipUnless( SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed" ) @@ -281,7 +287,8 @@ def test_impersonate_user_hive(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri=uri, extra=extra ) model.impersonate_user = True - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" @@ -294,7 +301,8 @@ def test_impersonate_user_hive(self, mocked_create_engine): } model.impersonate_user = False - model._get_sqla_engine() + with model.get_sqla_engine(): + pass call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" @@ -376,7 +384,7 @@ def test_multi_statement(self): df = main_db.get_df("USE superset; SELECT ';';", None, None) assert df.iat[0, 0] == ";" - @mock.patch("superset.models.core.create_engine") + @mock.patch("superset.engines.manager.create_engine") def test_get_sqla_engine(self, mocked_create_engine): model = Database( database_name="test_database", @@ -387,7 +395,8 @@ def test_get_sqla_engine(self, mocked_create_engine): ) mocked_create_engine.side_effect = Exception() with self.assertRaises(SupersetException): # noqa: PT027 - model._get_sqla_engine() + with model.get_sqla_engine(): + pass class TestSqlaTableModel(SupersetTestCase): diff --git a/tests/unit_tests/engines/manager_test.py b/tests/unit_tests/engines/manager_test.py new file mode 100644 index 000000000000..e0abf88f57e4 --- /dev/null +++ b/tests/unit_tests/engines/manager_test.py @@ -0,0 +1,527 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Unit tests for EngineManager.""" + +import threading +from collections.abc import Iterator +from unittest.mock import MagicMock, patch + +import pytest + +from superset.engines.manager import _LockManager, EngineManager, EngineModes + + +class TestLockManager: + """Test the _LockManager class.""" + + def test_get_lock_creates_new_lock(self): + """Test that get_lock creates a new lock when needed.""" + manager = _LockManager() + lock1 = manager.get_lock("key1") + + assert isinstance(lock1, type(threading.RLock())) + assert lock1 is manager.get_lock("key1") # Same lock returned + + def test_get_lock_different_keys_different_locks(self): + """Test that different keys get different locks.""" + manager = _LockManager() + lock1 = manager.get_lock("key1") + lock2 = manager.get_lock("key2") + + assert lock1 is not lock2 + + def test_cleanup_removes_unused_locks(self): + """Test that cleanup removes locks for inactive keys.""" + manager = _LockManager() + + # Create locks + _ = manager.get_lock("key1") # noqa: F841 + lock2 = manager.get_lock("key2") + + # Cleanup with only key1 active + manager.cleanup({"key1"}) + + # key2 lock should be removed + lock3 = manager.get_lock("key2") + assert lock3 is not lock2 # New lock created + + def test_concurrent_lock_creation(self): + """Test that concurrent lock creation doesn't create duplicates.""" + manager = _LockManager() + locks_created = [] + exceptions = [] + + def create_lock(): + try: + lock = manager.get_lock("concurrent_key") + locks_created.append(lock) + except Exception as e: + exceptions.append(e) + + # Create multiple threads trying to get the same lock + threads = [threading.Thread(target=create_lock) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(exceptions) == 0 + assert len(locks_created) == 10 + + # All should be the same lock + first_lock = locks_created[0] + for lock in locks_created[1:]: + assert lock is first_lock + + +class TestEngineManager: + """Test the EngineManager class.""" + + @pytest.fixture + def engine_manager(self): + """Create a mock EngineManager instance.""" + from contextlib import contextmanager + + @contextmanager + def dummy_context_manager( + database: MagicMock, catalog: str | None, schema: str | None + ) -> Iterator[None]: + yield + + return EngineManager(engine_context_manager=dummy_context_manager) + + @pytest.fixture + def mock_database(self): + """Create a mock database.""" + database = MagicMock() + database.sqlalchemy_uri_decrypted = "postgresql://user:pass@localhost/test" + database.get_extra.return_value = {"engine_params": {}} + database.get_effective_user.return_value = "test_user" + database.impersonate_user = False + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = (MagicMock(), {}) + database.db_engine_spec.impersonate_user = MagicMock( + return_value=(MagicMock(), {}) + ) + database.db_engine_spec.validate_database_uri = MagicMock() + database.ssh_tunnel = None + return database + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_new_mode( + self, mock_make_url, mock_create_engine, engine_manager, mock_database + ): + """Test getting an engine in NEW mode (no caching).""" + engine_manager.mode = EngineModes.NEW + + mock_make_url.return_value = MagicMock() + mock_engine1 = MagicMock() + mock_engine2 = MagicMock() + mock_create_engine.side_effect = [mock_engine1, mock_engine2] + + result = engine_manager._get_engine(mock_database, "catalog1", "schema1", None) + + assert result is mock_engine1 + mock_create_engine.assert_called_once() + + # Calling again should create a new engine (no caching) + mock_create_engine.reset_mock() + result2 = engine_manager._get_engine(mock_database, "catalog2", "schema2", None) + + assert result2 is mock_engine2 # Different engine + mock_create_engine.assert_called_once() + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_singleton_mode_caching( + self, mock_make_url, mock_create_engine, engine_manager, mock_database + ): + """Test that engines are cached in SINGLETON mode.""" + engine_manager.mode = EngineModes.SINGLETON + + # Use a real engine instead of MagicMock to avoid event listener issues + from sqlalchemy import create_engine + from sqlalchemy.pool import StaticPool + + real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool) + mock_create_engine.return_value = real_engine + mock_make_url.return_value = real_engine + + # Call twice with same params - should be cached + result1 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None) + result2 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None) + + assert result1 is result2 # Same engine returned (cached) + mock_create_engine.assert_called_once() # Only created once + + # Call with different params - should create new engine + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_concurrent_engine_creation( + self, mock_make_url, mock_create_engine, engine_manager, mock_database + ): + """Test concurrent engine creation doesn't create duplicates.""" + engine_manager.mode = EngineModes.SINGLETON + + # Use a real engine to avoid event listener issues with MagicMock + from sqlalchemy import create_engine + from sqlalchemy.pool import StaticPool + + real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool) + mock_make_url.return_value = real_engine + + create_count = [0] + + def counting_create_engine(*args, **kwargs): + create_count[0] += 1 + return real_engine + + mock_create_engine.side_effect = counting_create_engine + + results = [] + exceptions = [] + + def get_engine_thread(): + try: + engine = engine_manager._get_engine( + mock_database, "catalog1", "schema1", None + ) + results.append(engine) + except Exception as e: + exceptions.append(e) + + # Run multiple threads + threads = [threading.Thread(target=get_engine_thread) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(exceptions) == 0 + assert len(results) == 10 + assert create_count[0] == 1 # Engine created only once + + # All results should be the same engine + for engine in results: + assert engine is real_engine + + @patch("superset.engines.manager.sshtunnel.open_tunnel") + def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager): + """Test SSH tunnel creation and caching.""" + ssh_tunnel = MagicMock() + ssh_tunnel.server_address = "ssh.example.com" + ssh_tunnel.server_port = 22 + ssh_tunnel.username = "ssh_user" + ssh_tunnel.password = "ssh_pass" # noqa: S105 + ssh_tunnel.private_key = None + ssh_tunnel.private_key_password = None + + tunnel_instance = MagicMock() + tunnel_instance.is_active = True + tunnel_instance.local_bind_address = ("127.0.0.1", 12345) + mock_open_tunnel.return_value = tunnel_instance + + uri = MagicMock() + uri.host = "db.example.com" + uri.port = 5432 + uri.get_backend_name.return_value = "postgresql" + + result = engine_manager._get_tunnel(ssh_tunnel, uri) + + assert result is tunnel_instance + mock_open_tunnel.assert_called_once() + tunnel_instance.start.assert_called_once() + + # Getting same tunnel again should return cached version + mock_open_tunnel.reset_mock() + result2 = engine_manager._get_tunnel(ssh_tunnel, uri) + + assert result2 is tunnel_instance + mock_open_tunnel.assert_not_called() + + @patch("superset.engines.manager.sshtunnel.open_tunnel") + def test_ssh_tunnel_recreation_when_inactive( + self, mock_open_tunnel, engine_manager + ): + """Test that inactive tunnels are replaced.""" + ssh_tunnel = MagicMock() + ssh_tunnel.server_address = "ssh.example.com" + ssh_tunnel.server_port = 22 + ssh_tunnel.username = "ssh_user" + ssh_tunnel.password = "ssh_pass" # noqa: S105 + ssh_tunnel.private_key = None + ssh_tunnel.private_key_password = None + + # First tunnel is inactive + inactive_tunnel = MagicMock() + inactive_tunnel.is_active = False + inactive_tunnel.local_bind_address = ("127.0.0.1", 12345) + + # Second tunnel is active + active_tunnel = MagicMock() + active_tunnel.is_active = True + active_tunnel.local_bind_address = ("127.0.0.1", 23456) + + mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel] + + uri = MagicMock() + uri.host = "db.example.com" + uri.port = 5432 + uri.get_backend_name.return_value = "postgresql" + + # First call creates inactive tunnel + result1 = engine_manager._get_tunnel(ssh_tunnel, uri) + assert result1 is inactive_tunnel + + # Second call should create new tunnel since first is inactive + result2 = engine_manager._get_tunnel(ssh_tunnel, uri) + assert result2 is active_tunnel + assert mock_open_tunnel.call_count == 2 + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_args_basic( + self, mock_make_url, mock_create_engine, engine_manager + ): + """Test _get_engine_args returns correct URI and kwargs.""" + from sqlalchemy.engine.url import make_url + + from superset.engines.manager import EngineModes + + engine_manager.mode = EngineModes.NEW + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + database = MagicMock() + database.id = 1 + database.sqlalchemy_uri_decrypted = "trino://" + database.get_extra.return_value = { + "engine_params": {}, + "connect_args": {"source": "Apache Superset"}, + } + database.get_effective_user.return_value = "alice" + database.impersonate_user = False + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = ( + mock_uri, + {"source": "Apache Superset"}, + ) + database.db_engine_spec.validate_database_uri = MagicMock() + + uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None) + + assert str(uri) == "trino://" + assert "connect_args" in database.get_extra.return_value + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_args_user_impersonation( + self, mock_make_url, mock_create_engine, engine_manager + ): + """Test user impersonation in _get_engine_args.""" + from sqlalchemy.engine.url import make_url + + from superset.engines.manager import EngineModes + + engine_manager.mode = EngineModes.NEW + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + database = MagicMock() + database.id = 1 + database.sqlalchemy_uri_decrypted = "trino://" + database.get_extra.return_value = { + "engine_params": {}, + "connect_args": {"source": "Apache Superset"}, + } + database.get_effective_user.return_value = "alice" + database.impersonate_user = True + database.get_oauth2_config.return_value = None + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = ( + mock_uri, + {"source": "Apache Superset"}, + ) + database.db_engine_spec.impersonate_user.return_value = ( + mock_uri, + {"connect_args": {"user": "alice", "source": "Apache Superset"}}, + ) + database.db_engine_spec.validate_database_uri = MagicMock() + + uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None) + + # Verify impersonate_user was called + database.db_engine_spec.impersonate_user.assert_called_once() + call_args = database.db_engine_spec.impersonate_user.call_args + assert call_args[0][0] is database # database + assert call_args[0][1] == "alice" # username + assert call_args[0][2] is None # access_token (no OAuth2) + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_args_user_impersonation_email_prefix( + self, + mock_make_url, + mock_create_engine, + engine_manager, + ): + """Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag.""" + from sqlalchemy.engine.url import make_url + + from superset.engines.manager import EngineModes + + engine_manager.mode = EngineModes.NEW + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + # Mock user with email + mock_user = MagicMock() + mock_user.email = "alice.doe@example.org" + + database = MagicMock() + database.id = 1 + database.sqlalchemy_uri_decrypted = "trino://" + database.get_extra.return_value = { + "engine_params": {}, + "connect_args": {"source": "Apache Superset"}, + } + database.get_effective_user.return_value = "alice" + database.impersonate_user = True + database.get_oauth2_config.return_value = None + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = ( + mock_uri, + {"source": "Apache Superset"}, + ) + database.db_engine_spec.impersonate_user.return_value = ( + mock_uri, + {"connect_args": {"user": "alice.doe", "source": "Apache Superset"}}, + ) + database.db_engine_spec.validate_database_uri = MagicMock() + + with ( + patch( + "superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled", + return_value=True, + ), + patch( + "superset.extensions.security_manager.find_user", + return_value=mock_user, + ), + ): + uri, kwargs = engine_manager._get_engine_args( + database, None, None, None, None + ) + + # Verify impersonate_user was called with the email prefix + database.db_engine_spec.impersonate_user.assert_called_once() + call_args = database.db_engine_spec.impersonate_user.call_args + assert call_args[0][1] == "alice.doe" # username from email prefix + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_engine_context_manager_called( + self, mock_make_url, mock_create_engine, engine_manager, mock_database + ): + """Test that the engine context manager is properly called.""" + from sqlalchemy.engine.url import make_url + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + # Track context manager calls + context_manager_calls = [] + + def tracking_context_manager(database, catalog, schema): + from contextlib import contextmanager + + @contextmanager + def inner(): + context_manager_calls.append(("enter", database, catalog, schema)) + yield + context_manager_calls.append(("exit", database, catalog, schema)) + + return inner() + + engine_manager.engine_context_manager = tracking_context_manager + + with engine_manager.get_engine(mock_database, "catalog1", "schema1", None): + pass + + assert len(context_manager_calls) == 2 + assert context_manager_calls[0][0] == "enter" + assert context_manager_calls[0][1] is mock_database + assert context_manager_calls[0][2] == "catalog1" + assert context_manager_calls[0][3] == "schema1" + assert context_manager_calls[1][0] == "exit" + + @patch("superset.utils.oauth2.check_for_oauth2") + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_engine_oauth2_error_handling( + self, + mock_make_url, + mock_create_engine, + mock_check_for_oauth2, + engine_manager, + mock_database, + ): + """Test that OAuth2 errors are properly propagated from get_engine.""" + from contextlib import contextmanager + + from sqlalchemy.engine.url import make_url + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + # Simulate OAuth2 error during engine creation + class OAuth2TestError(Exception): + pass + + oauth_error = OAuth2TestError("OAuth2 required") + mock_create_engine.side_effect = oauth_error + + # Make get_dbapi_mapped_exception return the original exception + mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = ( + oauth_error + ) + + # Mock check_for_oauth2 to re-raise the exception + @contextmanager + def mock_oauth2_context(database): + try: + yield + except OAuth2TestError: + raise + + mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database) + + with pytest.raises(OAuth2TestError, match="OAuth2 required"): + with engine_manager.get_engine(mock_database, "catalog1", "schema1", None): + pass diff --git a/tests/unit_tests/initialization_test.py b/tests/unit_tests/initialization_test.py index 01fde0967c93..93fdf4d352ec 100644 --- a/tests/unit_tests/initialization_test.py +++ b/tests/unit_tests/initialization_test.py @@ -123,7 +123,7 @@ def test_init_app_in_ctx_calls_sync_config_to_db(self, mock_logger): patch.object(app_initializer, "configure_data_sources"), patch.object(app_initializer, "configure_auth_provider"), patch.object(app_initializer, "configure_async_queries"), - patch.object(app_initializer, "configure_ssh_manager"), + patch.object(app_initializer, "configure_engine_manager"), patch.object(app_initializer, "configure_stats_manager"), patch.object(app_initializer, "init_views"), ): diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 7d7aa96ea198..b2a48df05926 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -19,7 +19,6 @@ from datetime import datetime import pytest -from flask import current_app from pytest_mock import MockerFixture from sqlalchemy import ( Column, @@ -29,7 +28,6 @@ Table as SqlalchemyTable, ) from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.engine.url import make_url from sqlalchemy.orm.session import Session from sqlalchemy.sql import Select @@ -525,60 +523,6 @@ class DriverSpecificError(Exception): assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT -def test_get_sqla_engine(mocker: MockerFixture) -> None: - """ - Test `_get_sqla_engine`. - """ - from superset.models.core import Database - - user = mocker.MagicMock() - user.email = "alice.doe@example.org" - mocker.patch( - "superset.models.core.security_manager.find_user", - return_value=user, - ) - mocker.patch("superset.models.core.get_username", return_value="alice") - - create_engine = mocker.patch("superset.models.core.create_engine") - - database = Database(database_name="my_db", sqlalchemy_uri="trino://") - database._get_sqla_engine(nullpool=False) - - create_engine.assert_called_with( - make_url("trino:///"), - connect_args={"source": "Apache Superset"}, - ) - - -def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None: - """ - Test user impersonation in `_get_sqla_engine`. - """ - from superset.models.core import Database - - user = mocker.MagicMock() - user.email = "alice.doe@example.org" - mocker.patch( - "superset.models.core.security_manager.find_user", - return_value=user, - ) - mocker.patch("superset.models.core.get_username", return_value="alice") - - create_engine = mocker.patch("superset.models.core.create_engine") - - database = Database( - database_name="my_db", - sqlalchemy_uri="trino://", - impersonate_user=True, - ) - database._get_sqla_engine(nullpool=False) - - create_engine.assert_called_with( - make_url("trino:///"), - connect_args={"user": "alice", "source": "Apache Superset"}, - ) - - def test_add_database_to_signature(): args = ["param1", "param2"] @@ -604,36 +548,6 @@ def func_with_db_end(param1, param2, database): assert args3 == ["param1", "param2", database] -@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True) -def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None: - """ - Test user impersonation in `_get_sqla_engine` with `username_from_email`. - """ - from superset.models.core import Database - - user = mocker.MagicMock() - user.email = "alice.doe@example.org" - mocker.patch( - "superset.models.core.security_manager.find_user", - return_value=user, - ) - mocker.patch("superset.models.core.get_username", return_value="alice") - - create_engine = mocker.patch("superset.models.core.create_engine") - - database = Database( - database_name="my_db", - sqlalchemy_uri="trino://", - impersonate_user=True, - ) - database._get_sqla_engine(nullpool=False) - - create_engine.assert_called_with( - make_url("trino:///"), - connect_args={"user": "alice.doe", "source": "Apache Superset"}, - ) - - def test_is_oauth2_enabled() -> None: """ Test the `is_oauth2_enabled` method. @@ -753,37 +667,6 @@ def test_get_oauth2_config_redirect_uri_from_config( assert config["redirect_uri"] == custom_redirect_uri -def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None: - """ - Test that we can start OAuth2 from `raw_connection()` errors. - - With OAuth2, some databases will raise an exception when the engine is first created - (eg, BigQuery). Others, like, Snowflake, when the connection is created. And - finally, GSheets will raise an exception when the query is executed. - - This tests verifies that when calling `raw_connection()` the OAuth2 flow is - triggered when the engine is created. - """ - g = mocker.patch("superset.db_engine_specs.base.g") - g.user = mocker.MagicMock() - g.user.id = 42 - - database = Database( - id=1, - database_name="my_db", - sqlalchemy_uri="sqlite://", - encrypted_extra=json.dumps(oauth2_client_info), - ) - database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore - _get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine") - _get_sqla_engine.side_effect = OAuth2Error("OAuth2 required") - - with pytest.raises(OAuth2RedirectError) as excinfo: - with database.get_raw_connection() as conn: - conn.cursor() - assert str(excinfo.value) == "You don't have permission to access the data." - - def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None: """ Test that we can start OAuth2 from `raw_connection()` errors. @@ -879,56 +762,6 @@ def test_get_schema_access_for_file_upload() -> None: assert database.get_schema_access_for_file_upload() == {"public"} -def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None: - """ - Test the engine context manager. - """ - from unittest.mock import MagicMock - - engine_context_manager = MagicMock() - mocker.patch.dict( - current_app.config, - {"ENGINE_CONTEXT_MANAGER": engine_context_manager}, - ) - _get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine") - - database = Database(database_name="my_db", sqlalchemy_uri="trino://") - with database.get_sqla_engine("catalog", "schema"): - pass - - engine_context_manager.assert_called_once_with(database, "catalog", "schema") - engine_context_manager().__enter__.assert_called_once() - engine_context_manager().__exit__.assert_called_once_with(None, None, None) - _get_sqla_engine.assert_called_once_with( - catalog="catalog", - schema="schema", - nullpool=True, - source=None, - sqlalchemy_uri="trino://", - ) - - -def test_engine_oauth2(mocker: MockerFixture) -> None: - """ - Test that we handle OAuth2 when `create_engine` fails. - """ - database = Database(database_name="my_db", sqlalchemy_uri="trino://") - mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception) - mocker.patch.object(database, "is_oauth2_enabled", return_value=True) - mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True) - start_oauth2_dance = mocker.patch.object( - database.db_engine_spec, - "start_oauth2_dance", - side_effect=OAuth2Error("OAuth2 required"), - ) - - with pytest.raises(OAuth2Error): - with database.get_sqla_engine("catalog", "schema"): - pass - - start_oauth2_dance.assert_called_with(database) - - def test_purge_oauth2_tokens(session: Session) -> None: """ Test the `purge_oauth2_tokens` method. diff --git a/tests/unit_tests/sql/execution/conftest.py b/tests/unit_tests/sql/execution/conftest.py index 630f68845294..fba41e6e376a 100644 --- a/tests/unit_tests/sql/execution/conftest.py +++ b/tests/unit_tests/sql/execution/conftest.py @@ -202,7 +202,6 @@ def setup_mock_raw_connection( def _raw_connection( catalog: str | None = None, schema: str | None = None, - nullpool: bool = True, source: Any | None = None, ): yield mock_connection