Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,21 @@ def get_configuration_for_impersonation( # pylint: disable=invalid-name
"""
return {}

@classmethod
def get_connect_args_for_impersonation(
cls, uri: str, impersonate_user: bool, username: Optional[str]
) -> Dict[str, str]:
"""
Return a configuration dictionary that can be merged with other configs
that can set the correct properties for impersonating users

:param uri: URI
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: Configs required for impersonation
"""
return {}

@classmethod
def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None:
"""
Expand Down
93 changes: 79 additions & 14 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import textwrap
import time
import requests
from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
Expand All @@ -36,6 +37,7 @@
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from sqlalchemy.engine.url import make_url, URL

from superset import app, cache_manager, is_feature_enabled
from superset.db_engine_specs.base import BaseEngineSpec
Expand Down Expand Up @@ -136,6 +138,30 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
version = extra.get("version")
return version is not None and StrictVersion(version) >= StrictVersion("0.319")


@classmethod
def get_connect_args_for_impersonation(
cls, uri: str, impersonate_user: bool, username: Optional[str]
) -> Dict[str, str]:
"""
Return a configuration dictionary that can be merged with other configs
that can set the correct properties for impersonating users
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: Configs required for impersonation
"""
configuration = {}
url = make_url(uri)
backend_name = url.get_backend_name()

# Must be Presto connection, enable impersonation, and set optional param
# auth=LDAP|KERBEROS
if backend_name == "presto" and impersonate_user and username is not None:
configuration["principal_username"] = username
return configuration


@classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
Expand Down Expand Up @@ -173,9 +199,9 @@ def get_view_names(

engine = cls.get_engine(database, schema=schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
results = cursor.fetchall()
with closing(conn.cursor()) as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()

return [row[0] for row in results]

Expand Down Expand Up @@ -515,6 +541,45 @@ def estimate_statement_cost( # pylint: disable=too-many-locals
result = json.loads(cursor.fetchone()[0])
return result

@classmethod
def download_statement_result( # pylint: disable=too-many-locals
cls, statement: str, username: str
) -> Dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.

:param statement: A single SQL statement
:param database: Database instance
:param cursor: Cursor instance
:param username: Effective username
:return: JSON response from Presto
"""
url = config["BIFROST_DOWNLOAD_API_URL"]
token = config["BIFROST_API_TOKEN"]
download_format = config["BIFROST_DOWNLOAD_FILE_FORMAT"]
download_compression = config["BIFROST_DOWNLOAD_FILE_COMPRESSION"]

payload = {
"query": statement,
"engine": "PRESTO",
"username" : username,
"output" : {
"format": download_format,
"compression": download_compression
}
}
# Adding empty header as parameters are being sent in payload
headers = {
"Authorization" : token,
"Content-Type": "application/json"
}
r = requests.post(url, data=json.dumps(payload), headers=headers)
content = json.loads(r.content)
if r.ok:
return {"Request response" : "Successfully submitted. Please visit \"Download History\" under SQL Lab menu for status."}
else:
raise Exception(content["message"])

@classmethod
def query_cost_formatter(
cls, raw_cost: List[Dict[str, Any]]
Expand Down Expand Up @@ -758,18 +823,18 @@ def get_create_view(

engine = cls.get_engine(database, schema)
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)
polled = cursor.poll()

while polled:
time.sleep(0.2)
with closing(conn.cursor()) as cursor:
sql = f"SHOW CREATE VIEW {schema}.{table}"
try:
cls.execute(cursor, sql)
polled = cursor.poll()
except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)

while polled:
time.sleep(0.2)
polled = cursor.poll()
except DatabaseError: # not a VIEW
return None
rows = cls.fetch_data(cursor, 1)
return rows[0][0]

@classmethod
Expand Down
8 changes: 8 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,14 @@ def get_sqla_engine(
str(sqlalchemy_url), self.impersonate_user, effective_username
)
)

# If using presto, this will set principal_username=$effective_username
connect_args.update(
self.db_engine_spec.get_connect_args_for_impersonation(
str(sqlalchemy_url), self.impersonate_user, effective_username
)
)

if configuration:
connect_args["configuration"] = configuration
if connect_args:
Expand Down