diff --git a/labs.yml b/labs.yml index ae357b169b..b089593eec 100644 --- a/labs.yml +++ b/labs.yml @@ -114,3 +114,6 @@ commands: Workspace Group Name\tMembers Count\tAccount Group Name\tMembers Count {{range .}}{{.wf_group_name}}\t{{.wf_group_members_count}}\t{{.acc_group_name}}\t{{.acc_group_members_count}} {{end}} + + - name: migrate_credentials + description: Migrate credentials for storage access to UC storage credential diff --git a/src/databricks/labs/ucx/assessment/secrets.py b/src/databricks/labs/ucx/assessment/secrets.py index 20fed01bb1..715d6d82aa 100644 --- a/src/databricks/labs/ucx/assessment/secrets.py +++ b/src/databricks/labs/ucx/assessment/secrets.py @@ -22,7 +22,12 @@ def _get_secret_if_exists(self, secret_scope, secret_key) -> str | None: assert secret.value is not None return base64.b64decode(secret.value).decode("utf-8") except NotFound: - logger.warning(f'removed on the backend: {secret_scope}{secret_key}') + logger.warning(f'removed on the backend: {secret_scope}/{secret_key}') + return None + except UnicodeDecodeError: + logger.warning( + f"Secret {secret_scope}/{secret_key} has Base64 bytes that cannot be decoded to utf-8 string." + ) return None def _get_value_from_config_key(self, config: dict, key: str, get_secret: bool = True) -> str | None: diff --git a/src/databricks/labs/ucx/azure/access.py b/src/databricks/labs/ucx/azure/access.py index e489d967bb..78a8f7f917 100644 --- a/src/databricks/labs/ucx/azure/access.py +++ b/src/databricks/labs/ucx/azure/access.py @@ -17,6 +17,8 @@ class StoragePermissionMapping: client_id: str principal: str privilege: str + # Need this directory_id/tenant_id when create UC storage credentials using service principal + directory_id: str class AzureResourcePermissions: @@ -63,6 +65,7 @@ def _map_storage(self, storage: AzureResource) -> list[StoragePermissionMapping] client_id=role_assignment.principal.client_id, principal=role_assignment.principal.display_name, privilege=privilege, + directory_id=role_assignment.principal.directory_id, ) ) return out diff --git a/src/databricks/labs/ucx/azure/credentials.py b/src/databricks/labs/ucx/azure/credentials.py new file mode 100644 index 0000000000..c1182587c0 --- /dev/null +++ b/src/databricks/labs/ucx/azure/credentials.py @@ -0,0 +1,272 @@ +import logging +from dataclasses import dataclass + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.tui import Prompts +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors.platform import InvalidParameterValue +from databricks.sdk.service.catalog import ( + AzureServicePrincipal, + Privilege, + StorageCredentialInfo, + ValidationResultResult, +) + +from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler +from databricks.labs.ucx.assessment.secrets import SecretsMixin +from databricks.labs.ucx.azure.access import ( + AzureResourcePermissions, + StoragePermissionMapping, +) +from databricks.labs.ucx.azure.resources import AzureResources +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.framework.crawlers import StatementExecutionBackend +from databricks.labs.ucx.hive_metastore.locations import ExternalLocations + +logger = logging.getLogger(__name__) + + +# A dataclass to host service_principal info and its client_secret info +@dataclass +class ServicePrincipalMigrationInfo: + permission_mapping: StoragePermissionMapping + client_secret: str + + +@dataclass +class StorageCredentialValidationResult: + name: str + application_id: str + directory_id: str + read_only: bool + validated_on: str + failures: list[str] | None = None + + @classmethod + def from_validation(cls, permission_mapping: StoragePermissionMapping, failures: list[str] | None): + return cls( + permission_mapping.principal, + permission_mapping.client_id, + permission_mapping.directory_id, + permission_mapping.privilege == Privilege.READ_FILES.value, + permission_mapping.prefix, + failures, + ) + + +class StorageCredentialManager: + def __init__(self, ws: WorkspaceClient): + self._ws = ws + + def list(self, include_names: set[str] | None = None) -> set[str]: + # list existed storage credentials that is using service principal, capture the service principal's application_id + application_ids = set() + + storage_credentials = self._ws.storage_credentials.list(max_results=0) + + if include_names: + # we only check UC storage credentials listed in include_names + for storage_credential in storage_credentials: + if not storage_credential.azure_service_principal: + continue + if storage_credential.name in include_names: + application_ids.add(storage_credential.azure_service_principal.application_id) + logger.info( + f"Found {len(application_ids)} distinct service principals already used in UC storage credentials listed in include_names" + ) + return application_ids + + for storage_credential in storage_credentials: + # only add service principal's application_id, ignore managed identity based storage_credential + if storage_credential.azure_service_principal: + application_ids.add(storage_credential.azure_service_principal.application_id) + + logger.info(f"Found {len(application_ids)} distinct service principals already used in UC storage credentials") + return application_ids + + def create_with_client_secret(self, spn: ServicePrincipalMigrationInfo) -> StorageCredentialInfo: + # prepare the storage credential properties + name = spn.permission_mapping.principal + service_principal = AzureServicePrincipal( + spn.permission_mapping.directory_id, + spn.permission_mapping.client_id, + spn.client_secret, + ) + comment = ( + f"Created by UCX during migration to UC using Azure Service Principal: {spn.permission_mapping.principal}" + ) + + # create the storage credential + return self._ws.storage_credentials.create( + name, + azure_service_principal=service_principal, + comment=comment, + read_only=spn.permission_mapping.privilege == Privilege.READ_FILES.value, + ) + + def validate(self, permission_mapping: StoragePermissionMapping) -> StorageCredentialValidationResult: + try: + validation = self._ws.storage_credentials.validate( + storage_credential_name=permission_mapping.principal, + url=permission_mapping.prefix, + read_only=permission_mapping.privilege == Privilege.READ_FILES.value, + ) + except InvalidParameterValue: + logger.warning( + "There is an existing external location overlaps with the prefix that is mapped to " + "the service principal and used for validating the migrated storage credential. " + "Skip the validation" + ) + return StorageCredentialValidationResult.from_validation( + permission_mapping, + [ + "The validation is skipped because an existing external location overlaps " + "with the location used for validation." + ], + ) + + if not validation.results: + return StorageCredentialValidationResult.from_validation( + permission_mapping, ["Validation returned none results."] + ) + + failures = [] + for result in validation.results: + if result.operation is None: + continue + if result.result == ValidationResultResult.FAIL: + failures.append(f"{result.operation.value} validation failed with message: {result.message}") + return StorageCredentialValidationResult.from_validation(permission_mapping, None if not failures else failures) + + +class ServicePrincipalMigration(SecretsMixin): + + def __init__( + self, + installation: Installation, + ws: WorkspaceClient, + resource_permissions: AzureResourcePermissions, + service_principal_crawler: AzureServicePrincipalCrawler, + storage_credential_manager: StorageCredentialManager, + ): + self._output_file = "azure_service_principal_migration_result.csv" + self._installation = installation + self._ws = ws + self._resource_permissions = resource_permissions + self._sp_crawler = service_principal_crawler + self._storage_credential_manager = storage_credential_manager + + @classmethod + def for_cli(cls, ws: WorkspaceClient, installation: Installation, prompts: Prompts): + msg = ( + "Have you reviewed the azure_storage_account_info.csv " + "and confirm listed service principals are allowed to be checked for migration?" + ) + if not prompts.confirm(msg): + raise SystemExit() + + config = installation.load(WorkspaceConfig) + sql_backend = StatementExecutionBackend(ws, config.warehouse_id) + azurerm = AzureResources(ws) + locations = ExternalLocations(ws, sql_backend, config.inventory_database) + + resource_permissions = AzureResourcePermissions(installation, ws, azurerm, locations) + sp_crawler = AzureServicePrincipalCrawler(ws, sql_backend, config.inventory_database) + + storage_credential_manager = StorageCredentialManager(ws) + + return cls(installation, ws, resource_permissions, sp_crawler, storage_credential_manager) + + def _fetch_client_secret(self, sp_list: list[StoragePermissionMapping]) -> list[ServicePrincipalMigrationInfo]: + # check AzureServicePrincipalInfo from AzureServicePrincipalCrawler, if AzureServicePrincipalInfo + # has secret_scope and secret_key not empty, fetch the client_secret and put it to the client_secret field + # + # The input StoragePermissionMapping may have managed identity mixed in, we will ignore them for now, as + # they won't have any client_secret, we will process managed identity in the future. + + # fetch client_secrets of crawled service principal, if any + sp_info_with_client_secret: dict[str, str] = {} + sp_infos = self._sp_crawler.snapshot() + + for sp_info in sp_infos: + if not sp_info.secret_scope: + continue + if not sp_info.secret_key: + continue + + secret_value = self._get_secret_if_exists(sp_info.secret_scope, sp_info.secret_key) + + if secret_value: + sp_info_with_client_secret[sp_info.application_id] = secret_value + else: + logger.info( + f"Cannot fetch the service principal client_secret for {sp_info.application_id}. " + f"This service principal will be skipped for migration" + ) + + # update the list of ServicePrincipalMigrationInfo if client_secret is found + sp_list_with_secret = [] + for spn in sp_list: + if spn.client_id in sp_info_with_client_secret: + sp_list_with_secret.append( + ServicePrincipalMigrationInfo(spn, sp_info_with_client_secret[spn.client_id]) + ) + return sp_list_with_secret + + def _print_action_plan(self, sp_list: list[StoragePermissionMapping]): + # print action plan to console for customer to review. + for spn in sp_list: + logger.info( + f"Service Principal name: {spn.principal}, " + f"application_id: {spn.client_id}, " + f"privilege {spn.privilege} " + f"on location {spn.prefix}" + ) + + def _generate_migration_list(self, include_names: set[str] | None = None) -> list[ServicePrincipalMigrationInfo]: + """ + Create the list of SP that need to be migrated, output an action plan as a csv file for users to confirm + """ + # load sp list from azure_storage_account_info.csv + sp_list = self._resource_permissions.load() + # list existed storage credentials + sc_set = self._storage_credential_manager.list(include_names) + # check if the sp is already used in UC storage credential + filtered_sp_list = [sp for sp in sp_list if sp.client_id not in sc_set] + # fetch sp client_secret if any + sp_list_with_secret = self._fetch_client_secret(filtered_sp_list) + + # output the action plan for customer to confirm + # but first make a copy of the list and strip out the client_secret + sp_candidates = [sp.permission_mapping for sp in sp_list_with_secret] + self._print_action_plan(sp_candidates) + + return sp_list_with_secret + + def save(self, migration_results: list[StorageCredentialValidationResult]) -> str: + return self._installation.save(migration_results, filename=self._output_file) + + def run(self, prompts: Prompts, include_names: set[str] | None = None) -> list[StorageCredentialValidationResult]: + + sp_list_with_secret = self._generate_migration_list(include_names) + + plan_confirmed = prompts.confirm( + "Above Azure Service Principals will be migrated to UC storage credentials, please review and confirm." + ) + if plan_confirmed is not True: + return [] + + execution_result = [] + for spn in sp_list_with_secret: + self._storage_credential_manager.create_with_client_secret(spn) + execution_result.append(self._storage_credential_manager.validate(spn.permission_mapping)) + + if execution_result: + results_file = self.save(execution_result) + logger.info( + f"Completed migration from Azure Service Principal to UC Storage credentials" + f"Please check {results_file} for validation results" + ) + else: + logger.info("No Azure Service Principal migrated to UC Storage credentials") + return execution_result diff --git a/src/databricks/labs/ucx/azure/resources.py b/src/databricks/labs/ucx/azure/resources.py index 9dc892a1a8..6c86d464b6 100644 --- a/src/databricks/labs/ucx/azure/resources.py +++ b/src/databricks/labs/ucx/azure/resources.py @@ -70,6 +70,8 @@ class Principal: client_id: str display_name: str object_id: str + # Need this directory_id/tenant_id when create UC storage credentials using service principal + directory_id: str @dataclass @@ -171,10 +173,13 @@ def _get_principal(self, principal_id: str) -> Principal | None: client_id = raw.get("appId") display_name = raw.get("displayName") object_id = raw.get("id") + # Need this directory_id/tenant_id when create UC storage credentials using service principal + directory_id = raw.get("appOwnerOrganizationId") assert client_id is not None assert display_name is not None assert object_id is not None - self._principals[principal_id] = Principal(client_id, display_name, object_id) + assert directory_id is not None + self._principals[principal_id] = Principal(client_id, display_name, object_id, directory_id) return self._principals[principal_id] def role_assignments( diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index c4a7192117..e9160a1f95 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -13,6 +13,7 @@ from databricks.labs.ucx.account import AccountWorkspaces, WorkspaceInfo from databricks.labs.ucx.assessment.aws import AWSResourcePermissions from databricks.labs.ucx.azure.access import AzureResourcePermissions +from databricks.labs.ucx.azure.credentials import ServicePrincipalMigration from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.crawlers import StatementExecutionBackend from databricks.labs.ucx.hive_metastore import ExternalLocations, TablesCrawler @@ -282,5 +283,27 @@ def _aws_principal_prefix_access(w: WorkspaceClient, aws_profile: str): logger.info(f"UC roles and bucket info saved {uc_role_path}") +@ucx.command +def migrate_credentials(w: WorkspaceClient): + """For Azure, this command migrate Azure Service Principals, which have Storage Blob Data Contributor, + Storage Blob Data Reader, Storage Blob Data Owner roles on ADLS Gen2 locations that are being used in + Databricks, to UC storage credentials. + The Azure Service Principals to location mapping are listed in /Users/{user_name}/.ucx/azure_storage_account_info.csv + which is generated by principal_prefix_access command. Please review the file and delete the Service Principals + you do not want to be migrated. + The command will only migrate the Service Principals that have client secret stored in Databricks Secret. + """ + prompts = Prompts() + if w.config.is_azure: + logger.info("Running migrate_credentials for Azure") + installation = Installation.current(w, 'ucx') + service_principal_migration = ServicePrincipalMigration.for_cli(w, installation, prompts) + service_principal_migration.run(prompts) + if w.config.is_aws: + logger.error("migrate_credentials is not yet supported in AWS") + if w.config.is_gcp: + logger.error("migrate_credentials is not yet supported in GCP") + + if __name__ == "__main__": ucx() diff --git a/src/databricks/labs/ucx/mixins/fixtures.py b/src/databricks/labs/ucx/mixins/fixtures.py index ae22cbd27b..df870c1360 100644 --- a/src/databricks/labs/ucx/mixins/fixtures.py +++ b/src/databricks/labs/ucx/mixins/fixtures.py @@ -19,10 +19,12 @@ from databricks.sdk.retries import retried from databricks.sdk.service import compute, iam, jobs, pipelines, sql, workspace from databricks.sdk.service.catalog import ( + AzureServicePrincipal, CatalogInfo, DataSourceFormat, FunctionInfo, SchemaInfo, + StorageCredentialInfo, TableInfo, TableType, ) @@ -1072,3 +1074,24 @@ def remove(query: Query): logger.info(f"Can't drop query {e}") yield from factory("query", create, remove) + + +@pytest.fixture +def make_storage_credential_spn(ws): + def create( + *, credential_name: str, application_id: str, client_secret: str, directory_id: str, read_only=False + ) -> StorageCredentialInfo: + azure_service_principal = AzureServicePrincipal( + directory_id, + application_id, + client_secret, + ) + storage_credential = ws.storage_credentials.create( + credential_name, azure_service_principal=azure_service_principal, read_only=read_only + ) + return storage_credential + + def remove(storage_credential: StorageCredentialInfo): + ws.storage_credentials.delete(storage_credential.name, force=True) + + yield from factory("storage_credential_from_spn", create, remove) diff --git a/tests/integration/azure/test_credentials.py b/tests/integration/azure/test_credentials.py new file mode 100644 index 0000000000..7c3ecbfa55 --- /dev/null +++ b/tests/integration/azure/test_credentials.py @@ -0,0 +1,135 @@ +import base64 +import re +from dataclasses import dataclass + +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.tui import MockPrompts + +from databricks.labs.ucx.assessment.azure import AzureServicePrincipalInfo +from databricks.labs.ucx.azure.access import AzureResourcePermissions +from databricks.labs.ucx.azure.credentials import ( + ServicePrincipalMigration, + StorageCredentialManager, + StorageCredentialValidationResult, +) +from databricks.labs.ucx.azure.resources import AzureResources +from databricks.labs.ucx.hive_metastore import ExternalLocations +from tests.integration.conftest import StaticServicePrincipalCrawler + + +@dataclass +class MigrationTestInfo: + credential_name: str + application_id: str + directory_id: str + secret_scope: str + secret_key: str + client_secret: str + + +@pytest.fixture +def extract_test_info(ws, env_or_skip, make_random): + random = make_random(6).lower() + credential_name = f"testinfra_storageaccess_{random}" + + spark_conf = ws.clusters.get(env_or_skip("TEST_LEGACY_SPN_CLUSTER_ID")).spark_conf + + application_id = spark_conf.get("fs.azure.account.oauth2.client.id") + + end_point = spark_conf.get("fs.azure.account.oauth2.client.endpoint") + directory_id = end_point.split("/")[3] + + secret_matched = re.findall(r"{{secrets\/(.*)\/(.*)}}", spark_conf.get("fs.azure.account.oauth2.client.secret")) + secret_scope = secret_matched[0][0] + secret_key = secret_matched[0][1] + assert secret_scope is not None + assert secret_key is not None + + secret_response = ws.secrets.get_secret(secret_scope, secret_key) + client_secret = base64.b64decode(secret_response.value).decode("utf-8") + + return MigrationTestInfo(credential_name, application_id, directory_id, secret_scope, secret_key, client_secret) + + +@pytest.fixture +def run_migration(ws, sql_backend): + def inner( + test_info: MigrationTestInfo, credentials: set[str], read_only=False + ) -> list[StorageCredentialValidationResult]: + azurerm = AzureResources(ws) + locations = ExternalLocations(ws, sql_backend, "dont_need_a_schema") + + installation = MockInstallation( + { + "azure_storage_account_info.csv": [ + { + 'prefix': 'abfss://things@labsazurethings.dfs.core.windows.net/avoid_ext_loc_overlap', + 'client_id': test_info.application_id, + 'principal': test_info.credential_name, + 'privilege': "READ_FILES" if read_only else "WRITE_FILES", + 'directory_id': test_info.directory_id, + }, + ] + } + ) + resource_permissions = AzureResourcePermissions(installation, ws, azurerm, locations) + + sp_infos = [ + AzureServicePrincipalInfo( + test_info.application_id, + test_info.secret_scope, + test_info.secret_key, + "test", + "test", + ) + ] + sp_crawler = StaticServicePrincipalCrawler(sp_infos, ws, sql_backend, "dont_need_a_schema") + + spn_migration = ServicePrincipalMigration( + installation, ws, resource_permissions, sp_crawler, StorageCredentialManager(ws) + ) + return spn_migration.run( + MockPrompts({"Above Azure Service Principals will be migrated to UC storage credentials *": "Yes"}), + credentials, + ) + + return inner + + +def test_spn_migration_existed_storage_credential(extract_test_info, make_storage_credential_spn, run_migration): + # create a storage credential for this test + make_storage_credential_spn( + credential_name=extract_test_info.credential_name, + application_id=extract_test_info.application_id, + client_secret=extract_test_info.client_secret, + directory_id=extract_test_info.directory_id, + ) + + # test that the spn migration will be skipped due to above storage credential is existed + migration_result = run_migration(extract_test_info, {extract_test_info.credential_name}) + + # assert no spn migrated since migration_result will be empty + assert not migration_result + + +@pytest.mark.parametrize("read_only", [False, True]) +def test_spn_migration(ws, extract_test_info, run_migration, read_only): + try: + migration_results = run_migration(extract_test_info, {"lets_migrate_the_spn"}, read_only) + storage_credential = ws.storage_credentials.get(extract_test_info.credential_name) + finally: + ws.storage_credentials.delete(extract_test_info.credential_name, force=True) + + assert storage_credential is not None + assert storage_credential.read_only is read_only + + if read_only: + failures = migration_results[0].failures + # in this test LIST should fail as validation path does not exist + assert failures + match = re.match(r"LIST validation failed with message: .*The specified path does not exist", failures[0]) + assert match is not None, "LIST validation should fail" + else: + # all validation should pass + assert not migration_results[0].failures diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 4c8156c086..5f7beaf291 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -10,6 +10,10 @@ from databricks.labs.ucx.__about__ import __version__ from databricks.labs.ucx.account import WorkspaceInfo +from databricks.labs.ucx.assessment.azure import ( + AzureServicePrincipalCrawler, + AzureServicePrincipalInfo, +) from databricks.labs.ucx.framework.crawlers import SqlBackend from databricks.labs.ucx.hive_metastore import TablesCrawler from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping @@ -154,3 +158,12 @@ def load(self): def save(self, tables: TablesCrawler, workspace_info: WorkspaceInfo) -> str: raise RuntimeWarning("not available") + + +class StaticServicePrincipalCrawler(AzureServicePrincipalCrawler): + def __init__(self, spn_infos: list[AzureServicePrincipalInfo], *args): + super().__init__(*args) + self._spn_infos = spn_infos + + def snapshot(self) -> list[AzureServicePrincipalInfo]: + return self._spn_infos diff --git a/tests/unit/azure/azure/mappings.json b/tests/unit/azure/azure/mappings.json index 59678edccd..7df33168fd 100644 --- a/tests/unit/azure/azure/mappings.json +++ b/tests/unit/azure/azure/mappings.json @@ -8,12 +8,14 @@ "/v1.0/directoryObjects/user2": { "appId": "appIduser2", "displayName": "disNameuser2", - "id": "Iduser2" + "id": "Iduser2", + "appOwnerOrganizationId": "0000-0000" }, "/v1.0/directoryObjects/user3": { "appId": "appIduser3", "displayName": "disNameuser3", - "id": "Iduser3" + "id": "Iduser3", + "appOwnerOrganizationId": "0000-0000" }, "/subscriptions": { "value": [ diff --git a/tests/unit/azure/test_access.py b/tests/unit/azure/test_access.py index 0957f95c63..04b84d0e72 100644 --- a/tests/unit/azure/test_access.py +++ b/tests/unit/azure/test_access.py @@ -68,13 +68,13 @@ def test_save_spn_permissions_valid_azure_storage_account(): AzureRoleAssignment( resource=AzureResource(f'{containers}/container1'), scope=AzureResource(f'{containers}/container1'), - principal=Principal('a', 'b', 'c'), + principal=Principal('a', 'b', 'c', '0000-0000'), role_name='Storage Blob Data Contributor', ), AzureRoleAssignment( resource=AzureResource(f'{storage_accounts}/storage1'), scope=AzureResource(f'{storage_accounts}/storage1'), - principal=Principal('d', 'e', 'f'), + principal=Principal('d', 'e', 'f', '0000-0000'), role_name='Button Clicker', ), ] @@ -88,12 +88,14 @@ def test_save_spn_permissions_valid_azure_storage_account(): 'prefix': 'abfss://container1@storage1.dfs.core.windows.net/', 'principal': 'b', 'privilege': 'WRITE_FILES', + 'directory_id': '0000-0000', }, { 'client_id': 'a', 'prefix': 'abfss://container2@storage1.dfs.core.windows.net/', 'principal': 'b', 'privilege': 'WRITE_FILES', + 'directory_id': '0000-0000', }, ], ) @@ -132,12 +134,14 @@ def test_save_spn_permissions_valid_storage_accounts(caplog, mocker, az_token): 'prefix': 'abfss://container3@sto2.dfs.core.windows.net/', 'principal': 'disNameuser3', 'privilege': 'WRITE_FILES', + 'directory_id': '0000-0000', }, { 'client_id': 'appIduser3', 'prefix': 'abfss://container3@sto2.dfs.core.windows.net/', 'principal': 'disNameuser3', 'privilege': 'WRITE_FILES', + 'directory_id': '0000-0000', }, ], ) diff --git a/tests/unit/azure/test_credentials.py b/tests/unit/azure/test_credentials.py new file mode 100644 index 0000000000..c7996bf27f --- /dev/null +++ b/tests/unit/azure/test_credentials.py @@ -0,0 +1,337 @@ +import logging +import re +from unittest.mock import create_autospec + +import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.tui import MockPrompts +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ResourceDoesNotExist +from databricks.sdk.errors.platform import InvalidParameterValue +from databricks.sdk.service.catalog import ( + AwsIamRole, + AzureManagedIdentity, + AzureServicePrincipal, + StorageCredentialInfo, + ValidateStorageCredentialResponse, + ValidationResult, + ValidationResultOperation, + ValidationResultResult, +) +from databricks.sdk.service.workspace import GetSecretResponse + +from databricks.labs.ucx.assessment.azure import ( + AzureServicePrincipalCrawler, + AzureServicePrincipalInfo, +) +from databricks.labs.ucx.azure.access import ( + AzureResourcePermissions, + StoragePermissionMapping, +) +from databricks.labs.ucx.azure.credentials import ( + ServicePrincipalMigration, + ServicePrincipalMigrationInfo, + StorageCredentialManager, +) +from databricks.labs.ucx.azure.resources import AzureResources +from databricks.labs.ucx.hive_metastore import ExternalLocations + + +@pytest.fixture +def ws(): + return create_autospec(WorkspaceClient) + + +@pytest.fixture +def installation(): + return MockInstallation( + { + "config.yml": { + 'version': 2, + 'inventory_database': 'ucx', + 'connect': { + 'host': 'foo', + 'token': 'bar', + }, + }, + "azure_storage_account_info.csv": [ + { + 'prefix': 'prefix1', + 'client_id': 'app_secret1', + 'principal': 'principal_1', + 'privilege': 'WRITE_FILES', + 'directory_id': 'directory_id_1', + }, + { + 'prefix': 'prefix2', + 'client_id': 'app_secret2', + 'principal': 'principal_read', + 'privilege': 'READ_FILES', + 'directory_id': 'directory_id_1', + }, + { + 'prefix': 'prefix3', + 'client_id': 'app_secret3', + 'principal': 'principal_write', + 'privilege': 'WRITE_FILES', + 'directory_id': 'directory_id_2', + }, + { + 'prefix': 'overlap_with_external_location', + 'client_id': 'app_secret4', + 'principal': 'principal_overlap', + 'privilege': 'WRITE_FILES', + 'directory_id': 'directory_id_2', + }, + ], + } + ) + + +def side_effect_create_storage_credential(name, azure_service_principal, comment, read_only): + return StorageCredentialInfo( + name=name, azure_service_principal=azure_service_principal, comment=comment, read_only=read_only + ) + + +def side_effect_validate_storage_credential(storage_credential_name, url, read_only): # pylint: disable=unused-argument + if "overlap" in storage_credential_name: + raise InvalidParameterValue + if "none" in storage_credential_name: + return ValidateStorageCredentialResponse() + if "fail" in storage_credential_name: + return ValidateStorageCredentialResponse( + is_dir=True, + results=[ + ValidationResult( + operation=ValidationResultOperation.LIST, result=ValidationResultResult.FAIL, message="fail" + ), + ValidationResult(operation=None, result=ValidationResultResult.FAIL, message="fail"), + ], + ) + if read_only: + return ValidateStorageCredentialResponse( + is_dir=True, + results=[ValidationResult(operation=ValidationResultOperation.READ, result=ValidationResultResult.PASS)], + ) + return ValidateStorageCredentialResponse( + is_dir=True, + results=[ValidationResult(operation=ValidationResultOperation.WRITE, result=ValidationResultResult.PASS)], + ) + + +@pytest.fixture +def credential_manager(ws): + ws.storage_credentials.list.return_value = [ + StorageCredentialInfo(aws_iam_role=AwsIamRole("arn:aws:iam::123456789012:role/example-role-name")), + StorageCredentialInfo( + azure_managed_identity=AzureManagedIdentity("/subscriptions/.../providers/Microsoft.Databricks/...") + ), + StorageCredentialInfo( + name="included_test", + azure_service_principal=AzureServicePrincipal( + "62e43d7d-df53-4c64-86ed-c2c1a3ac60c3", + "b6420590-5e1c-4426-8950-a94cbe9b6115", + "secret", + ), + ), + StorageCredentialInfo(azure_service_principal=AzureServicePrincipal("directory_id_1", "app_secret2", "secret")), + ] + + ws.storage_credentials.create.side_effect = side_effect_create_storage_credential + ws.storage_credentials.validate.side_effect = side_effect_validate_storage_credential + + return StorageCredentialManager(ws) + + +def test_list_storage_credentials(credential_manager): + assert credential_manager.list() == {"b6420590-5e1c-4426-8950-a94cbe9b6115", "app_secret2"} + + +def test_list_included_storage_credentials(credential_manager): + include_names = {"included_test"} + assert credential_manager.list(include_names) == {"b6420590-5e1c-4426-8950-a94cbe9b6115"} + + +def test_create_storage_credentials(credential_manager): + sp_1 = ServicePrincipalMigrationInfo( + StoragePermissionMapping( + "prefix1", + "app_secret1", + "principal_write", + "WRITE_FILES", + "directory_id_1", + ), + "test", + ) + sp_2 = ServicePrincipalMigrationInfo( + StoragePermissionMapping( + "prefix2", + "app_secret2", + "principal_read", + "READ_FILES", + "directory_id_1", + ), + "test", + ) + + storage_credential = credential_manager.create_with_client_secret(sp_1) + assert sp_1.permission_mapping.principal == storage_credential.name + assert storage_credential.read_only is False + + storage_credential = credential_manager.create_with_client_secret(sp_2) + assert sp_2.permission_mapping.principal == storage_credential.name + assert storage_credential.read_only is True + + +def test_validate_storage_credentials(credential_manager): + permission_mapping = StoragePermissionMapping("prefix", "client_id", "principal_1", "WRITE_FILES", "directory_id") + + # validate normal storage credential + validation = credential_manager.validate(permission_mapping) + assert validation.read_only is False + assert validation.name == permission_mapping.principal + assert not validation.failures + + +def test_validate_read_only_storage_credentials(credential_manager): + permission_mapping = StoragePermissionMapping( + "prefix", "client_id", "principal_read", "READ_FILES", "directory_id_1" + ) + + # validate read-only storage credential + validation = credential_manager.validate(permission_mapping) + assert validation.read_only is True + assert validation.name == permission_mapping.principal + assert not validation.failures + + +def test_validate_storage_credentials_overlap_location(credential_manager): + permission_mapping = StoragePermissionMapping("prefix", "client_id", "overlap", "WRITE_FILES", "directory_id_2") + + # prefix used for validation overlaps with existing external location will raise InvalidParameterValue + # assert InvalidParameterValue is handled + validation = credential_manager.validate(permission_mapping) + assert validation.failures == [ + "The validation is skipped because an existing external location overlaps with the location used for validation." + ] + + +def test_validate_storage_credentials_non_response(credential_manager): + permission_mapping = StoragePermissionMapping("prefix", "client_id", "none", "WRITE_FILES", "directory_id") + + validation = credential_manager.validate(permission_mapping) + assert validation.failures == ["Validation returned none results."] + + +def test_validate_storage_credentials_failed_operation(credential_manager): + permission_mapping = StoragePermissionMapping("prefix", "client_id", "fail", "WRITE_FILES", "directory_id_2") + + validation = credential_manager.validate(permission_mapping) + assert validation.failures == ["LIST validation failed with message: fail"] + + +@pytest.fixture +def sp_migration(ws, installation, credential_manager): + ws.secrets.get_secret.return_value = GetSecretResponse(value="aGVsbG8gd29ybGQ=") + + arp = AzureResourcePermissions( + installation, ws, create_autospec(AzureResources), create_autospec(ExternalLocations) + ) + + sp_crawler = create_autospec(AzureServicePrincipalCrawler) + sp_crawler.snapshot.return_value = [ + AzureServicePrincipalInfo("app_secret1", "test_scope", "test_key", "tenant_id_1", "storage1"), + AzureServicePrincipalInfo("app_secret2", "test_scope", "test_key", "tenant_id_1", "storage1"), + AzureServicePrincipalInfo("app_secret3", "test_scope", "", "tenant_id_2", "storage1"), + AzureServicePrincipalInfo("app_secret4", "", "", "tenant_id_2", "storage1"), + ] + + return ServicePrincipalMigration(installation, ws, arp, sp_crawler, credential_manager) + + +def test_for_cli_not_prompts(ws, installation): + ws.config.is_azure = True + prompts = MockPrompts({"Have you reviewed the azure_storage_account_info.csv *": "No"}) + with pytest.raises(SystemExit): + ServicePrincipalMigration.for_cli(ws, installation, prompts) + + +def test_for_cli(ws, installation): + ws.config.is_azure = True + ws.config.auth_type = "azure-cli" + prompts = MockPrompts({"Have you reviewed the azure_storage_account_info.csv *": "Yes"}) + + assert isinstance(ServicePrincipalMigration.for_cli(ws, installation, prompts), ServicePrincipalMigration) + + +@pytest.mark.parametrize( + "secret_bytes_value, num_migrated", + [(GetSecretResponse(value="aGVsbG8gd29ybGQ="), 1), (GetSecretResponse(value="T2zhLCBNdW5kbyE="), 0)], +) +def test_read_secret_value_decode(ws, sp_migration, secret_bytes_value, num_migrated): + ws.secrets.get_secret.return_value = secret_bytes_value + + prompts = MockPrompts({"Above Azure Service Principals will be migrated to UC storage credentials*": "Yes"}) + assert len(sp_migration.run(prompts)) == num_migrated + + +def test_read_secret_value_none(ws, sp_migration): + ws.secrets.get_secret.return_value = GetSecretResponse(value=None) + prompts = MockPrompts({"Above Azure Service Principals will be migrated to UC storage credentials*": "Yes"}) + with pytest.raises(AssertionError): + sp_migration.run(prompts) + + +def test_read_secret_read_exception(caplog, ws, sp_migration): + caplog.set_level(logging.INFO) + ws.secrets.get_secret.side_effect = ResourceDoesNotExist() + + prompts = MockPrompts({"Above Azure Service Principals will be migrated to UC storage credentials*": "Yes"}) + + assert len(sp_migration.run(prompts)) == 0 + assert re.search(r"removed on the backend: .*", caplog.text) + + +def test_print_action_plan(caplog, ws, sp_migration): + caplog.set_level(logging.INFO) + ws.secrets.get_secret.return_value = GetSecretResponse(value="aGVsbG8gd29ybGQ=") + + prompts = MockPrompts({"Above Azure Service Principals will be migrated to UC storage credentials*": "Yes"}) + + sp_migration.run(prompts) + + log_pattern = r"Service Principal name: .* application_id: .* privilege .* on location .*" + for msg in caplog.messages: + if re.search(log_pattern, msg): + assert True + return + assert False, "Action plan is not logged" + + +def test_run_without_confirmation(ws, sp_migration): + ws.secrets.get_secret.return_value = GetSecretResponse(value="aGVsbG8gd29ybGQ=") + prompts = MockPrompts( + { + "Above Azure Service Principals will be migrated to UC storage credentials*": "No", + } + ) + + assert sp_migration.run(prompts) == [] + + +def test_run(ws, installation, sp_migration): + prompts = MockPrompts({"Above Azure Service Principals will be migrated to UC storage credentials*": "Yes"}) + + sp_migration.run(prompts) + installation.assert_file_written( + "azure_service_principal_migration_result.csv", + [ + { + 'application_id': 'app_secret1', + 'directory_id': 'directory_id_1', + 'name': 'principal_1', + 'validated_on': 'prefix1', + } + ], + ) diff --git a/tests/unit/azure/test_resources.py b/tests/unit/azure/test_resources.py index 6c1b067f44..c0c468a20c 100644 --- a/tests/unit/azure/test_resources.py +++ b/tests/unit/azure/test_resources.py @@ -61,7 +61,7 @@ def test_role_assignments_storage(mocker, az_token): assert len(role_assignments) == 1 for role_assignment in role_assignments: assert role_assignment.role_name == "Contributor" - assert role_assignment.principal == Principal("appIduser2", "disNameuser2", "Iduser2") + assert role_assignment.principal == Principal("appIduser2", "disNameuser2", "Iduser2", "0000-0000") assert str(role_assignment.scope) == resource_id assert role_assignment.resource == AzureResource(resource_id) @@ -75,6 +75,6 @@ def test_role_assignments_container(mocker, az_token): assert len(role_assignments) == 1 for role_assignment in role_assignments: assert role_assignment.role_name == "Contributor" - assert role_assignment.principal == Principal("appIduser2", "disNameuser2", "Iduser2") + assert role_assignment.principal == Principal("appIduser2", "disNameuser2", "Iduser2", "0000-0000") assert str(role_assignment.scope) == resource_id assert role_assignment.resource == AzureResource(resource_id) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 1b552c9a25..2ddaddd9c2 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -15,6 +15,7 @@ ensure_assessment_run, installations, manual_workspace_info, + migrate_credentials, move, open_remote_config, principal_prefix_access, @@ -43,11 +44,14 @@ def ws(): } ), '/Users/foo/.ucx/state.json': json.dumps({'resources': {'jobs': {'assessment': '123'}}}), + "/Users/foo/.ucx/azure_storage_account_info.csv": "prefix,client_id,principal,privilege,directory_id\ntest,test,test,test,test", } - def download(path: str) -> io.StringIO: + def download(path: str) -> io.StringIO | io.BytesIO: if path not in state: raise NotFound(path) + if ".csv" in path: + return io.BytesIO(state[path].encode('utf-8')) return io.StringIO(state[path]) workspace_client = create_autospec(WorkspaceClient) @@ -305,3 +309,11 @@ def test_save_storage_and_principal_gcp(ws, caplog): ws.config.is_gcp = True principal_prefix_access(ws) assert "This cmd is only supported for azure and aws workspaces" in caplog.messages + + +def test_migrate_credentials_azure(ws): + ws.config.is_azure = True + ws.workspace.upload.return_value = "test" + with patch("databricks.labs.blueprint.tui.Prompts.confirm", return_value=True): + migrate_credentials(ws) + ws.storage_credentials.list.assert_called()