Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions superset/commands/database/sync_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,7 @@ def sync_database_permissions(self) -> None:
"""
Syncs the permissions for a DB connection.
"""
catalogs = (
self._get_catalog_names()
if self.db_connection.db_engine_spec.supports_catalog
else [None]
)

for catalog in catalogs:
for catalog in self._get_catalog_names():
try:
schemas = self._get_schema_names(catalog)

Expand Down Expand Up @@ -192,15 +186,29 @@ def sync_database_permissions(self) -> None:
if self.old_db_connection_name != self.db_connection.database_name:
self._rename_database_in_permissions(catalog, schemas)

def _get_catalog_names(self) -> set[str]:
def _get_catalog_names(self) -> set[str | None]:
"""
Helper method to load catalogs.
"""
if not self.db_connection.db_engine_spec.supports_catalog:
return {None}

try:
return self.db_connection.get_all_catalog_names(
force=True,
ssh_tunnel=self.db_connection_ssh_tunnel,
)
# Adding permissions to all catalogs (and all their schemas) can take a long
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a similar check here:

catalogs = (
self._get_catalog_names()
if self.db_connection.db_engine_spec.supports_catalog
else [None]
)

Can these two conditions be consolidated?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that block is actually calling this method, so I think we're good? Would defer to @betodealmeida to confirm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, those are different.

Basically with my changes _get_catalog_names returns all relevant catalogs in a given database. This could be all catalogs, or just the default.

When a database doesn't support catalogs, then we need to use None as the catalog, hence the [None].

# time (minutes, while importing a chart, eg). If the database does not
# support cross-catalog queries (like Postgres), and the multi-catalog
# feature is not enabled, then we only need to add permissions to the
# default catalog.
if (
self.db_connection.db_engine_spec.supports_cross_catalog_queries
or self.db_connection.allow_multi_catalog
):
return self.db_connection.get_all_catalog_names(
force=True,
ssh_tunnel=self.db_connection_ssh_tunnel,
)
else:
return {self.db_connection.get_default_catalog()}
except OAuth2RedirectError:
# raise OAuth2 exceptions as-is
raise
Expand Down
3 changes: 3 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# Can the catalog be changed on a per-query basis?
supports_dynamic_catalog = False

# Does the DB engine spec support cross-catalog queries?
supports_cross_catalog_queries = False

# Does the engine supports OAuth 2.0? This requires logic to be added to one of the
# the user impersonation methods to handle personal tokens.
supports_oauth2 = False
Expand Down
4 changes: 2 additions & 2 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met

allows_hidden_cc_in_orderby = True

supports_catalog = supports_dynamic_catalog = True
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True

# when editing the database, mask this field in `encrypted_extra`
# pylint: disable=invalid-name
Expand Down Expand Up @@ -539,7 +539,7 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
]

@classmethod
def get_default_catalog(cls, database: Database) -> str | None:
def get_default_catalog(cls, database: Database) -> str:
"""
Get the default catalog.
"""
Expand Down
10 changes: 5 additions & 5 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec):
"extra",
}

supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True
supports_dynamic_schema = True
supports_catalog = True
supports_dynamic_catalog = True
supports_cross_catalog_queries = True

@classmethod
def build_sqlalchemy_uri( # type: ignore
Expand Down Expand Up @@ -433,10 +436,7 @@ def parameters_json_schema(cls) -> Any:
return spec.to_dict()["components"]["schemas"][cls.__name__]

@classmethod
def get_default_catalog(
cls,
database: Database,
) -> str | None:
def get_default_catalog(cls, database: Database) -> str:
"""
Return the default catalog.

Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class DorisEngineSpec(MySQLEngineSpec):
)
encryption_parameters = {"ssl": "0"}
supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = True
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True

column_type_mappings = ( # type: ignore
(
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def adjust_engine_params(
return uri, connect_args

@classmethod
def get_default_catalog(cls, database: Database) -> str | None:
def get_default_catalog(cls, database: Database) -> str:
"""
Return the default catalog for a given database.
"""
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
"""

supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = True
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True

column_type_mappings = (
(
Expand Down
4 changes: 2 additions & 2 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
sqlalchemy_uri_placeholder = "snowflake://"

supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = True
supports_catalog = supports_dynamic_catalog = supports_cross_catalog_queries = True

# pylint: disable=invalid-name
encrypted_extra_sensitive_fields = {
Expand Down Expand Up @@ -189,7 +189,7 @@ def get_schema_from_engine_params(
return parse.unquote(database.split("/")[1])

@classmethod
def get_default_catalog(cls, database: "Database") -> Optional[str]:
def get_default_catalog(cls, database: "Database") -> str:
"""
Return the default catalog.
"""
Expand Down
29 changes: 23 additions & 6 deletions tests/unit_tests/commands/databases/sync_permissions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,31 @@ def test_sync_permissions_command_async_mode_new_db_name(
async_task_mock.delay.assert_called_once_with(1, "admin", "Old Name")


def test_resync_permissions_command_get_catalogs(database_with_catalog: MagicMock):
def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock):
"""
Test the ``_get_catalog_names`` method.
"""
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]


def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock):
"""
Test ``_get_catalog_names`` when only the default one should be returned.
When the database doesn't not support cross-catalog queries (like Postgres), we
should only return all catalogs if multi-catalog is enabled.
"""
database_with_catalog.db_engine_spec.supports_cross_catalog_queries = False
database_with_catalog.allow_multi_catalog = False
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == {"catalog2"}

database_with_catalog.allow_multi_catalog = True
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == ["catalog1", "catalog2"]


@pytest.mark.parametrize(
("inner_exception, outer_exception"),
[
Expand All @@ -249,7 +266,7 @@ def test_resync_permissions_command_get_catalogs(database_with_catalog: MagicMoc
(GenericDBException, DatabaseConnectionFailedError),
],
)
def test_resync_permissions_command_raise_on_getting_catalogs(
def test_sync_permissions_command_raise_on_getting_catalogs(
inner_exception: Exception,
outer_exception: Exception,
database_with_catalog: MagicMock,
Expand All @@ -263,7 +280,7 @@ def test_resync_permissions_command_raise_on_getting_catalogs(
cmmd._get_catalog_names()


def test_resync_permissions_command_get_schemas(database_with_catalog: MagicMock):
def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock):
"""
Test the ``_get_schema_names`` method.
"""
Expand All @@ -282,7 +299,7 @@ def test_resync_permissions_command_get_schemas(database_with_catalog: MagicMock
(GenericDBException, DatabaseConnectionFailedError),
],
)
def test_resync_permissions_command_raise_on_getting_schemas(
def test_sync_permissions_command_raise_on_getting_schemas(
inner_exception: Exception,
outer_exception: Exception,
database_with_catalog: MagicMock,
Expand All @@ -296,7 +313,7 @@ def test_resync_permissions_command_raise_on_getting_schemas(
cmmd._get_schema_names("blah")


def test_resync_permissions_command_refresh_schemas(
def test_sync_permissions_command_refresh_schemas(
mocker: MockerFixture, database_with_catalog: MagicMock
):
"""
Expand All @@ -319,7 +336,7 @@ def test_resync_permissions_command_refresh_schemas(
)


def test_resync_permissions_command_rename_db_in_perms(
def test_sync_permissions_command_rename_db_in_perms(
mocker: MockerFixture, database_with_catalog: MagicMock
):
"""
Expand Down
Loading