Skip to content
12 changes: 6 additions & 6 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,19 +909,19 @@ def modify_url_for_impersonation(
url.username = username

@classmethod
def get_configuration_for_impersonation( # pylint: disable=invalid-name
cls, uri: str, impersonate_user: bool, username: Optional[str]
) -> Dict[str, str]:
def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
) -> None:
"""
Return a configuration dictionary that can be merged with other configs
Update a configuration dictionary
that can set the correct properties for impersonating users

:param connect_args: config to be updated
:param uri: URI
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: Configs required for impersonation
:return: None
"""
return {}

@classmethod
def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None:
Expand Down
18 changes: 10 additions & 8 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,26 +487,28 @@ def modify_url_for_impersonation(
# the configuraiton dictionary. See get_configuration_for_impersonation

@classmethod
def get_configuration_for_impersonation(
cls, uri: str, impersonate_user: bool, username: Optional[str]
) -> Dict[str, str]:
def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
) -> None:
"""
Return a configuration dictionary that can be merged with other configs
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args:
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: Configs required for impersonation
:return: None
"""
configuration = {}
url = make_url(uri)
backend_name = url.get_backend_name()

# Must be Hive connection, enable impersonation, and set optional param
# auth=LDAP|KERBEROS
if backend_name == "hive" and impersonate_user and username is not None:
# this will set hive.server2.proxy.user=$effective_username on connect_args['configuration']
if backend_name == "hive" and username is not None:
configuration = connect_args.get("configuration", {})
configuration["hive.server2.proxy.user"] = username
return configuration
connect_args["configuration"] = configuration

@staticmethod
def execute( # type: ignore
Expand Down
24 changes: 23 additions & 1 deletion superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import URL
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select

Expand Down Expand Up @@ -136,6 +136,28 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
version = extra.get("version")
return version is not None and StrictVersion(version) >= StrictVersion("0.319")

@classmethod
def update_impersonation_config(
cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
) -> None:
"""
Update a configuration dictionary
that can set the correct properties for impersonating users
:param connect_args: config to be updated
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: None
"""
url = make_url(uri)
backend_name = url.get_backend_name()

# Must be Presto connection, enable impersonation, and set optional param
# auth=LDAP|KERBEROS
# Set principal_username=$effective_username
if backend_name == "presto" and username is not None:
connect_args["principal_username"] = username

@classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
Expand Down
13 changes: 4 additions & 9 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,11 @@ def get_sqla_engine(
params["poolclass"] = NullPool

connect_args = params.get("connect_args", {})
configuration = connect_args.get("configuration", {})

# If using Hive, this will set hive.server2.proxy.user=$effective_username
configuration.update(
self.db_engine_spec.get_configuration_for_impersonation(
str(sqlalchemy_url), self.impersonate_user, effective_username
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
)
)
if configuration:
connect_args["configuration"] = configuration

if connect_args:
params["connect_args"] = connect_args

Expand Down
93 changes: 93 additions & 0 deletions tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# isort:skip_file
import textwrap
import unittest
from unittest import mock
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices

import pandas
Expand Down Expand Up @@ -110,6 +111,98 @@ def test_database_impersonate_user(self):
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertNotEqual(example_user, user_name)

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_presto(self, mocked_create_engine):
uri = "presto://localhost"
principal_user = "logged_in_user"
extra = """
{
"metadata_params": {},
"engine_params": {
"connect_args":{
"protocol": "https",
"username":"original_user",
"password":"original_user_password"
}
},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": []
}
"""

model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)

model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "presto://logged_in_user@localhost"

assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"principal_username": "logged_in_user",
}

model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "presto://localhost"

assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}

@mock.patch("superset.models.core.create_engine")
def test_impersonate_user_hive(self, mocked_create_engine):
uri = "hive://localhost"
principal_user = "logged_in_user"
extra = """
{
"metadata_params": {},
"engine_params": {
"connect_args":{
"protocol": "https",
"username":"original_user",
"password":"original_user_password"
}
},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": []
}
"""

model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)

model.impersonate_user = True
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "hive://localhost"

assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
"configuration": {"hive.server2.proxy.user": "logged_in_user"},
}

model.impersonate_user = False
model.get_sqla_engine(user_name=principal_user)
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "hive://localhost"

assert call_args[1]["connect_args"] == {
"protocol": "https",
"username": "original_user",
"password": "original_user_password",
}

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_select_star(self):
db = get_example_database()
Expand Down