diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 825304db836d..bc34e39bc24c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,7 @@ repos: args: [--check-untyped-defs] exclude: ^superset-extensions-cli/ additional_dependencies: [ + types-cachetools, types-simplejson, types-python-dateutil, types-requests, diff --git a/docs/static/feature-flags.json b/docs/static/feature-flags.json index bfe0955ea89a..95738dd913fc 100644 --- a/docs/static/feature-flags.json +++ b/docs/static/feature-flags.json @@ -114,6 +114,12 @@ "lifecycle": "testing", "description": "Allow users to export full CSV of table viz type. Warning: Could cause server memory/compute issues with large datasets." }, + { + "name": "AWS_DATABASE_IAM_AUTH", + "default": false, + "lifecycle": "testing", + "description": "Enable AWS IAM authentication for database connections (Aurora, Redshift). Allows cross-account role assumption via STS AssumeRole. Security note: When enabled, ensure Superset's IAM role has restricted sts:AssumeRole permissions to prevent unauthorized access." + }, { "name": "CACHE_IMPERSONATION", "default": false, diff --git a/pyproject.toml b/pyproject.toml index fc37dbe89c16..87496d5e9726 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,6 +204,7 @@ ydb = ["ydb-sqlalchemy>=0.1.2"] development = [ # no bounds for apache-superset-extensions-cli until a stable version "apache-superset-extensions-cli", + "boto3", "docker", "flask-testing", "freezegun", diff --git a/requirements/development.txt b/requirements/development.txt index c91b6a3646ec..d26c1c78b91b 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -76,6 +76,12 @@ blinker==1.9.0 # via # -c requirements/base-constraint.txt # flask +boto3==1.42.39 + # via apache-superset +botocore==1.42.39 + # via + # boto3 + # s3transfer bottleneck==1.5.0 # via # -c requirements/base-constraint.txt @@ -460,6 +466,10 @@ jinja2==3.1.6 # apache-superset-extensions-cli # flask # flask-babel +jmespath==1.1.0 + # via + # boto3 + # botocore jsonpath-ng==1.7.0 # via # -c requirements/base-constraint.txt @@ -812,6 +822,7 @@ python-dateutil==2.9.0.post0 # via # -c requirements/base-constraint.txt # apache-superset + # botocore # celery # croniter # flask-appbuilder @@ -915,6 +926,8 @@ rsa==4.9.1 # google-auth ruff==0.9.7 # via apache-superset +s3transfer==0.16.0 + # via boto3 secretstorage==3.5.0 # via keyring selenium==4.32.0 @@ -1066,6 +1079,7 @@ url-normalize==2.2.1 urllib3==2.6.3 # via # -c requirements/base-constraint.txt + # botocore # docker # requests # requests-cache diff --git a/superset/config.py b/superset/config.py index 5e66fe40f48e..b048fc2f8d54 100644 --- a/superset/config.py +++ b/superset/config.py @@ -656,6 +656,12 @@ class D3TimeFormat(TypedDict, total=False): # @lifecycle: testing # @docs: https://superset.apache.org/docs/configuration/setup-ssh-tunneling "SSH_TUNNELING": False, + # Enable AWS IAM authentication for database connections (Aurora, Redshift). + # Allows cross-account role assumption via STS AssumeRole. + # Security note: When enabled, ensure Superset's IAM role has restricted + # sts:AssumeRole permissions to prevent unauthorized access. + # @lifecycle: testing + "AWS_DATABASE_IAM_AUTH": False, # Use analogous colors in charts # @lifecycle: testing "USE_ANALOGOUS_COLORS": False, diff --git a/superset/db_engine_specs/aurora.py b/superset/db_engine_specs/aurora.py index 6dcbe6e1c0fd..bac6274f271c 100644 --- a/superset/db_engine_specs/aurora.py +++ b/superset/db_engine_specs/aurora.py @@ -54,3 +54,29 @@ class AuroraPostgresDataAPI(PostgresEngineSpec): "secret_arn={secret_arn}&" "region_name={region_name}" ) + + +class AuroraMySQLEngineSpec(MySQLEngineSpec): + """ + Aurora MySQL engine spec. + + IAM authentication is handled by the parent MySQLEngineSpec via + the aws_iam config in encrypted_extra. + """ + + engine = "mysql" + engine_name = "Aurora MySQL" + default_driver = "mysqldb" + + +class AuroraPostgresEngineSpec(PostgresEngineSpec): + """ + Aurora PostgreSQL engine spec. + + IAM authentication is handled by the parent PostgresEngineSpec via + the aws_iam config in encrypted_extra. + """ + + engine = "postgresql" + engine_name = "Aurora PostgreSQL" + default_driver = "psycopg2" diff --git a/superset/db_engine_specs/aws_iam.py b/superset/db_engine_specs/aws_iam.py new file mode 100644 index 000000000000..ce2960c3d283 --- /dev/null +++ b/superset/db_engine_specs/aws_iam.py @@ -0,0 +1,660 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +AWS IAM Authentication Mixin for database engine specs. + +This mixin provides cross-account IAM authentication support for AWS databases +(Aurora PostgreSQL, Aurora MySQL, Redshift). It handles: +- Assuming IAM roles via STS AssumeRole +- Generating RDS IAM auth tokens +- Generating Redshift Serverless credentials +- Configuring SSL (required for IAM auth) +- Caching STS credentials to reduce API calls +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any, TYPE_CHECKING, TypedDict + +from cachetools import TTLCache + +from superset.databases.utils import make_url_safe +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetSecurityException + +if TYPE_CHECKING: + from superset.models.core import Database + +logger = logging.getLogger(__name__) + +# Default session duration for STS AssumeRole (1 hour) +DEFAULT_SESSION_DURATION = 3600 + +# Default ports +DEFAULT_POSTGRES_PORT = 5432 +DEFAULT_MYSQL_PORT = 3306 +DEFAULT_REDSHIFT_PORT = 5439 + +# Cache STS credentials: key = (role_arn, region, external_id), TTL = 10 min +# Using a TTL shorter than the minimum supported session duration (900s) avoids +# reusing expired STS credentials when a short session_duration is configured. +_credentials_cache: TTLCache[tuple[str, str, str | None], dict[str, Any]] = TTLCache( + maxsize=100, ttl=600 +) +_credentials_lock = threading.RLock() + + +class AWSIAMConfig(TypedDict, total=False): + """Configuration for AWS IAM authentication.""" + + enabled: bool + role_arn: str + external_id: str + region: str + db_username: str + session_duration: int + # Redshift Serverless fields + workgroup_name: str + db_name: str + # Redshift provisioned cluster fields + cluster_identifier: str + + +class AWSIAMAuthMixin: + """ + Mixin that provides AWS IAM authentication for database connections. + + This mixin can be used with database engine specs that support IAM + authentication (Aurora PostgreSQL, Aurora MySQL, Redshift). + + Configuration is provided via the database's encrypted_extra JSON: + + { + "aws_iam": { + "enabled": true, + "role_arn": "arn:aws:iam::222222222222:role/SupersetDatabaseAccess", + "external_id": "superset-prod-12345", # optional + "region": "us-east-1", + "db_username": "superset_iam_user", + "session_duration": 3600 # optional, defaults to 3600 + } + } + """ + + # AWS error patterns for actionable error messages + aws_iam_custom_errors: dict[str, tuple[SupersetErrorType, str]] = { + "AccessDenied": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "Unable to assume IAM role. Verify the role ARN and trust policy " + "allow access from Superset's IAM role.", + ), + "InvalidIdentityToken": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "Invalid IAM credentials. Ensure Superset has a valid IAM role " + "with permissions to assume the target role.", + ), + "MalformedPolicyDocument": ( + SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + "Invalid IAM role ARN format. Please verify the role ARN.", + ), + "ExpiredTokenException": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "AWS credentials have expired. Please refresh the connection.", + ), + } + + @classmethod + def get_iam_credentials( + cls, + role_arn: str, + region: str, + external_id: str | None = None, + session_duration: int = DEFAULT_SESSION_DURATION, + ) -> dict[str, Any]: + """ + Assume cross-account IAM role via STS AssumeRole with credential caching. + + Credentials are cached by (role_arn, region, external_id) with a 50-minute + TTL to reduce STS API calls while ensuring tokens are refreshed before the + default 1-hour expiration. + + :param role_arn: The ARN of the IAM role to assume + :param region: AWS region for the STS client + :param external_id: External ID for the role assumption (optional) + :param session_duration: Duration of the session in seconds + :returns: Dictionary with AccessKeyId, SecretAccessKey, SessionToken + :raises SupersetSecurityException: If role assumption fails + """ + cache_key = (role_arn, region, external_id) + + with _credentials_lock: + cached = _credentials_cache.get(cache_key) + if cached is not None: + return cached + + try: + # Lazy import to avoid errors when boto3 is not installed + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication. " + "Install it with: pip install boto3", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + sts_client = boto3.client("sts", region_name=region) + + assume_role_kwargs: dict[str, Any] = { + "RoleArn": role_arn, + "RoleSessionName": "superset-iam-session", + "DurationSeconds": session_duration, + } + if external_id: + assume_role_kwargs["ExternalId"] = external_id + + response = sts_client.assume_role(**assume_role_kwargs) + credentials = response["Credentials"] + + with _credentials_lock: + _credentials_cache[cache_key] = credentials + + return credentials + + except ClientError as ex: + error_code = ex.response.get("Error", {}).get("Code", "") + error_message = ex.response.get("Error", {}).get("Message", "") + + # Handle ExternalId mismatch (shows as AccessDenied with specific message) + # Check this first before generic AccessDenied handling + if "external id" in error_message.lower(): + raise SupersetSecurityException( + SupersetError( + message="External ID mismatch. Verify the external_id " + "configuration matches the trust policy.", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + if error_code in cls.aws_iam_custom_errors: + error_type, message = cls.aws_iam_custom_errors[error_code] + raise SupersetSecurityException( + SupersetError( + message=message, + error_type=error_type, + level=ErrorLevel.ERROR, + ) + ) from ex + + raise SupersetSecurityException( + SupersetError( + message=f"Failed to assume IAM role: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def generate_rds_auth_token( + cls, + credentials: dict[str, Any], + hostname: str, + port: int, + username: str, + region: str, + ) -> str: + """ + Generate RDS IAM auth token using temporary credentials. + + :param credentials: STS credentials from assume_role + :param hostname: RDS/Aurora endpoint hostname + :param port: Database port + :param username: Database username configured for IAM auth + :param region: AWS region + :returns: IAM auth token to use as database password + :raises SupersetSecurityException: If token generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + rds_client = boto3.client( + "rds", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + token = rds_client.generate_db_auth_token( + DBHostname=hostname, + Port=port, + DBUsername=username, + ) + return token + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to generate RDS auth token: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def generate_redshift_credentials( + cls, + credentials: dict[str, Any], + workgroup_name: str, + db_name: str, + region: str, + ) -> tuple[str, str]: + """ + Generate Redshift Serverless credentials using temporary STS credentials. + + :param credentials: STS credentials from assume_role + :param workgroup_name: Redshift Serverless workgroup name + :param db_name: Redshift database name + :param region: AWS region + :returns: Tuple of (username, password) for Redshift connection + :raises SupersetSecurityException: If credential generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + client = boto3.client( + "redshift-serverless", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + response = client.get_credentials( + workgroupName=workgroup_name, + dbName=db_name, + ) + return response["dbUser"], response["dbPassword"] + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to get Redshift Serverless credentials: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def generate_redshift_cluster_credentials( + cls, + credentials: dict[str, Any], + cluster_identifier: str, + db_user: str, + db_name: str, + region: str, + auto_create: bool = False, + ) -> tuple[str, str]: + """ + Generate credentials for a provisioned Redshift cluster using temporary + STS credentials. + + :param credentials: STS credentials from assume_role + :param cluster_identifier: Redshift cluster identifier + :param db_user: Database username to get credentials for + :param db_name: Redshift database name + :param region: AWS region + :param auto_create: Whether to auto-create the database user if it doesn't exist + :returns: Tuple of (username, password) for Redshift connection + :raises SupersetSecurityException: If credential generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + client = boto3.client( + "redshift", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + response = client.get_cluster_credentials( + ClusterIdentifier=cluster_identifier, + DbUser=db_user, + DbName=db_name, + AutoCreate=auto_create, + ) + return response["DbUser"], response["DbPassword"] + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to get Redshift cluster credentials: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def _apply_iam_authentication( + cls, + database: Database, + params: dict[str, Any], + iam_config: AWSIAMConfig, + ssl_args: dict[str, str] | None = None, + default_port: int = DEFAULT_POSTGRES_PORT, + ) -> None: + """ + Apply IAM authentication to the connection parameters. + + Full flow: assume role -> generate token -> update connect_args -> enable SSL. + + :param database: Database model instance + :param params: Engine parameters dict to modify + :param iam_config: IAM configuration from encrypted_extra + :param ssl_args: SSL args to apply (defaults to sslmode=require) + :param default_port: Default port if not specified in URI + :raises SupersetSecurityException: If any step fails + """ + from superset import feature_flag_manager + + if not feature_flag_manager.is_feature_enabled("AWS_DATABASE_IAM_AUTH"): + raise SupersetSecurityException( + SupersetError( + message="AWS IAM database authentication is not enabled. " + "Set the AWS_DATABASE_IAM_AUTH feature flag to True in your " + "Superset configuration to enable this feature.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + if ssl_args is None: + ssl_args = {"sslmode": "require"} + + # Extract configuration + role_arn = iam_config.get("role_arn") + region = iam_config.get("region") + db_username = iam_config.get("db_username") + external_id = iam_config.get("external_id") + session_duration = iam_config.get("session_duration", DEFAULT_SESSION_DURATION) + + # Validate required fields + missing_fields = [] + if not role_arn: + missing_fields.append("role_arn") + if not region: + missing_fields.append("region") + if not db_username: + missing_fields.append("db_username") + + if missing_fields: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration missing required fields: " + f"{', '.join(missing_fields)}", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Type assertions after validation (mypy doesn't narrow types from list check) + assert role_arn is not None + assert region is not None + assert db_username is not None + + # Get hostname and port from the database URI + uri = make_url_safe(database.sqlalchemy_uri_decrypted) + hostname = uri.host + port = uri.port or default_port + + if not hostname: + raise SupersetSecurityException( + SupersetError( + message=( + "Database URI must include a hostname for IAM authentication" + ), + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + logger.debug( + "Applying IAM authentication for %s:%d as user %s", + hostname, + port, + db_username, + ) + + # Step 1: Assume the IAM role + credentials = cls.get_iam_credentials( + role_arn=role_arn, + region=region, + external_id=external_id, + session_duration=session_duration, + ) + + # Step 2: Generate the RDS auth token + token = cls.generate_rds_auth_token( + credentials=credentials, + hostname=hostname, + port=port, + username=db_username, + region=region, + ) + + # Step 3: Update connection parameters + connect_args = params.setdefault("connect_args", {}) + + # Set the IAM token as the password + connect_args["password"] = token + + # Override username if different from URI + connect_args["user"] = db_username + + # Step 4: Enable SSL (required for IAM authentication) + connect_args.update(ssl_args) + + logger.debug("IAM authentication configured successfully") + + @classmethod + def _apply_redshift_iam_authentication( + cls, + database: Database, + params: dict[str, Any], + iam_config: AWSIAMConfig, + ) -> None: + """ + Apply Redshift IAM authentication to connection parameters. + + Supports both Redshift Serverless (workgroup_name) and provisioned + clusters (cluster_identifier). The method auto-detects which type + based on the configuration provided. + + Flow: assume role -> get Redshift credentials -> update connect_args -> SSL. + + :param database: Database model instance + :param params: Engine parameters dict to modify + :param iam_config: IAM configuration from encrypted_extra + :raises SupersetSecurityException: If any step fails + """ + from superset import feature_flag_manager + + if not feature_flag_manager.is_feature_enabled("AWS_DATABASE_IAM_AUTH"): + raise SupersetSecurityException( + SupersetError( + message="AWS IAM database authentication is not enabled. " + "Set the AWS_DATABASE_IAM_AUTH feature flag to True in your " + "Superset configuration to enable this feature.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Extract configuration + role_arn = iam_config.get("role_arn") + region = iam_config.get("region") + external_id = iam_config.get("external_id") + session_duration = iam_config.get("session_duration", DEFAULT_SESSION_DURATION) + + # Serverless fields + workgroup_name = iam_config.get("workgroup_name") + + # Provisioned cluster fields + cluster_identifier = iam_config.get("cluster_identifier") + db_username = iam_config.get("db_username") + + # Common field + db_name = iam_config.get("db_name") + + # Determine deployment type + is_serverless = bool(workgroup_name) + is_provisioned = bool(cluster_identifier) + + if is_serverless and is_provisioned: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration cannot have both workgroup_name " + "(Serverless) and cluster_identifier (provisioned). " + "Please specify only one.", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + if not is_serverless and not is_provisioned: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration must include either workgroup_name " + "(for Redshift Serverless) or cluster_identifier " + "(for provisioned Redshift clusters).", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Validate common required fields + missing_fields = [] + if not role_arn: + missing_fields.append("role_arn") + if not region: + missing_fields.append("region") + if not db_name: + missing_fields.append("db_name") + + # Validate provisioned cluster specific fields + if is_provisioned and not db_username: + missing_fields.append("db_username") + + if missing_fields: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration missing required fields: " + f"{', '.join(missing_fields)}", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Type assertions after validation + assert role_arn is not None + assert region is not None + assert db_name is not None + + # Step 1: Assume the IAM role + credentials = cls.get_iam_credentials( + role_arn=role_arn, + region=region, + external_id=external_id, + session_duration=session_duration, + ) + + # Step 2: Get Redshift credentials based on deployment type + if is_serverless: + assert workgroup_name is not None + logger.debug( + "Applying Redshift Serverless IAM authentication for workgroup %s", + workgroup_name, + ) + db_user, db_password = cls.generate_redshift_credentials( + credentials=credentials, + workgroup_name=workgroup_name, + db_name=db_name, + region=region, + ) + else: + assert cluster_identifier is not None + assert db_username is not None + logger.debug( + "Applying Redshift provisioned cluster IAM authentication for %s", + cluster_identifier, + ) + db_user, db_password = cls.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier=cluster_identifier, + db_user=db_username, + db_name=db_name, + region=region, + ) + + # Step 3: Update connection parameters + connect_args = params.setdefault("connect_args", {}) + connect_args["password"] = db_password + connect_args["user"] = db_user + + # Step 4: Enable SSL (required for Redshift IAM authentication) + connect_args["sslmode"] = "verify-ca" + + logger.debug("Redshift IAM authentication configured successfully") diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 19554d5b9c8c..b6cba3906a6e 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import contextlib +import logging import re from datetime import datetime from decimal import Decimal from re import Pattern -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING from urllib import parse from flask_babel import gettext as __ @@ -46,8 +49,14 @@ ) from superset.errors import SupersetErrorType from superset.models.sql_lab import Query +from superset.utils import json from superset.utils.core import GenericDataType +if TYPE_CHECKING: + from superset.models.core import Database + +logger = logging.getLogger(__name__) + # Regular expressions to catch custom errors CONNECTION_ACCESS_DENIED_REGEX = re.compile( "Access denied for user '(?P.*?)'@'(?P.*?)'" @@ -294,6 +303,54 @@ class MySQLEngineSpec(BasicParametersMixin, BaseEngineSpec): "mysqlconnector": {"allow_local_infile": 0}, } + # Sensitive fields that should be masked in encrypted_extra. + # This follows the pattern used by other engine specs (bigquery, snowflake, etc.) + # that specify exact paths rather than using the base class's catch-all "$.*". + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication if configured, then merges any + remaining encrypted_extra keys into params. + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + AWSIAMAuthMixin._apply_iam_authentication( + database, + params, + iam_config, + # MySQL drivers (mysqlclient) use 'ssl' dict, not 'ssl_mode'. + # SSL is typically configured via the database's extra settings, + # so we pass empty ssl_args here to avoid driver compatibility issues. + ssl_args={}, + default_port=3306, + ) + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 10458f15c600..8ae844ff4a29 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -359,6 +359,14 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): max_column_name_length = 63 try_remove_schema_from_table_name = False # pylint: disable=invalid-name + # Sensitive fields that should be masked in encrypted_extra. + # This follows the pattern used by other engine specs (bigquery, snowflake, etc.) + # that specify exact paths rather than using the base class's catch-all "$.*". + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + column_type_mappings = ( ( re.compile(r"^double precision", re.IGNORECASE), @@ -461,6 +469,51 @@ def adjust_engine_params( return uri, connect_args + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication if configured, then merges any + remaining encrypted_extra keys into params (standard behavior). + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + # Preserve a stricter existing sslmode (e.g. verify-full) if present + connect_args = params.get("connect_args") or {} + previous_sslmode = connect_args.get("sslmode") + + AWSIAMAuthMixin._apply_iam_authentication( + database, + params, + iam_config, + ssl_args={"sslmode": "require"}, + default_port=5432, + ) + + # Restore stricter sslmode if it was previously configured + if previous_sslmode in ("verify-ca", "verify-full"): + params.setdefault("connect_args", {})["sslmode"] = previous_sslmode + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def get_default_catalog(cls, database: Database) -> str: """ diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py index ea49c479dea7..fcdfab16967e 100644 --- a/superset/db_engine_specs/redshift.py +++ b/superset/db_engine_specs/redshift.py @@ -31,6 +31,7 @@ from superset.models.core import Database from superset.models.sql_lab import Query from superset.sql.parse import Table +from superset.utils import json logger = logging.getLogger() @@ -201,6 +202,47 @@ def normalize_table_name_for_upload( schema_name.lower() if schema_name else None, ) + # Sensitive fields that should be masked in encrypted_extra. + # This follows the pattern used by other engine specs (bigquery, snowflake, etc.) + # that specify exact paths rather than using the base class's catch-all "$.*". + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication for Redshift Serverless if configured, + then merges any remaining encrypted_extra keys into params. + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + AWSIAMAuthMixin._apply_redshift_iam_authentication( + database, params, iam_config + ) + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def df_to_sql( cls, diff --git a/tests/unit_tests/db_engine_specs/test_aurora.py b/tests/unit_tests/db_engine_specs/test_aurora.py new file mode 100644 index 000000000000..9b979c278bd6 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_aurora.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags + + +def test_aurora_postgres_engine_spec_properties() -> None: + from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec + + assert AuroraPostgresEngineSpec.engine == "postgresql" + assert AuroraPostgresEngineSpec.engine_name == "Aurora PostgreSQL" + assert AuroraPostgresEngineSpec.default_driver == "psycopg2" + + +def test_update_params_from_encrypted_extra_without_iam() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made + assert params == {} + + +def test_update_params_from_encrypted_extra_iam_disabled() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made when IAM is disabled + assert params == {} + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_update_params_from_encrypted_extra_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + assert params["connect_args"]["sslmode"] == "require" + + +def test_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 10, + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:password@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # aws_iam should be consumed, pool_size should be merged + assert "aws_iam" not in params + assert params["pool_size"] == 10 + + +def test_update_params_from_encrypted_extra_no_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made + assert params == {} + + +def test_update_params_from_encrypted_extra_invalid_json() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + # Verify sensitive fields are properly defined + assert ( + "$.aws_iam.external_id" in PostgresEngineSpec.encrypted_extra_sensitive_fields + ) + assert "$.aws_iam.role_arn" in PostgresEngineSpec.encrypted_extra_sensitive_fields + + +def test_mask_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + + masked = PostgresEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["db_username"] == "superset_user" + + +def test_aurora_postgres_inherits_from_postgres() -> None: + from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec + from superset.db_engine_specs.postgres import PostgresEngineSpec + + # Verify inheritance + assert issubclass(AuroraPostgresEngineSpec, PostgresEngineSpec) + + # Verify it inherits PostgreSQL capabilities + assert AuroraPostgresEngineSpec.supports_dynamic_schema is True + assert AuroraPostgresEngineSpec.supports_catalog is True + + +def test_aurora_mysql_engine_spec_properties() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + + assert AuroraMySQLEngineSpec.engine == "mysql" + assert AuroraMySQLEngineSpec.engine_name == "Aurora MySQL" + assert AuroraMySQLEngineSpec.default_driver == "mysqldb" + + +def test_aurora_mysql_inherits_from_mysql() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + from superset.db_engine_specs.mysql import MySQLEngineSpec + + assert issubclass(AuroraMySQLEngineSpec, MySQLEngineSpec) + assert AuroraMySQLEngineSpec.supports_dynamic_schema is True + + +def test_aurora_mysql_has_iam_support() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + + # Verify it inherits encrypted_extra_sensitive_fields + assert ( + "$.aws_iam.external_id" + in AuroraMySQLEngineSpec.encrypted_extra_sensitive_fields + ) + assert ( + "$.aws_iam.role_arn" in AuroraMySQLEngineSpec.encrypted_extra_sensitive_fields + ) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_aurora_mysql_update_params_from_encrypted_extra_with_iam() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:3306/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AuroraMySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + # Note: ssl_mode is not set because MySQL drivers don't support it. + # SSL should be configured via the database's extra settings. + + +def test_aurora_data_api_classes_unchanged() -> None: + from superset.db_engine_specs.aurora import ( + AuroraMySQLDataAPI, + AuroraPostgresDataAPI, + ) + + # Verify Data API classes are still available and unchanged + assert AuroraMySQLDataAPI.engine == "mysql" + assert AuroraMySQLDataAPI.default_driver == "auroradataapi" + assert AuroraMySQLDataAPI.engine_name == "Aurora MySQL (Data API)" + + assert AuroraPostgresDataAPI.engine == "postgresql" + assert AuroraPostgresDataAPI.default_driver == "auroradataapi" + assert AuroraPostgresDataAPI.engine_name == "Aurora PostgreSQL (Data API)" diff --git a/tests/unit_tests/db_engine_specs/test_aws_iam.py b/tests/unit_tests/db_engine_specs/test_aws_iam.py new file mode 100644 index 000000000000..602bd76f68fd --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_aws_iam.py @@ -0,0 +1,1045 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, protected-access + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.exceptions import SupersetSecurityException +from tests.unit_tests.conftest import with_feature_flags + + +def test_get_iam_credentials_success() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + "Expiration": "2025-01-01T00:00:00Z", + } + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + credentials = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert credentials == mock_credentials + mock_boto3_client.assert_called_once_with("sts", region_name="us-east-1") + mock_sts.assume_role.assert_called_once_with( + RoleArn="arn:aws:iam::123456789012:role/TestRole", + RoleSessionName="superset-iam-session", + DurationSeconds=3600, + ) + + +def test_get_iam_credentials_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + credentials = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-west-2", + external_id="external-id-12345", + session_duration=900, + ) + + assert credentials == mock_credentials + mock_sts.assume_role.assert_called_once_with( + RoleArn="arn:aws:iam::123456789012:role/TestRole", + RoleSessionName="superset-iam-session", + DurationSeconds=900, + ExternalId="external-id-12345", + ) + + +def test_get_iam_credentials_access_denied() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, + "AssumeRole", + ) + mock_boto3_client.return_value = mock_sts + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert "Unable to assume IAM role" in str(exc_info.value) + + +def test_get_iam_credentials_external_id_mismatch() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = ClientError( + { + "Error": { + "Code": "AccessDenied", + "Message": "The external id does not match", + } + }, + "AssumeRole", + ) + mock_boto3_client.return_value = mock_sts + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + external_id="wrong-id", + ) + + assert "External ID mismatch" in str(exc_info.value) + + +def test_generate_rds_auth_token() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_rds = MagicMock() + mock_rds.generate_db_auth_token.return_value = "iam-token-12345" + mock_boto3_client.return_value = mock_rds + + token = AWSIAMAuthMixin.generate_rds_auth_token( + credentials=credentials, + hostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com", + port=5432, + username="superset_user", + region="us-east-1", + ) + + assert token == "iam-token-12345" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "rds", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_rds.generate_db_auth_token.assert_called_once_with( + DBHostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com", + Port=5432, + DBUsername="superset_user", + ) + + +def test_apply_iam_authentication_feature_flag_disabled() -> None: + """Test that IAM auth is blocked when feature flag is disabled.""" + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + # Feature flag is disabled by default + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + ) + + assert "AWS IAM database authentication is not enabled" in str(exc_info.value) + assert "AWS_DATABASE_IAM_AUTH" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_token.assert_called_once() + token_call_kwargs = mock_gen_token.call_args[1] + assert ( + token_call_kwargs["hostname"] == "mydb.cluster-xyz.us-east-1.rds.amazonaws.com" + ) + assert token_call_kwargs["port"] == 5432 + assert token_call_kwargs["username"] == "superset_iam_user" + + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + assert params["connect_args"]["sslmode"] == "require" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-west-2.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRole", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "db_username": "iam_user", + "session_duration": 1800, + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRole", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_missing_role_arn() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + assert "role_arn" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_missing_db_username() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-east-1.rds.amazonaws.com:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + assert "db_username" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_default_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + # URI without explicit port + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://user@mydb.us-east-1.rds.amazonaws.com/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + # Should use default port 5432 + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 5432 + + +def test_get_iam_credentials_boto3_not_installed() -> None: + import builtins + + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + with _credentials_lock: + _credentials_cache.clear() + + # Patch the import mechanism to simulate boto3 not being installed + real_import = builtins.__import__ + + def fake_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "boto3" or name.startswith("boto3."): + raise ImportError("No module named 'boto3'") + return real_import(name, *args, **kwargs) + + with patch.object(builtins, "__import__", side_effect=fake_import): + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert "boto3 is required" in str(exc_info.value) + + +def test_get_iam_credentials_caching() -> None: + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + # Clear cache before test + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + # First call should hit STS + result1 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/CachedRole", + region="us-east-1", + ) + + # Second call should use cache + result2 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/CachedRole", + region="us-east-1", + ) + + assert result1 == mock_credentials + assert result2 == mock_credentials + # STS should only be called once + mock_sts.assume_role.assert_called_once() + + # Clean up + with _credentials_lock: + _credentials_cache.clear() + + +def test_get_iam_credentials_cache_different_keys() -> None: + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + creds_role1 = { + "AccessKeyId": "ASIA_ROLE1", + "SecretAccessKey": "secret1", + "SessionToken": "token1", + } + creds_role2 = { + "AccessKeyId": "ASIA_ROLE2", + "SecretAccessKey": "secret2", + "SessionToken": "token2", + } + + # Clear cache before test + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = [ + {"Credentials": creds_role1}, + {"Credentials": creds_role2}, + ] + mock_boto3_client.return_value = mock_sts + + result1 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::111111111111:role/Role1", + region="us-east-1", + ) + result2 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::222222222222:role/Role2", + region="us-east-1", + ) + + assert result1 == creds_role1 + assert result2 == creds_role2 + # Both calls should hit STS (different cache keys) + assert mock_sts.assume_role.call_count == 2 + + # Clean up + with _credentials_lock: + _credentials_cache.clear() + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_custom_ssl_args() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:3306/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + ssl_args={"ssl_mode": "REQUIRED"}, + default_port=3306, + ) + + assert params["connect_args"]["ssl_mode"] == "REQUIRED" + assert "sslmode" not in params["connect_args"] + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_iam_authentication_custom_default_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + # URI without explicit port + mock_database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + default_port=3306, + ) + + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 3306 + + +def test_generate_redshift_credentials() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_credentials.return_value = { + "dbUser": "IAM:admin", + "dbPassword": "redshift-temp-password", + } + mock_boto3_client.return_value = mock_redshift + + db_user, db_password = AWSIAMAuthMixin.generate_redshift_credentials( + credentials=credentials, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert db_user == "IAM:admin" + assert db_password == "redshift-temp-password" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "redshift-serverless", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_redshift.get_credentials.assert_called_once_with( + workgroupName="my-workgroup", + dbName="dev", + ) + + +def test_generate_redshift_credentials_client_error() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_credentials.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, + "GetCredentials", + ) + mock_boto3_client.return_value = mock_redshift + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.generate_redshift_credentials( + credentials=credentials, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert "Failed to get Redshift Serverless credentials" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-workgroup.123456789012.us-east-1" + ".redshift-serverless.amazonaws.com:5439/dev" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ) as mock_gen_creds, + ): + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/RedshiftRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_creds.assert_called_once_with( + credentials={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert params["connect_args"]["password"] == "redshift-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:admin" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_missing_workgroup() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "workgroup_name" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_missing_db_name() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "db_name" in str(exc_info.value) + + +def test_generate_redshift_cluster_credentials() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_cluster_credentials.return_value = { + "DbUser": "IAM:superset_user", + "DbPassword": "redshift-cluster-temp-password", + } + mock_boto3_client.return_value = mock_redshift + + db_user, db_password = AWSIAMAuthMixin.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier="my-redshift-cluster", + db_user="superset_user", + db_name="analytics", + region="us-east-1", + ) + + assert db_user == "IAM:superset_user" + assert db_password == "redshift-cluster-temp-password" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "redshift", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_redshift.get_cluster_credentials.assert_called_once_with( + ClusterIdentifier="my-redshift-cluster", + DbUser="superset_user", + DbName="analytics", + AutoCreate=False, + ) + + +def test_generate_redshift_cluster_credentials_with_auto_create() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_cluster_credentials.return_value = { + "DbUser": "IAM:new_user", + "DbPassword": "temp-password", + } + mock_boto3_client.return_value = mock_redshift + + AWSIAMAuthMixin.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier="my-cluster", + db_user="new_user", + db_name="dev", + region="us-west-2", + auto_create=True, + ) + + mock_redshift.get_cluster_credentials.assert_called_once_with( + ClusterIdentifier="my-cluster", + DbUser="new_user", + DbName="dev", + AutoCreate=True, + ) + + +def test_generate_redshift_cluster_credentials_client_error() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_cluster_credentials.side_effect = ClientError( + {"Error": {"Code": "ClusterNotFound", "Message": "Cluster not found"}}, + "GetClusterCredentials", + ) + mock_boto3_client.return_value = mock_redshift + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.generate_redshift_cluster_credentials( + credentials=credentials, + cluster_identifier="nonexistent-cluster", + db_user="superset_user", + db_name="dev", + region="us-east-1", + ) + + assert "Failed to get Redshift cluster credentials" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_provisioned_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-cluster.abc123.us-east-1" + ".redshift.amazonaws.com:5439/analytics" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "cluster_identifier": "my-cluster", + "db_username": "superset_user", + "db_name": "analytics", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_cluster_credentials", + return_value=("IAM:superset_user", "cluster-temp-password"), + ) as mock_gen_creds, + ): + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/RedshiftRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_creds.assert_called_once_with( + credentials={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + cluster_identifier="my-cluster", + db_user="superset_user", + db_name="analytics", + region="us-east-1", + ) + + assert params["connect_args"]["password"] == "cluster-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:superset_user" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_provisioned_missing_db_username() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "cluster_identifier": "my-cluster", + "db_name": "dev", + # Missing db_username - required for provisioned clusters + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "db_username" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_both_workgroup_and_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "cluster_identifier": "my-cluster", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "cannot have both" in str(exc_info.value) + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_apply_redshift_iam_authentication_neither_workgroup_nor_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "db_name": "dev", + # Missing both workgroup_name and cluster_identifier + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "must include either workgroup_name" in str(exc_info.value) diff --git a/tests/unit_tests/db_engine_specs/test_mysql_iam.py b/tests/unit_tests/db_engine_specs/test_mysql_iam.py new file mode 100644 index 000000000000..9b5c25b53cf0 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_mysql_iam.py @@ -0,0 +1,236 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags + + +def test_mysql_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + assert "$.aws_iam.external_id" in MySQLEngineSpec.encrypted_extra_sensitive_fields + assert "$.aws_iam.role_arn" in MySQLEngineSpec.encrypted_extra_sensitive_fields + + +def test_mysql_update_params_no_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_mysql_update_params_empty_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_mysql_update_params_iam_disabled() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_mysql_update_params_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com:3306/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + # Note: ssl_mode is not set because MySQL drivers don't support it. + # SSL should be configured via the database's extra settings. + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_mysql_update_params_iam_uses_mysql_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + # URI without explicit port + database.sqlalchemy_uri_decrypted = ( + "mysql://user@mydb.cluster-xyz.us-east-1.rds.amazonaws.com/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + # Should use default MySQL port 3306 + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 3306 + + +def test_mysql_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 10, + } + ) + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "aws_iam" not in params + assert params["pool_size"] == 10 + + +def test_mysql_update_params_invalid_json() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_mysql_mask_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + + masked = MySQLEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["db_username"] == "superset_user" diff --git a/tests/unit_tests/db_engine_specs/test_redshift_iam.py b/tests/unit_tests/db_engine_specs/test_redshift_iam.py new file mode 100644 index 000000000000..49657bad8911 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_redshift_iam.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags + + +def test_redshift_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + assert ( + "$.aws_iam.external_id" in RedshiftEngineSpec.encrypted_extra_sensitive_fields + ) + assert "$.aws_iam.role_arn" in RedshiftEngineSpec.encrypted_extra_sensitive_fields + + +def test_redshift_update_params_no_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_redshift_update_params_empty_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_redshift_update_params_iam_disabled() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-workgroup.123456789012.us-east-1" + ".redshift-serverless.amazonaws.com:5439/dev" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "redshift-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:admin" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRedshift", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "workgroup_name": "prod-workgroup", + "db_name": "analytics", + "session_duration": 1800, + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@prod-workgroup.222222222222.us-west-2" + ".redshift-serverless.amazonaws.com:5439/analytics" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRedshift", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +def test_redshift_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 5, + } + ) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "aws_iam" not in params + assert params["pool_size"] == 5 + + +def test_redshift_update_params_invalid_json() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_redshift_mask_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + + masked = RedshiftEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["workgroup_name"] == "my-workgroup" + assert masked_config["aws_iam"]["db_name"] == "dev" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_with_iam_provisioned_cluster() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "cluster_identifier": "my-redshift-cluster", + "db_username": "superset_user", + "db_name": "analytics", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@my-redshift-cluster.abc123.us-east-1" + ".redshift.amazonaws.com:5439/analytics" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_redshift_cluster_credentials", + return_value=("IAM:superset_user", "cluster-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "cluster-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:superset_user" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +@with_feature_flags(AWS_DATABASE_IAM_AUTH=True) +def test_redshift_update_params_provisioned_cluster_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRedshift", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "cluster_identifier": "prod-cluster", + "db_username": "analytics_user", + "db_name": "prod_db", + "session_duration": 1800, + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://user@prod-cluster.xyz789.us-west-2" + ".redshift.amazonaws.com:5439/prod_db" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_cluster_credentials", + return_value=("IAM:analytics_user", "cluster-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRedshift", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +def test_redshift_mask_encrypted_extra_provisioned_cluster() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "cluster_identifier": "my-cluster", + "db_username": "superset_user", + "db_name": "analytics", + } + } + ) + + masked = RedshiftEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["cluster_identifier"] == "my-cluster" + assert masked_config["aws_iam"]["db_username"] == "superset_user" + assert masked_config["aws_iam"]["db_name"] == "analytics"