Skip to content

Commit

Permalink
Add Support for Migrating Table ACL of Interactive clusters using SPN (
Browse files Browse the repository at this point in the history
  • Loading branch information
HariGS-DB authored Mar 31, 2024
1 parent a286c7f commit 4249caf
Show file tree
Hide file tree
Showing 20 changed files with 1,189 additions and 79 deletions.
24 changes: 23 additions & 1 deletion src/databricks/labs/ucx/assessment/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from databricks.labs.lsql.backends import SqlBackend
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service.compute import ClusterSource, Policy
from databricks.sdk.service.compute import ClusterSource, DataSecurityMode, Policy

from databricks.labs.ucx.assessment.crawlers import azure_sp_conf_present_check, logger
from databricks.labs.ucx.assessment.jobs import JobsMixin
Expand All @@ -30,6 +30,15 @@ class AzureServicePrincipalInfo:
storage_account: str | None = None


@dataclass
class ServicePrincipalClusterMapping:
# this class is created separately as we need cluster to spn mapping
# Cluster id where the spn is used
cluster_id: str
# spn info data class
spn_info: set[AzureServicePrincipalInfo]


class AzureServicePrincipalCrawler(CrawlerBase[AzureServicePrincipalInfo], JobsMixin, SecretsMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "azure_service_principals", AzureServicePrincipalInfo)
Expand Down Expand Up @@ -171,3 +180,16 @@ def _get_azure_spn_from_config(self, config: dict) -> set[AzureServicePrincipalI
)
)
return set_service_principals

def get_cluster_to_storage_mapping(self):
# this function gives a mapping between an interactive cluster and the spn used by it
# either directly or through a cluster policy.
set_service_principals = set[AzureServicePrincipalInfo]()
spn_cluster_mapping = []
for cluster in self._ws.clusters.list():
if cluster.cluster_source != ClusterSource.JOB and (
cluster.data_security_mode in [DataSecurityMode.LEGACY_SINGLE_USER, DataSecurityMode.NONE]
):
set_service_principals = self._get_azure_spn_from_cluster_config(cluster)
spn_cluster_mapping.append(ServicePrincipalClusterMapping(cluster.cluster_id, set_service_principals))
return spn_cluster_mapping
3 changes: 1 addition & 2 deletions src/databricks/labs/ucx/hive_metastore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from databricks.labs.ucx.hive_metastore.grants import GrantsCrawler
from databricks.labs.ucx.hive_metastore.locations import (
ExternalLocations,
Mounts,
TablesInMounts,
)
from databricks.labs.ucx.hive_metastore.tables import TablesCrawler

__all__ = ["TablesCrawler", "GrantsCrawler", "Mounts", "ExternalLocations", "TablesInMounts"]
__all__ = ["TablesCrawler", "Mounts", "ExternalLocations", "TablesInMounts"]
254 changes: 251 additions & 3 deletions src/databricks/labs/ucx/hive_metastore/grants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,42 @@
from dataclasses import dataclass
from functools import partial

from databricks.labs.blueprint.installation import Installation
from databricks.labs.blueprint.parallel import ManyError, Threads
from databricks.sdk.service.catalog import SchemaInfo, TableInfo

from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import ResourceDoesNotExist
from databricks.sdk.service.catalog import ExternalLocationInfo, SchemaInfo, TableInfo

from databricks.labs.ucx.assessment.azure import (
AzureServicePrincipalCrawler,
AzureServicePrincipalInfo,
)
from databricks.labs.ucx.azure.access import (
AzureResourcePermissions,
StoragePermissionMapping,
)
from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources
from databricks.labs.ucx.config import WorkspaceConfig
from databricks.labs.ucx.framework.crawlers import CrawlerBase
from databricks.labs.ucx.framework.utils import escape_sql_identifier
from databricks.labs.ucx.hive_metastore.tables import TablesCrawler
from databricks.labs.ucx.hive_metastore.locations import (
ExternalLocations,
Mount,
Mounts,
)
from databricks.labs.ucx.hive_metastore.tables import Table, TablesCrawler
from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler

logger = logging.getLogger(__name__)


@dataclass
class ClusterLocationMapping:
cluster_id: str
locations: dict[str, str]


@dataclass(frozen=True)
class Grant:
principal: str
Expand Down Expand Up @@ -127,6 +152,7 @@ def uc_grant_sql(self, object_type: str | None = None, object_key: str | None =
("TABLE", "SELECT"): self._uc_action("SELECT"),
("TABLE", "MODIFY"): self._uc_action("MODIFY"),
("TABLE", "READ_METADATA"): self._uc_action("BROWSE"),
("TABLE", "ALL PRIVILEGES"): self._uc_action("ALL PRIVILEGES"),
("TABLE", "OWN"): self._set_owner_sql,
("VIEW", "SELECT"): self._uc_action("SELECT"),
("VIEW", "READ_METADATA"): self._uc_action("BROWSE"),
Expand Down Expand Up @@ -307,3 +333,225 @@ def grants(
# TODO: https://github.com/databrickslabs/ucx/issues/406
logger.error(f"Couldn't fetch grants for object {on_type} {key}: {e}")
return []


class AzureACL:
def __init__(
self,
ws: WorkspaceClient,
backend: SqlBackend,
spn_crawler: AzureServicePrincipalCrawler,
resource_permissions: AzureResourcePermissions,
):
self._backend = backend
self._ws = ws
self._spn_crawler = spn_crawler
self._resource_permissions = resource_permissions

@classmethod
def for_cli(cls, ws: WorkspaceClient, installation: Installation):
config = installation.load(WorkspaceConfig)
sql_backend = StatementExecutionBackend(ws, config.warehouse_id)
locations = ExternalLocations(ws, sql_backend, config.inventory_database)
azure_client = AzureAPIClient(
ws.config.arm_environment.resource_manager_endpoint,
ws.config.arm_environment.service_management_endpoint,
)
graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com")
azurerm = AzureResources(azure_client, graph_client)
resource_permissions = AzureResourcePermissions(installation, ws, azurerm, locations)
spn_crawler = AzureServicePrincipalCrawler(ws, sql_backend, config.inventory_database)
return cls(ws, sql_backend, spn_crawler, resource_permissions)

def get_eligible_locations_principals(self) -> dict[str, dict]:
cluster_locations = {}
eligible_locations = {}
spn_cluster_mapping = self._spn_crawler.get_cluster_to_storage_mapping()
if len(spn_cluster_mapping) == 0:
# if there are no interactive clusters , then return empty grants
logger.info("No interactive cluster found with spn configured")
return {}
external_locations = list(self._ws.external_locations.list())
if len(external_locations) == 0:
# if there are no external locations, then throw an error to run migrate_locations cli command
msg = (
"No external location found, If hive metastore tables are created in external storage, "
"ensure migrate-locations cli cmd is run to create the required locations."
)
logger.error(msg)
raise ResourceDoesNotExist(msg) from None

permission_mappings = self._resource_permissions.load()
if len(permission_mappings) == 0:
# if permission mapping is empty, raise an error to run principal_prefix cmd
msg = (
"No storage permission file found. Please ensure principal-prefix-access cli "
"cmd is run to create the access permission file."
)
logger.error(msg)
raise ResourceDoesNotExist(msg) from None

for cluster_spn in spn_cluster_mapping:
for spn in cluster_spn.spn_info:
eligible_locations.update(self._get_external_locations(spn, external_locations, permission_mappings))
cluster_locations[cluster_spn.cluster_id] = eligible_locations
return cluster_locations

def _get_external_locations(
self,
spn: AzureServicePrincipalInfo,
external_locations: list[ExternalLocationInfo],
permission_mappings: list[StoragePermissionMapping],
) -> dict[str, str]:
matching_location = {}
for location in external_locations:
if location.url is None:
continue
for permission_mapping in permission_mappings:
prefix = permission_mapping.prefix
if (
location.url.startswith(permission_mapping.prefix)
and permission_mapping.client_id == spn.application_id
and spn.storage_account is not None
# check for storage account name starting after @ in the prefix url
and prefix[prefix.index('@') + 1 :].startswith(spn.storage_account)
):
matching_location[location.url] = permission_mapping.privilege
return matching_location


class PrincipalACL:
def __init__(
self,
ws: WorkspaceClient,
backend: SqlBackend,
installation: Installation,
tables_crawler: TablesCrawler,
mounts_crawler: Mounts,
cluster_locations: dict[str, dict],
):
self._backend = backend
self._ws = ws
self._installation = installation
self._tables_crawler = tables_crawler
self._mounts_crawler = mounts_crawler
self._cluster_locations = cluster_locations

@classmethod
def for_cli(cls, ws: WorkspaceClient, installation: Installation, sql_backend: SqlBackend):
config = installation.load(WorkspaceConfig)

tables_crawler = TablesCrawler(sql_backend, config.inventory_database)
mount_crawler = Mounts(sql_backend, ws, config.inventory_database)
if ws.config.is_azure:
azure_acl = AzureACL.for_cli(ws, installation)
return cls(
ws,
sql_backend,
installation,
tables_crawler,
mount_crawler,
azure_acl.get_eligible_locations_principals(),
)
if ws.config.is_aws:
return None
if ws.config.is_gcp:
logger.error("UCX is not supported for GCP yet. Please run it on azure or aws")
return None
return None

def get_interactive_cluster_grants(self) -> list[Grant]:
tables = self._tables_crawler.snapshot()
mounts = list(self._mounts_crawler.snapshot())
grants: set[Grant] = set()

for cluster_id, locations in self._cluster_locations.items():
principals = self._get_cluster_principal_mapping(cluster_id)
if len(principals) == 0:
continue
cluster_usage = self._get_grants(locations, principals, tables, mounts)
grants.update(cluster_usage)
catalog_grants = [Grant(principal, "USE", "hive_metastore") for principal in principals]
grants.update(catalog_grants)

return list(grants)

def _get_privilege(self, table: Table, locations: dict[str, str], mounts: list[Mount]):
if table.view_text is not None:
# return nothing for view so that it goes to the separate view logic
return None
if table.location is None:
return None
if table.location.startswith('dbfs:/mnt') or table.location.startswith('/dbfs/mnt'):
mount_location = ExternalLocations.resolve_mount(table.location, mounts)
for loc, privilege in locations.items():
if loc is not None and mount_location.startswith(loc):
return privilege
return None
if table.location.startswith('dbfs:/') or table.location.startswith('/dbfs/'):
return "WRITE_FILES"

for loc, privilege in locations.items():
if loc is not None and table.location.startswith(loc):
return privilege
return None

def _get_database_grants(self, tables: list[Table], principals: list[str]) -> list[Grant]:
databases = {table.database for table in tables}
return [
Grant(principal, "USE", "hive_metastore", database) for database in databases for principal in principals
]

def _get_grants(
self, locations: dict[str, str], principals: list[str], tables: list[Table], mounts: list[Mount]
) -> list[Grant]:
grants = []
filtered_tables = []
for table in tables:
privilege = self._get_privilege(table, locations, mounts)
if privilege == "READ_FILES":
grants.extend(
[Grant(principal, "SELECT", table.catalog, table.database, table.name) for principal in principals]
)
filtered_tables.append(table)
continue
if privilege == "WRITE_FILES":
grants.extend(
[
Grant(principal, "ALL PRIVILEGES", table.catalog, table.database, table.name)
for principal in principals
]
)
filtered_tables.append(table)
continue
if table.view_text is not None:
grants.extend(
[
Grant(principal, "ALL PRIVILEGES", table.catalog, table.database, view=table.name)
for principal in principals
]
)
filtered_tables.append(table)

database_grants = self._get_database_grants(filtered_tables, principals)

grants.extend(database_grants)

return grants

def _get_cluster_principal_mapping(self, cluster_id: str) -> list[str]:
# gets all the users,groups,spn which have access to the clusters and returns a dataclass of that mapping
principal_list = []
cluster_permission = self._ws.permissions.get("clusters", cluster_id)
if cluster_permission.access_control_list is None:
return []
for acl in cluster_permission.access_control_list:
if acl.user_name is not None:
principal_list.append(acl.user_name)
if acl.group_name is not None:
if acl.group_name == "admins":
continue
principal_list.append(acl.group_name)
if acl.service_principal_name is not None:
principal_list.append(acl.service_principal_name)
return principal_list
5 changes: 3 additions & 2 deletions src/databricks/labs/ucx/hive_metastore/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _external_locations(self, tables: list[Row], mounts) -> Iterable[ExternalLoc
if not location:
continue
if location.startswith("dbfs:/mnt"):
location = self._resolve_mount(location, mounts)
location = self.resolve_mount(location, mounts)
if (
not location.startswith("dbfs")
and (self._prefix_size[0] < location.find(":/") < self._prefix_size[1])
Expand All @@ -58,7 +58,8 @@ def _external_locations(self, tables: list[Row], mounts) -> Iterable[ExternalLoc
self._add_jdbc_location(external_locations, location, table)
return external_locations

def _resolve_mount(self, location, mounts):
@staticmethod
def resolve_mount(location, mounts):
for mount in mounts:
if location[5:].startswith(mount.name.lower()):
location = location[5:].replace(mount.name, mount.source)
Expand Down
Loading

0 comments on commit 4249caf

Please sign in to comment.