diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1e560e6ed5c4..ea566ac0a9f4 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -923,6 +923,21 @@ def get_configuration_for_impersonation( # pylint: disable=invalid-name """ return {} + @classmethod + def get_connect_args_for_impersonation( + cls, uri: str, impersonate_user: bool, username: Optional[str] + ) -> Dict[str, str]: + """ + Return a configuration dictionary that can be merged with other configs + that can set the correct properties for impersonating users + + :param uri: URI + :param impersonate_user: Flag indicating if impersonation is enabled + :param username: Effective username + :return: Configs required for impersonation + """ + return {} + @classmethod def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None: """ diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 071fd885f8d6..d81a85b59b00 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -19,6 +19,7 @@ import re import textwrap import time +import requests from collections import defaultdict, deque from contextlib import closing from datetime import datetime @@ -36,6 +37,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select +from sqlalchemy.engine.url import make_url, URL from superset import app, cache_manager, is_feature_enabled from superset.db_engine_specs.base import BaseEngineSpec @@ -136,6 +138,30 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: version = extra.get("version") return version is not None and StrictVersion(version) >= StrictVersion("0.319") + + @classmethod + def get_connect_args_for_impersonation( + cls, uri: str, impersonate_user: bool, username: Optional[str] + ) -> Dict[str, str]: + """ + Return a configuration dictionary that can be merged with other configs + that can set the correct properties for impersonating users + :param uri: URI string + :param impersonate_user: Flag indicating if impersonation is enabled + :param username: Effective username + :return: Configs required for impersonation + """ + configuration = {} + url = make_url(uri) + backend_name = url.get_backend_name() + + # Must be Presto connection, enable impersonation, and set optional param + # auth=LDAP|KERBEROS + if backend_name == "presto" and impersonate_user and username is not None: + configuration["principal_username"] = username + return configuration + + @classmethod def get_table_names( cls, database: "Database", inspector: Inspector, schema: Optional[str] @@ -173,9 +199,9 @@ def get_view_names( engine = cls.get_engine(database, schema=schema) with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - cursor.execute(sql, params) - results = cursor.fetchall() + with closing(conn.cursor()) as cursor: + cursor.execute(sql, params) + results = cursor.fetchall() return [row[0] for row in results] @@ -515,6 +541,45 @@ def estimate_statement_cost( # pylint: disable=too-many-locals result = json.loads(cursor.fetchone()[0]) return result + @classmethod + def download_statement_result( # pylint: disable=too-many-locals + cls, statement: str, username: str + ) -> Dict[str, Any]: + """ + Run a SQL query that estimates the cost of a given statement. + + :param statement: A single SQL statement + :param database: Database instance + :param cursor: Cursor instance + :param username: Effective username + :return: JSON response from Presto + """ + url = config["BIFROST_DOWNLOAD_API_URL"] + token = config["BIFROST_API_TOKEN"] + download_format = config["BIFROST_DOWNLOAD_FILE_FORMAT"] + download_compression = config["BIFROST_DOWNLOAD_FILE_COMPRESSION"] + + payload = { + "query": statement, + "engine": "PRESTO", + "username" : username, + "output" : { + "format": download_format, + "compression": download_compression + } + } + # Adding empty header as parameters are being sent in payload + headers = { + "Authorization" : token, + "Content-Type": "application/json" + } + r = requests.post(url, data=json.dumps(payload), headers=headers) + content = json.loads(r.content) + if r.ok: + return {"Request response" : "Successfully submitted. Please visit \"Download History\" under SQL Lab menu for status."} + else: + raise Exception(content["message"]) + @classmethod def query_cost_formatter( cls, raw_cost: List[Dict[str, Any]] @@ -758,18 +823,18 @@ def get_create_view( engine = cls.get_engine(database, schema) with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - sql = f"SHOW CREATE VIEW {schema}.{table}" - try: - cls.execute(cursor, sql) - polled = cursor.poll() - - while polled: - time.sleep(0.2) + with closing(conn.cursor()) as cursor: + sql = f"SHOW CREATE VIEW {schema}.{table}" + try: + cls.execute(cursor, sql) polled = cursor.poll() - except DatabaseError: # not a VIEW - return None - rows = cls.fetch_data(cursor, 1) + + while polled: + time.sleep(0.2) + polled = cursor.poll() + except DatabaseError: # not a VIEW + return None + rows = cls.fetch_data(cursor, 1) return rows[0][0] @classmethod diff --git a/superset/models/core.py b/superset/models/core.py index 5d0dde3dce9b..b656c46d6fd4 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -333,6 +333,14 @@ def get_sqla_engine( str(sqlalchemy_url), self.impersonate_user, effective_username ) ) + + # If using presto, this will set principal_username=$effective_username + connect_args.update( + self.db_engine_spec.get_connect_args_for_impersonation( + str(sqlalchemy_url), self.impersonate_user, effective_username + ) + ) + if configuration: connect_args["configuration"] = configuration if connect_args: