diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 1e560e6ed5c4..1796c2c55733 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -909,19 +909,19 @@ def modify_url_for_impersonation( url.username = username @classmethod - def get_configuration_for_impersonation( # pylint: disable=invalid-name - cls, uri: str, impersonate_user: bool, username: Optional[str] - ) -> Dict[str, str]: + def update_impersonation_config( + cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], + ) -> None: """ - Return a configuration dictionary that can be merged with other configs + Update a configuration dictionary that can set the correct properties for impersonating users + :param connect_args: config to be updated :param uri: URI :param impersonate_user: Flag indicating if impersonation is enabled :param username: Effective username - :return: Configs required for impersonation + :return: None """ - return {} @classmethod def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None: diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 72cf93c79d20..51bedbee3569 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -487,26 +487,28 @@ def modify_url_for_impersonation( # the configuraiton dictionary. See get_configuration_for_impersonation @classmethod - def get_configuration_for_impersonation( - cls, uri: str, impersonate_user: bool, username: Optional[str] - ) -> Dict[str, str]: + def update_impersonation_config( + cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], + ) -> None: """ - Return a configuration dictionary that can be merged with other configs + Update a configuration dictionary that can set the correct properties for impersonating users + :param connect_args: :param uri: URI string :param impersonate_user: Flag indicating if impersonation is enabled :param username: Effective username - :return: Configs required for impersonation + :return: None """ - configuration = {} url = make_url(uri) backend_name = url.get_backend_name() # Must be Hive connection, enable impersonation, and set optional param # auth=LDAP|KERBEROS - if backend_name == "hive" and impersonate_user and username is not None: + # this will set hive.server2.proxy.user=$effective_username on connect_args['configuration'] + if backend_name == "hive" and username is not None: + configuration = connect_args.get("configuration", {}) configuration["hive.server2.proxy.user"] = username - return configuration + connect_args["configuration"] = configuration @staticmethod def execute( # type: ignore diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 071fd885f8d6..6ea687fcfe02 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -33,7 +33,7 @@ from sqlalchemy.engine.base import Engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.result import RowProxy -from sqlalchemy.engine.url import URL +from sqlalchemy.engine.url import make_url, URL from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause, Select @@ -136,6 +136,28 @@ def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: version = extra.get("version") return version is not None and StrictVersion(version) >= StrictVersion("0.319") + @classmethod + def update_impersonation_config( + cls, connect_args: Dict[str, Any], uri: str, username: Optional[str], + ) -> None: + """ + Update a configuration dictionary + that can set the correct properties for impersonating users + :param connect_args: config to be updated + :param uri: URI string + :param impersonate_user: Flag indicating if impersonation is enabled + :param username: Effective username + :return: None + """ + url = make_url(uri) + backend_name = url.get_backend_name() + + # Must be Presto connection, enable impersonation, and set optional param + # auth=LDAP|KERBEROS + # Set principal_username=$effective_username + if backend_name == "presto" and username is not None: + connect_args["principal_username"] = username + @classmethod def get_table_names( cls, database: "Database", inspector: Inspector, schema: Optional[str] diff --git a/superset/models/core.py b/superset/models/core.py index 5d0dde3dce9b..079a5e3575ff 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -325,16 +325,11 @@ def get_sqla_engine( params["poolclass"] = NullPool connect_args = params.get("connect_args", {}) - configuration = connect_args.get("configuration", {}) - - # If using Hive, this will set hive.server2.proxy.user=$effective_username - configuration.update( - self.db_engine_spec.get_configuration_for_impersonation( - str(sqlalchemy_url), self.impersonate_user, effective_username + if self.impersonate_user: + self.db_engine_spec.update_impersonation_config( + connect_args, str(sqlalchemy_url), effective_username ) - ) - if configuration: - connect_args["configuration"] = configuration + if connect_args: params["connect_args"] = connect_args diff --git a/tests/model_tests.py b/tests/model_tests.py index e0eaf4ac460e..2ff4c1a6dd85 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -17,6 +17,7 @@ # isort:skip_file import textwrap import unittest +from unittest import mock from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices import pandas @@ -110,6 +111,98 @@ def test_database_impersonate_user(self): user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username self.assertNotEqual(example_user, user_name) + @mock.patch("superset.models.core.create_engine") + def test_impersonate_user_presto(self, mocked_create_engine): + uri = "presto://localhost" + principal_user = "logged_in_user" + extra = """ + { + "metadata_params": {}, + "engine_params": { + "connect_args":{ + "protocol": "https", + "username":"original_user", + "password":"original_user_password" + } + }, + "metadata_cache_timeout": {}, + "schemas_allowed_for_csv_upload": [] + } + """ + + model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra) + + model.impersonate_user = True + model.get_sqla_engine(user_name=principal_user) + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "presto://logged_in_user@localhost" + + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + "principal_username": "logged_in_user", + } + + model.impersonate_user = False + model.get_sqla_engine(user_name=principal_user) + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "presto://localhost" + + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + } + + @mock.patch("superset.models.core.create_engine") + def test_impersonate_user_hive(self, mocked_create_engine): + uri = "hive://localhost" + principal_user = "logged_in_user" + extra = """ + { + "metadata_params": {}, + "engine_params": { + "connect_args":{ + "protocol": "https", + "username":"original_user", + "password":"original_user_password" + } + }, + "metadata_cache_timeout": {}, + "schemas_allowed_for_csv_upload": [] + } + """ + + model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra) + + model.impersonate_user = True + model.get_sqla_engine(user_name=principal_user) + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "hive://localhost" + + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + "configuration": {"hive.server2.proxy.user": "logged_in_user"}, + } + + model.impersonate_user = False + model.get_sqla_engine(user_name=principal_user) + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "hive://localhost" + + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + } + @pytest.mark.usefixtures("load_energy_table_with_slice") def test_select_star(self): db = get_example_database()