diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 08f75e7a81..64722d545f 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -21,6 +21,7 @@ class AccountWorkspaces: SYNC_FILE_NAME: ClassVar[str] = "workspaces.json" def __init__(self, account_client: AccountClient, new_workspace_client=WorkspaceClient): + # TODO: new_workspace_client is a design flaw, remove it self._new_workspace_client = new_workspace_client self._ac = account_client diff --git a/src/databricks/labs/ucx/assessment/clusters.py b/src/databricks/labs/ucx/assessment/clusters.py index 03f54dfc47..c5ef5e5dbd 100644 --- a/src/databricks/labs/ucx/assessment/clusters.py +++ b/src/databricks/labs/ucx/assessment/clusters.py @@ -138,7 +138,7 @@ def _check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[ class ClustersCrawler(CrawlerBase[ClusterInfo], CheckClusterMixin): - def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema): + def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema: str): super().__init__(sbe, "hive_metastore", schema, "clusters", ClusterInfo) self._ws = ws diff --git a/src/databricks/labs/ucx/assessment/workflows.py b/src/databricks/labs/ucx/assessment/workflows.py new file mode 100644 index 0000000000..b785d512b6 --- /dev/null +++ b/src/databricks/labs/ucx/assessment/workflows.py @@ -0,0 +1,218 @@ +from databricks.labs.ucx.contexts.workflow_task import RuntimeContext +from databricks.labs.ucx.framework.tasks import Workflow, job_task + + +class Assessment(Workflow): + def __init__(self): + super().__init__('assessment') + + @job_task(notebook="hive_metastore/tables.scala") + def crawl_tables(self, ctx: RuntimeContext): + """Iterates over all tables in the Hive Metastore of the current workspace and persists their metadata, such + as _database name_, _table name_, _table type_, _table location_, etc., in the Delta table named + `$inventory_database.tables`. Note that the `inventory_database` is set in the configuration file. The metadata + stored is then used in the subsequent tasks and workflows to, for example, find all Hive Metastore tables that + cannot easily be migrated to Unity Catalog.""" + + @job_task(job_cluster="tacl") + def setup_tacl(self, ctx: RuntimeContext): + """(Optimization) Starts `tacl` job cluster in parallel to crawling tables.""" + + @job_task(depends_on=[crawl_tables, setup_tacl], job_cluster="tacl") + def crawl_grants(self, ctx: RuntimeContext): + """Scans the previously created Delta table named `$inventory_database.tables` and issues a `SHOW GRANTS` + statement for every object to retrieve the permissions it has assigned to it. The permissions include information + such as the _principal_, _action type_, and the _table_ it applies to. This is persisted in the Delta table + `$inventory_database.grants`. Other, migration related jobs use this inventory table to convert the legacy Table + ACLs to Unity Catalog permissions. + + Note: This job runs on a separate cluster (named `tacl`) as it requires the proper configuration to have the Table + ACLs enabled and available for retrieval.""" + ctx.grants_crawler.snapshot() + + @job_task(depends_on=[crawl_tables]) + def estimate_table_size_for_migration(self, ctx: RuntimeContext): + """Scans the previously created Delta table named `$inventory_database.tables` and locate tables that cannot be + "synced". These tables will have to be cloned in the migration process. + Assesses the size of these tables and create `$inventory_database.table_size` table to list these sizes. + The table size is a factor in deciding whether to clone these tables.""" + ctx.table_size_crawler.snapshot() + + @job_task + def crawl_mounts(self, ctx: RuntimeContext): + """Defines the scope of the _mount points_ intended for migration into Unity Catalog. As these objects are not + compatible with the Unity Catalog paradigm, a key component of the migration process involves transferring them + to Unity Catalog External Locations. + + The assessment involves scanning the workspace to compile a list of all existing mount points and subsequently + storing this information in the `$inventory.mounts` table. This is crucial for planning the migration.""" + ctx.mounts_crawler.snapshot() + + @job_task(depends_on=[crawl_mounts, crawl_tables]) + def guess_external_locations(self, ctx: RuntimeContext): + """Determines the shared path prefixes of all the tables. Specifically, the focus is on identifying locations that + utilize mount points. The goal is to identify the _external locations_ necessary for a successful migration and + store this information in the `$inventory.external_locations` table. + + The approach taken in this assessment involves the following steps: + - Extracting all the locations associated with tables that do not use DBFS directly, but a mount point instead + - Scanning all these locations to identify folders that can act as shared path prefixes + - These identified external locations will be created subsequently prior to the actual table migration""" + ctx.external_locations.snapshot() + + @job_task + def assess_jobs(self, ctx: RuntimeContext): + """Scans through all the jobs and identifies those that are not compatible with UC. The list of all the jobs is + stored in the `$inventory.jobs` table. + + It looks for: + - Clusters with Databricks Runtime (DBR) version earlier than 11.3 + - Clusters using Passthrough Authentication + - Clusters with incompatible Spark config tags + - Clusters referencing DBFS locations in one or more config options + """ + ctx.jobs_crawler.snapshot() + + @job_task + def assess_clusters(self, ctx: RuntimeContext): + """Scan through all the clusters and identifies those that are not compatible with UC. The list of all the clusters + is stored in the`$inventory.clusters` table. + + It looks for: + - Clusters with Databricks Runtime (DBR) version earlier than 11.3 + - Clusters using Passthrough Authentication + - Clusters with incompatible spark config tags + - Clusters referencing DBFS locations in one or more config options + """ + ctx.clusters_crawler.snapshot() + + @job_task + def assess_pipelines(self, ctx: RuntimeContext): + """This module scans through all the Pipelines and identifies those pipelines which has Azure Service Principals + embedded (who has been given access to the Azure storage accounts via spark configurations) in the pipeline + configurations. + + It looks for: + - all the pipelines which has Azure Service Principal embedded in the pipeline configuration + + Subsequently, a list of all the pipelines with matching configurations are stored in the + `$inventory.pipelines` table.""" + ctx.pipelines_crawler.snapshot() + + @job_task + def assess_incompatible_submit_runs(self, ctx: RuntimeContext): + """This module scans through all the Submit Runs and identifies those runs which may become incompatible after + the workspace attachment. + + It looks for: + - All submit runs with DBR >=11.3 and data_security_mode:None + + It also combines several submit runs under a single pseudo_id based on hash of the submit run configuration. + Subsequently, a list of all the incompatible runs with failures are stored in the + `$inventory.submit_runs` table.""" + ctx.submit_runs_crawler.snapshot() + + @job_task + def crawl_cluster_policies(self, ctx: RuntimeContext): + """This module scans through all the Cluster Policies and get the necessary information + + It looks for: + - Clusters Policies with Databricks Runtime (DBR) version earlier than 11.3 + + Subsequently, a list of all the policies with matching configurations are stored in the + `$inventory.policies` table.""" + ctx.policies_crawler.snapshot() + + @job_task(cloud="azure") + def assess_azure_service_principals(self, ctx: RuntimeContext): + """This module scans through all the clusters configurations, cluster policies, job cluster configurations, + Pipeline configurations, Warehouse configuration and identifies all the Azure Service Principals who has been + given access to the Azure storage accounts via spark configurations referred in those entities. + + It looks in: + - all those entities and prepares a list of Azure Service Principal embedded in their configurations + + Subsequently, the list of all the Azure Service Principals referred in those configurations are saved + in the `$inventory.azure_service_principals` table.""" + if ctx.is_azure: + ctx.azure_service_principal_crawler.snapshot() + + @job_task + def assess_global_init_scripts(self, ctx: RuntimeContext): + """This module scans through all the global init scripts and identifies if there is an Azure Service Principal + who has been given access to the Azure storage accounts via spark configurations referred in those scripts. + + It looks in: + - the list of all the global init scripts are saved in the `$inventory.azure_service_principals` table.""" + ctx.global_init_scripts_crawler.snapshot() + + @job_task + def workspace_listing(self, ctx: RuntimeContext): + """Scans the workspace for workspace objects. It recursively list all sub directories + and compiles a list of directories, notebooks, files, repos and libraries in the workspace. + + It uses multi-threading to parallelize the listing process to speed up execution on big workspaces. + It accepts starting path as the parameter defaulted to the root path '/'.""" + ctx.workspace_listing.snapshot() + + @job_task(depends_on=[crawl_grants, workspace_listing]) + def crawl_permissions(self, ctx: RuntimeContext): + """Scans the workspace-local groups and all their permissions. The list is stored in the `$inventory.permissions` + Delta table. + + This is the first step for the _group migration_ process, which is continued in the `migrate-groups` workflow. + This step includes preparing Legacy Table ACLs for local group migration.""" + permission_manager = ctx.permission_manager + permission_manager.cleanup() + permission_manager.inventorize_permissions() + + @job_task + def crawl_groups(self, ctx: RuntimeContext): + """Scans all groups for the local group migration scope""" + ctx.group_manager.snapshot() + + @job_task( + depends_on=[ + crawl_grants, + crawl_groups, + crawl_permissions, + guess_external_locations, + assess_jobs, + assess_incompatible_submit_runs, + assess_clusters, + crawl_cluster_policies, + assess_azure_service_principals, + assess_pipelines, + assess_global_init_scripts, + crawl_tables, + ], + dashboard="assessment_main", + ) + def assessment_report(self, ctx: RuntimeContext): + """Refreshes the assessment dashboard after all previous tasks have been completed. Note that you can access the + dashboard _before_ all tasks have been completed, but then only already completed information is shown.""" + + @job_task( + depends_on=[ + assess_jobs, + assess_incompatible_submit_runs, + assess_clusters, + assess_pipelines, + crawl_tables, + ], + dashboard="assessment_estimates", + ) + def estimates_report(self, ctx: RuntimeContext): + """Refreshes the assessment dashboard after all previous tasks have been completed. Note that you can access the + dashboard _before_ all tasks have been completed, but then only already completed information is shown.""" + + +class DestroySchema(Workflow): + def __init__(self): + super().__init__('099-destroy-schema') + + @job_task + def destroy_schema(self, ctx: RuntimeContext): + """This _clean-up_ workflow allows to removes the `$inventory` database, with all the inventory tables created by + the previous workflow runs. Use this to reset the entire state and start with the assessment step again.""" + ctx.sql_backend.execute(f"DROP DATABASE {ctx.inventory_database} CASCADE") diff --git a/src/databricks/labs/ucx/aws/access.py b/src/databricks/labs/ucx/aws/access.py index c4d9ee56d0..89d41427c3 100644 --- a/src/databricks/labs/ucx/aws/access.py +++ b/src/databricks/labs/ucx/aws/access.py @@ -49,24 +49,6 @@ def __init__( self._kms_key = kms_key self._filename = self.INSTANCE_PROFILES_FILE_NAMES - @classmethod - def for_cli(cls, ws: WorkspaceClient, installation, backend, aws, schema, kms_key=None): - config = installation.load(WorkspaceConfig) - caller_identity = aws.validate_connection() - locations = ExternalLocations(ws, backend, config.inventory_database) - if not caller_identity: - raise ResourceWarning("AWS CLI is not configured properly.") - return cls( - installation, - ws, - backend, - aws, - locations, - schema, - caller_identity.get("Account"), - kms_key, - ) - def create_uc_roles_cli(self, *, single_role=True, role_name="UC_ROLE", policy_name="UC_POLICY"): # Get the missing paths # Identify the S3 prefixes diff --git a/src/databricks/labs/ucx/aws/credentials.py b/src/databricks/labs/ucx/aws/credentials.py index b9ed616e2d..c0bbaab39e 100644 --- a/src/databricks/labs/ucx/aws/credentials.py +++ b/src/databricks/labs/ucx/aws/credentials.py @@ -3,7 +3,6 @@ from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.tui import Prompts -from databricks.labs.lsql.backends import StatementExecutionBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors.platform import InvalidParameterValue from databricks.sdk.service.catalog import ( @@ -13,9 +12,8 @@ ValidationResultResult, ) -from databricks.labs.ucx.assessment.aws import AWSResources, AWSRoleAction +from databricks.labs.ucx.assessment.aws import AWSRoleAction from databricks.labs.ucx.aws.access import AWSResourcePermissions -from databricks.labs.ucx.config import WorkspaceConfig logger = logging.getLogger(__name__) @@ -125,30 +123,6 @@ def __init__( self._resource_permissions = resource_permissions self._storage_credential_manager = storage_credential_manager - @classmethod - def for_cli(cls, ws: WorkspaceClient, installation: Installation, aws: AWSResources, prompts: Prompts): - if not ws.config.is_aws: - logger.error("Workspace is not on AWS, please run this command on a Databricks on AWS workspaces.") - raise SystemExit() - - msg = ( - f"Have you reviewed the {AWSResourcePermissions.UC_ROLES_FILE_NAMES} " - "and confirm listed IAM roles to be migrated?" - ) - if not prompts.confirm(msg): - raise SystemExit() - - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - - resource_permissions = AWSResourcePermissions.for_cli( - ws, installation, sql_backend, aws, config.inventory_database - ) - - storage_credential_manager = CredentialManager(ws) - - return cls(installation, ws, resource_permissions, storage_credential_manager) - @staticmethod def _print_action_plan(iam_list: list[AWSRoleAction]): # print action plan to console for customer to review. diff --git a/src/databricks/labs/ucx/azure/access.py b/src/databricks/labs/ucx/azure/access.py index 8f9e19fa12..4c4d7e7747 100644 --- a/src/databricks/labs/ucx/azure/access.py +++ b/src/databricks/labs/ucx/azure/access.py @@ -4,14 +4,12 @@ from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.tui import Prompts -from databricks.labs.lsql.backends import StatementExecutionBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound, ResourceAlreadyExists from databricks.sdk.service.catalog import Privilege from databricks.labs.ucx.assessment.crawlers import logger from databricks.labs.ucx.azure.resources import ( - AzureAPIClient, AzureResource, AzureResources, PrincipalSecret, @@ -32,6 +30,8 @@ class StoragePermissionMapping: class AzureResourcePermissions: + FILENAME = 'azure_storage_account_info.csv' + def __init__( self, installation: Installation, @@ -39,7 +39,6 @@ def __init__( azurerm: AzureResources, external_locations: ExternalLocations, ): - self._filename = 'azure_storage_account_info.csv' self._installation = installation self._locations = external_locations self._azurerm = azurerm @@ -50,20 +49,6 @@ def __init__( "Storage Blob Data Reader": Privilege.READ_FILES, } - @classmethod - def for_cli(cls, ws: WorkspaceClient, product='ucx', include_subscriptions=None): - installation = Installation.current(ws, product) - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - azure_mgmt_client = AzureAPIClient( - ws.config.arm_environment.resource_manager_endpoint, - ws.config.arm_environment.service_management_endpoint, - ) - graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") - azurerm = AzureResources(azure_mgmt_client, graph_client, include_subscriptions) - locations = ExternalLocations(ws, sql_backend, config.inventory_database) - return cls(installation, ws, azurerm, locations) - def _map_storage(self, storage: AzureResource) -> list[StoragePermissionMapping]: logger.info(f"Fetching role assignment for {storage.storage_account}") out = [] @@ -103,7 +88,7 @@ def save_spn_permissions(self) -> str | None: if len(storage_account_infos) == 0: logger.error("No storage account found in current tenant with spn permission") return None - return self._installation.save(storage_account_infos, filename=self._filename) + return self._installation.save(storage_account_infos, filename=self.FILENAME) def _update_cluster_policy_definition( self, @@ -221,7 +206,7 @@ def _create_scope(self, uber_principal: PrincipalSecret, inventory_database: str self._ws.secrets.put_secret(inventory_database, "uber_principal_secret", string_value=uber_principal.secret) def load(self): - return self._installation.load(list[StoragePermissionMapping], filename=self._filename) + return self._installation.load(list[StoragePermissionMapping], filename=self.FILENAME) def _get_storage_accounts(self) -> list[str]: external_locations = self._locations.snapshot() diff --git a/src/databricks/labs/ucx/azure/credentials.py b/src/databricks/labs/ucx/azure/credentials.py index 91f84a20b9..fed05f8731 100644 --- a/src/databricks/labs/ucx/azure/credentials.py +++ b/src/databricks/labs/ucx/azure/credentials.py @@ -3,7 +3,6 @@ from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.tui import Prompts -from databricks.labs.lsql.backends import StatementExecutionBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors.platform import InvalidParameterValue from databricks.sdk.service.catalog import ( @@ -19,9 +18,6 @@ AzureResourcePermissions, StoragePermissionMapping, ) -from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources -from databricks.labs.ucx.config import WorkspaceConfig -from databricks.labs.ucx.hive_metastore.locations import ExternalLocations logger = logging.getLogger(__name__) @@ -160,32 +156,6 @@ def __init__( self._sp_crawler = service_principal_crawler self._storage_credential_manager = storage_credential_manager - @classmethod - def for_cli(cls, ws: WorkspaceClient, installation: Installation, prompts: Prompts): - msg = ( - "Have you reviewed the azure_storage_account_info.csv " - "and confirm listed service principals are allowed to be checked for migration?" - ) - if not prompts.confirm(msg): - raise SystemExit() - - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - azure_mgmt_client = AzureAPIClient( - ws.config.arm_environment.resource_manager_endpoint, - ws.config.arm_environment.service_management_endpoint, - ) - graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") - azurerm = AzureResources(azure_mgmt_client, graph_client) - locations = ExternalLocations(ws, sql_backend, config.inventory_database) - - resource_permissions = AzureResourcePermissions(installation, ws, azurerm, locations) - sp_crawler = AzureServicePrincipalCrawler(ws, sql_backend, config.inventory_database) - - storage_credential_manager = StorageCredentialManager(ws) - - return cls(installation, ws, resource_permissions, sp_crawler, storage_credential_manager) - def _fetch_client_secret(self, sp_list: list[StoragePermissionMapping]) -> list[ServicePrincipalMigrationInfo]: # check AzureServicePrincipalInfo from AzureServicePrincipalCrawler, if AzureServicePrincipalInfo # has secret_scope and secret_key not empty, fetch the client_secret and put it to the client_secret field diff --git a/src/databricks/labs/ucx/azure/locations.py b/src/databricks/labs/ucx/azure/locations.py index 8faa013a24..595caf531a 100644 --- a/src/databricks/labs/ucx/azure/locations.py +++ b/src/databricks/labs/ucx/azure/locations.py @@ -1,14 +1,11 @@ import logging from urllib.parse import urlparse -from databricks.labs.blueprint.installation import Installation -from databricks.labs.lsql.backends import StatementExecutionBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors.platform import InvalidParameterValue, PermissionDenied from databricks.labs.ucx.azure.access import AzureResourcePermissions -from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources -from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.azure.resources import AzureResources from databricks.labs.ucx.hive_metastore import ExternalLocations logger = logging.getLogger(__name__) @@ -27,23 +24,6 @@ def __init__( self._resource_permissions = resource_permissions self._azurerm = azurerm - @classmethod - def for_cli(cls, ws: WorkspaceClient, installation: Installation): - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - hms_locations = ExternalLocations(ws, sql_backend, config.inventory_database) - - azure_mgmt_client = AzureAPIClient( - ws.config.arm_environment.resource_manager_endpoint, - ws.config.arm_environment.service_management_endpoint, - ) - graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") - azurerm = AzureResources(azure_mgmt_client, graph_client) - - resource_permissions = AzureResourcePermissions(installation, ws, azurerm, hms_locations) - - return cls(ws, hms_locations, resource_permissions, azurerm) - def _app_id_credential_name_mapping(self) -> tuple[dict[str, str], dict[str, str]]: # list all storage credentials. # generate the managed identity/service principal application id to credential name mapping. diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index 87352e7164..4a6a3b20ba 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -1,36 +1,17 @@ import json -import os -import shutil import webbrowser -from collections.abc import Callable from pathlib import Path from databricks.labs.blueprint.cli import App from databricks.labs.blueprint.entrypoint import get_logger from databricks.labs.blueprint.installation import Installation, SerdeError from databricks.labs.blueprint.tui import Prompts -from databricks.labs.lsql.backends import StatementExecutionBackend from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.errors import NotFound -from databricks.labs.ucx.account import AccountWorkspaces, WorkspaceInfo -from databricks.labs.ucx.assessment.aws import AWSResources -from databricks.labs.ucx.aws.access import AWSResourcePermissions -from databricks.labs.ucx.aws.credentials import IamRoleMigration -from databricks.labs.ucx.azure.access import AzureResourcePermissions -from databricks.labs.ucx.azure.credentials import ServicePrincipalMigration -from databricks.labs.ucx.azure.locations import ExternalLocationsMigration +from databricks.labs.ucx.account import AccountWorkspaces from databricks.labs.ucx.config import WorkspaceConfig -from databricks.labs.ucx.hive_metastore import ExternalLocations, TablesCrawler -from databricks.labs.ucx.hive_metastore.catalog_schema import CatalogSchema -from databricks.labs.ucx.hive_metastore.mapping import TableMapping -from databricks.labs.ucx.hive_metastore.table_migrate import TablesMigrator -from databricks.labs.ucx.hive_metastore.table_move import TableMove -from databricks.labs.ucx.install import WorkspaceInstallation -from databricks.labs.ucx.installer.workflows import WorkflowsDeployment -from databricks.labs.ucx.source_code.files import Files -from databricks.labs.ucx.workspace_access.clusters import ClusterAccess -from databricks.labs.ucx.workspace_access.groups import GroupManager +from databricks.labs.ucx.contexts.cli_command import AccountContext, WorkspaceContext ucx = App(__file__) logger = get_logger(__file__) @@ -44,16 +25,18 @@ @ucx.command def workflows(w: WorkspaceClient): """Show deployed workflows and their state""" - installation = WorkflowsDeployment.for_cli(w) + ctx = WorkspaceContext(w) logger.info("Fetching deployed jobs...") - print(json.dumps(installation.latest_job_status())) + latest_job_status = ctx.deployed_workflows.latest_job_status() + print(json.dumps(latest_job_status)) @ucx.command def open_remote_config(w: WorkspaceClient): """Opens remote configuration in the browser""" - installation = WorkspaceInstallation.current(w) - webbrowser.open(installation.config_file_link()) + ctx = WorkspaceContext(w) + workspace_link = ctx.installation.workspace_link('config.yml') + webbrowser.open(workspace_link) @ucx.command @@ -84,20 +67,19 @@ def skip(w: WorkspaceClient, schema: str | None = None, table: str | None = None logger.info("Running skip command") if not schema: logger.error("--schema is a required parameter.") - return - mapping = TableMapping.current(w) + return None + ctx = WorkspaceContext(w) if table: - mapping.skip_table(schema, table) - else: - mapping.skip_schema(schema) + return ctx.table_mapping.skip_table(schema, table) + return ctx.table_mapping.skip_schema(schema) @ucx.command(is_account=True) def sync_workspace_info(a: AccountClient): """upload workspace config to all workspaces in the account where ucx is installed""" logger.info(f"Account ID: {a.config.account_id}") - workspaces = AccountWorkspaces(a) - workspaces.sync_workspace_info() + ctx = AccountContext(a) + ctx.account_workspaces.sync_workspace_info() @ucx.command(is_account=True) @@ -128,32 +110,23 @@ def create_account_groups( @ucx.command def manual_workspace_info(w: WorkspaceClient, prompts: Prompts): """only supposed to be run if cannot get admins to run `databricks labs ucx sync-workspace-info`""" - installation = Installation.current(w, 'ucx') - workspace_info = WorkspaceInfo(installation, w) - workspace_info.manual_workspace_info(prompts) + ctx = WorkspaceContext(w) + ctx.workspace_info.manual_workspace_info(prompts) @ucx.command def create_table_mapping(w: WorkspaceClient): """create initial table mapping for review""" - table_mapping = TableMapping.current(w) - installation = Installation.current(w, 'ucx') - workspace_info = WorkspaceInfo(installation, w) - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(w, config.warehouse_id) - tables_crawler = TablesCrawler(sql_backend, config.inventory_database) - path = table_mapping.save(tables_crawler, workspace_info) + ctx = WorkspaceContext(w) + path = ctx.table_mapping.save(ctx.tables_crawler, ctx.workspace_info) webbrowser.open(f"{w.config.host}/#workspace{path}") @ucx.command def validate_external_locations(w: WorkspaceClient, prompts: Prompts): """validates and provides mapping to external table to external location and shared generation tf scripts""" - installation = Installation.current(w, 'ucx') - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(w, config.warehouse_id) - location_crawler = ExternalLocations(w, sql_backend, config.inventory_database) - path = location_crawler.save_as_terraform_definitions_on_workspace(installation) + ctx = WorkspaceContext(w) + path = ctx.external_locations.save_as_terraform_definitions_on_workspace(ctx.installation) if path and prompts.confirm(f"external_locations.tf file written to {path}. Do you want to open it?"): webbrowser.open(f"{w.config.host}/#workspace{path}") @@ -161,7 +134,8 @@ def validate_external_locations(w: WorkspaceClient, prompts: Prompts): @ucx.command def ensure_assessment_run(w: WorkspaceClient): """ensure the assessment job was run on a workspace""" - deployed_workflows = WorkflowsDeployment.for_cli(w) + ctx = WorkspaceContext(w) + deployed_workflows = ctx.deployed_workflows if not deployed_workflows.validate_step("assessment"): deployed_workflows.run_workflow("assessment") @@ -171,45 +145,39 @@ def repair_run(w: WorkspaceClient, step): """Repair Run the Failed Job""" if not step: raise KeyError("You did not specify --step") - installation = WorkflowsDeployment.for_cli(w) + ctx = WorkspaceContext(w) logger.info(f"Repair Running {step} Job") - installation.repair_run(step) + ctx.deployed_workflows.repair_run(step) @ucx.command def validate_groups_membership(w: WorkspaceClient): """Validate the groups to see if the groups at account level and workspace level has different membership""" - installation = Installation.current(w, 'ucx') - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(w, config.warehouse_id) - logger.info("Validating Groups which are having different memberships between account and workspace") - group_manager = GroupManager( - sql_backend=sql_backend, - ws=w, - inventory_database=config.inventory_database, - include_group_names=config.include_group_names, - renamed_group_prefix=config.renamed_group_prefix, - workspace_group_regex=config.workspace_group_regex, - workspace_group_replace=config.workspace_group_replace, - account_group_regex=config.account_group_regex, - ) - mismatch_groups = group_manager.validate_group_membership() + ctx = WorkspaceContext(w) + mismatch_groups = ctx.group_manager.validate_group_membership() print(json.dumps(mismatch_groups)) @ucx.command def revert_migrated_tables( - w: WorkspaceClient, prompts: Prompts, schema: str, table: str, *, delete_managed: bool = False + w: WorkspaceClient, + prompts: Prompts, + schema: str, + table: str, + *, + delete_managed: bool = False, + ctx: WorkspaceContext | None = None, ): """remove notation on a migrated table for re-migration""" if not schema and not table: question = "You haven't specified a schema or a table. All migrated tables will be reverted. Continue?" if not prompts.confirm(question, max_attempts=2): return - tables_migrate = TablesMigrator.for_cli(w) - revert = tables_migrate.print_revert_report(delete_managed=delete_managed) + if not ctx: + ctx = WorkspaceContext(w) + revert = ctx.tables_migrator.print_revert_report(delete_managed=delete_managed) if revert and prompts.confirm("Would you like to continue?", max_attempts=2): - tables_migrate.revert_migrated_tables(schema, table, delete_managed=delete_managed) + ctx.tables_migrator.revert_migrated_tables(schema, table, delete_managed=delete_managed) @ucx.command @@ -233,14 +201,14 @@ def move( if from_catalog == to_catalog and from_schema == to_schema: logger.error("please select a different schema or catalog to migrate to") return - tables = TableMove.for_cli(w) if not prompts.confirm(f"[WARNING] External tables will be dropped and recreated in the target schema {to_schema}"): return del_table = prompts.confirm( f"should we delete managed tables & views after moving to the new schema" f" {to_catalog}.{to_schema}" ) logger.info(f"migrating tables {from_table} from {from_catalog}.{from_schema} to {to_catalog}.{to_schema}") - tables.move(from_catalog, from_schema, from_table, to_catalog, to_schema, del_table=del_table) + ctx = WorkspaceContext(w) + ctx.table_move.move(from_catalog, from_schema, from_table, to_catalog, to_schema, del_table=del_table) @ucx.command @@ -262,176 +230,52 @@ def alias( if from_catalog == to_catalog and from_schema == to_schema: logger.error("please select a different schema or catalog to migrate to") return - tables = TableMove.for_cli(w) logger.info(f"aliasing table {from_table} from {from_catalog}.{from_schema} to {to_catalog}.{to_schema}") - tables.alias_tables(from_catalog, from_schema, from_table, to_catalog, to_schema) - - -def _execute_for_cloud( - w: WorkspaceClient, - prompts: Prompts, - func_azure: Callable, - func_aws: Callable, - azure_resource_permissions: AzureResourcePermissions | None = None, - subscription_id: str | None = None, - aws_permissions: AWSResourcePermissions | None = None, - aws_profile: str | None = None, -): - if w.config.is_azure: - if w.config.auth_type != "azure-cli": - logger.error("In order to obtain AAD token, Please run azure cli to authenticate.") - return None - if not subscription_id: - logger.error("Please enter subscription id to scan storage accounts in.") - return None - return func_azure( - w, prompts, subscription_id=subscription_id, azure_resource_permissions=azure_resource_permissions - ) - if w.config.is_aws: - if not shutil.which("aws"): - logger.error("Couldn't find AWS CLI in path. Please install the CLI from https://aws.amazon.com/cli/") - return None - if not aws_profile: - aws_profile = os.getenv("AWS_DEFAULT_PROFILE") - if not aws_profile: - logger.error( - "AWS Profile is not specified. Use the environment variable [AWS_DEFAULT_PROFILE] " - "or use the '--aws-profile=[profile-name]' parameter." - ) - return None - return func_aws(w, prompts, aws_profile=aws_profile, aws_permissions=aws_permissions) - logger.error("This cmd is only supported for azure and aws workspaces") - return None + ctx = WorkspaceContext(w) + ctx.table_move.alias_tables(from_catalog, from_schema, from_table, to_catalog, to_schema) @ucx.command def create_uber_principal( w: WorkspaceClient, prompts: Prompts, - subscription_id: str | None = None, - azure_resource_permissions: AzureResourcePermissions | None = None, - aws_profile: str | None = None, - aws_resource_permissions: AWSResourcePermissions | None = None, + ctx: WorkspaceContext | None = None, + **named_parameters, ): """For azure cloud, creates a service principal and gives STORAGE BLOB READER access on all the storage account used by tables in the workspace and stores the spn info in the UCX cluster policy. For aws, it identifies all s3 buckets used by the Instance Profiles configured in the workspace. Pass subscription_id for azure and aws_profile for aws.""" - return _execute_for_cloud( - w, - prompts, - _azure_setup_uber_principal, - _aws_setup_uber_principal, - azure_resource_permissions, - subscription_id, - aws_resource_permissions, - aws_profile, - ) - - -def _azure_setup_uber_principal( - w: WorkspaceClient, - prompts: Prompts, - subscription_id: str, - azure_resource_permissions: AzureResourcePermissions | None = None, -): - include_subscriptions = [subscription_id] if subscription_id else None - if azure_resource_permissions is None: - azure_resource_permissions = AzureResourcePermissions.for_cli(w, include_subscriptions=include_subscriptions) - azure_resource_permissions.create_uber_principal(prompts) - - -def _aws_setup_uber_principal( - w: WorkspaceClient, - prompts: Prompts, - aws_profile: str, - aws_resource_permissions: AWSResourcePermissions | None = None, -): - installation = Installation.current(w, 'ucx') - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(w, config.warehouse_id) - aws = AWSResources(aws_profile) - if aws_resource_permissions is None: - aws_resource_permissions = AWSResourcePermissions.for_cli( - w, installation, sql_backend, aws, config.inventory_database - ) - aws_resource_permissions.create_uber_principal(prompts) + if not ctx: + ctx = WorkspaceContext(w, named_parameters) + if ctx.is_azure: + return ctx.azure_resource_permissions.create_uber_principal(prompts) + if ctx.is_aws: + return ctx.aws_resource_permissions.create_uber_principal(prompts) + raise ValueError("Unsupported cloud provider") @ucx.command -def principal_prefix_access( - w: WorkspaceClient, - prompts: Prompts, - subscription_id: str | None = None, - azure_resource_permissions: AzureResourcePermissions | None = None, - aws_profile: str | None = None, - aws_resource_permissions: AWSResourcePermissions | None = None, -): +def principal_prefix_access(w: WorkspaceClient, ctx: WorkspaceContext | None = None, **named_parameters): """For azure cloud, identifies all storage accounts used by tables in the workspace, identify spn and its permission on each storage accounts. For aws, identifies all the Instance Profiles configured in the workspace and its access to all the S3 buckets, along with AWS roles that are set with UC access and its access to S3 buckets. The output is stored in the workspace install folder. Pass subscription_id for azure and aws_profile for aws.""" - return _execute_for_cloud( - w, - prompts, - _azure_principal_prefix_access, - _aws_principal_prefix_access, - azure_resource_permissions, - subscription_id, - aws_resource_permissions, - aws_profile, - ) - - -def _azure_principal_prefix_access( - w: WorkspaceClient, - _: Prompts, - *, - subscription_id: str, - azure_resource_permissions: AzureResourcePermissions | None = None, -): - if w.config.auth_type != "azure-cli": - logger.error("In order to obtain AAD token, Please run azure cli to authenticate.") - return - include_subscriptions = [subscription_id] if subscription_id else None - if azure_resource_permissions is None: - azure_resource_permissions = AzureResourcePermissions.for_cli(w, include_subscriptions=include_subscriptions) - logger.info("Generating azure storage accounts and service principal permission info") - path = azure_resource_permissions.save_spn_permissions() - if path: - logger.info(f"storage and spn info saved under {path}") - return - - -def _aws_principal_prefix_access( - w: WorkspaceClient, - _: Prompts, - *, - aws_profile: str, - aws_permissions: AWSResourcePermissions | None = None, -): - if not shutil.which("aws"): - logger.error("Couldn't find AWS CLI in path. Please install the CLI from https://aws.amazon.com/cli/") - return - logger.info("Generating instance profile and bucket permission info") - installation = Installation.current(w, 'ucx') - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(w, config.warehouse_id) - aws = AWSResources(aws_profile) - if aws_permissions is None: - aws_permissions = AWSResourcePermissions.for_cli(w, installation, sql_backend, aws, config.inventory_database) - instance_role_path = aws_permissions.save_instance_profile_permissions() - logger.info(f"Instance profile and bucket info saved {instance_role_path}") - logger.info("Generating UC roles and bucket permission info") - uc_role_path = aws_permissions.save_uc_compatible_roles() - logger.info(f"UC roles and bucket info saved {uc_role_path}") + if not ctx: + ctx = WorkspaceContext(w, named_parameters) + if ctx.is_azure: + return ctx.azure_resource_permissions.save_spn_permissions() + if ctx.is_aws: + instance_role_path = ctx.aws_resource_permissions.save_instance_profile_permissions() + logger.info(f"Instance profile and bucket info saved {instance_role_path}") + logger.info("Generating UC roles and bucket permission info") + return ctx.aws_resource_permissions.save_uc_compatible_roles() + raise ValueError("Unsupported cloud provider") @ucx.command -def migrate_credentials( - w: WorkspaceClient, prompts: Prompts, aws_profile: str | None = None, aws_resources: AWSResources | None = None -): +def migrate_credentials(w: WorkspaceClient, prompts: Prompts, ctx: WorkspaceContext | None = None, **named_parameters): """For Azure, this command migrates Azure Service Principals, which have Storage Blob Data Contributor, Storage Blob Data Reader, Storage Blob Data Owner roles on ADLS Gen2 locations that are being used in Databricks, to UC storage credentials. @@ -445,82 +289,44 @@ def migrate_credentials( Please review the file and delete the Instance Profiles you do not want to be migrated. Pass aws_profile for aws. """ - installation = Installation.current(w, 'ucx') - if w.config.is_azure: - logger.info("Running migrate_credentials for Azure") - service_principal_migration = ServicePrincipalMigration.for_cli(w, installation, prompts) - service_principal_migration.run(prompts) - return - if w.config.is_aws: - if not aws_profile: - aws_profile = os.getenv("AWS_DEFAULT_PROFILE") - if not aws_profile: - logger.error( - "AWS Profile is not specified. Use the environment variable [AWS_DEFAULT_PROFILE] " - "or use the '--aws-profile=[profile-name]' parameter." - ) - return - logger.info("Running migrate_credentials for AWS") - if not aws_resources: - aws_resources = AWSResources(aws_profile) - instance_profile_migration = IamRoleMigration.for_cli(w, installation, aws_resources, prompts) - instance_profile_migration.run(prompts) - return - if w.config.is_gcp: - logger.error("migrate_credentials is not yet supported in GCP") + if not ctx: + ctx = WorkspaceContext(w, named_parameters) + if ctx.is_azure: + return ctx.service_principal_migration.run(prompts) + if ctx.is_aws: + return ctx.iam_role_migration.run(prompts) + raise ValueError("Unsupported cloud provider") @ucx.command -def migrate_locations(w: WorkspaceClient, aws_profile: str | None = None): +def migrate_locations(w: WorkspaceClient, ctx: WorkspaceContext | None = None, **named_parameters): """This command creates UC external locations. The candidate locations to be created are extracted from guess_external_locations task in the assessment job. You can run validate_external_locations command to check the candidate locations. Please make sure the credentials haven migrated before running this command. The command will only create the locations that have corresponded UC Storage Credentials. """ - if w.config.is_azure: - logger.info("Running migrate_locations for Azure") - installation = Installation.current(w, 'ucx') - service_principal_migration = ExternalLocationsMigration.for_cli(w, installation) - service_principal_migration.run() - if w.config.is_aws: - logger.error("Migrate_locations for AWS") - if not shutil.which("aws"): - logger.error("Couldn't find AWS CLI in path. Please install the CLI from https://aws.amazon.com/cli/") - return - if not aws_profile: - aws_profile = os.getenv("AWS_DEFAULT_PROFILE") - if not aws_profile: - logger.error( - "AWS Profile is not specified. Use the environment variable [AWS_DEFAULT_PROFILE] " - "or use the '--aws-profile=[profile-name]' parameter." - ) - return - installation = Installation.current(w, 'ucx') - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(w, config.warehouse_id) - aws = AWSResources(aws_profile) - location = ExternalLocations(w, sql_backend, config.inventory_database) - aws_permissions = AWSResourcePermissions(installation, w, sql_backend, aws, location, config.inventory_database) - aws_permissions.create_external_locations() - if w.config.is_gcp: - logger.error("migrate_locations is not yet supported in GCP") + if not ctx: + ctx = WorkspaceContext(w, named_parameters) + if ctx.is_azure: + return ctx.azure_external_locations_migration.run() + if ctx.is_aws: + return ctx.aws_resource_permissions.create_external_locations() + raise ValueError("Unsupported cloud provider") @ucx.command def create_catalogs_schemas(w: WorkspaceClient, prompts: Prompts): """Create UC catalogs and schemas based on the destinations created from create_table_mapping command.""" - installation = Installation.current(w, 'ucx') - catalog_schema = CatalogSchema.for_cli(w, installation) - catalog_schema.create_all_catalogs_schemas(prompts) + ctx = WorkspaceContext(w) + ctx.catalog_schema.create_all_catalogs_schemas(prompts) @ucx.command def cluster_remap(w: WorkspaceClient, prompts: Prompts): """Re-mapping the cluster to UC""" logger.info("Remapping the Clusters to UC") - installation = Installation.current(w, 'ucx') - cluster = ClusterAccess(installation, w, prompts) - cluster_list = cluster.list_cluster() + ctx = WorkspaceContext(w) + cluster_list = ctx.cluster_access.list_cluster() if not cluster_list: logger.info("No cluster information present in the workspace") return @@ -530,17 +336,17 @@ def cluster_remap(w: WorkspaceClient, prompts: Prompts): cluster_ids = prompts.question( "Please provide the cluster id's as comma separated value from the above list", default="" ) - cluster.map_cluster_to_uc(cluster_ids, cluster_list) + ctx.cluster_access.map_cluster_to_uc(cluster_ids, cluster_list) @ucx.command def revert_cluster_remap(w: WorkspaceClient, prompts: Prompts): """Reverting Re-mapping of clusters from UC""" logger.info("Reverting the Remapping of the Clusters from UC") - installation = Installation.current(w, 'ucx') + ctx = WorkspaceContext(w) cluster_ids = [ cluster_files.path.split("/")[-1].split(".")[0] - for cluster_files in installation.files() + for cluster_files in ctx.installation.files() if cluster_files.path is not None and cluster_files.path.find("backup/clusters") > 0 ] if not cluster_ids: @@ -551,18 +357,17 @@ def revert_cluster_remap(w: WorkspaceClient, prompts: Prompts): cluster_list = prompts.question( "Please provide the cluster id's as comma separated value from the above list", default="" ) - cluster_details = ClusterAccess(installation, w, prompts) - cluster_details.revert_cluster_remap(cluster_list, cluster_ids) + ctx.cluster_access.revert_cluster_remap(cluster_list, cluster_ids) @ucx.command def migrate_local_code(w: WorkspaceClient, prompts: Prompts): """Fix the code files based on their language.""" - files = Files.for_cli(w) + ctx = WorkspaceContext(w) working_directory = Path.cwd() if not prompts.confirm("Do you want to apply UC migration to all files in the current directory?"): return - files.apply(working_directory) + ctx.local_file_migrator.apply(working_directory) if __name__ == "__main__": diff --git a/src/databricks/labs/ucx/contexts/__init__.py b/src/databricks/labs/ucx/contexts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/databricks/labs/ucx/contexts/application.py b/src/databricks/labs/ucx/contexts/application.py new file mode 100644 index 0000000000..f31ec8c494 --- /dev/null +++ b/src/databricks/labs/ucx/contexts/application.py @@ -0,0 +1,301 @@ +import abc +import logging +from datetime import timedelta +from functools import cached_property + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installer import InstallState +from databricks.labs.blueprint.wheels import ProductInfo, WheelsV2 +from databricks.labs.lsql.backends import SqlBackend +from databricks.sdk import AccountClient, WorkspaceClient, core +from databricks.sdk.service import sql + +from databricks.labs.ucx.account import WorkspaceInfo +from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler +from databricks.labs.ucx.aws.credentials import CredentialManager +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.hive_metastore import ExternalLocations, Mounts, TablesCrawler +from databricks.labs.ucx.hive_metastore.catalog_schema import CatalogSchema +from databricks.labs.ucx.hive_metastore.grants import ( + AzureACL, + GrantsCrawler, + PrincipalACL, +) +from databricks.labs.ucx.hive_metastore.mapping import TableMapping +from databricks.labs.ucx.hive_metastore.table_migrate import ( + MigrationStatusRefresher, + TablesMigrator, +) +from databricks.labs.ucx.hive_metastore.table_move import TableMove +from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler +from databricks.labs.ucx.hive_metastore.verification import VerifyHasMetastore +from databricks.labs.ucx.installer.workflows import DeployedWorkflows +from databricks.labs.ucx.source_code.languages import Languages +from databricks.labs.ucx.workspace_access import generic, redash +from databricks.labs.ucx.workspace_access.groups import GroupManager +from databricks.labs.ucx.workspace_access.manager import PermissionManager +from databricks.labs.ucx.workspace_access.scim import ScimSupport +from databricks.labs.ucx.workspace_access.secrets import SecretScopesSupport +from databricks.labs.ucx.workspace_access.tacl import TableAclSupport + +# "Service Factories" would always have a lot of pulic methods. +# This is because they are responsible for creating objects that are +# used throughout the application. That being said, we'll do best +# effort of splitting the instances between Global, Runtime, +# Workspace CLI, and Account CLI contexts. +# pylint: disable=too-many-public-methods + +logger = logging.getLogger(__name__) + + +class GlobalContext(abc.ABC): + def __init__(self, named_parameters: dict[str, str] | None = None): + if not named_parameters: + named_parameters = {} + self._named_parameters = named_parameters + + def replace(self, **kwargs): + """Replace cached properties for unit testing purposes.""" + for key, value in kwargs.items(): + self.__dict__[key] = value + return self + + @cached_property + def workspace_client(self) -> WorkspaceClient: + raise ValueError("Workspace client not set") + + @cached_property + def sql_backend(self) -> SqlBackend: + raise ValueError("SQL backend not set") + + @cached_property + def account_client(self) -> AccountClient: + raise ValueError("Account client not set") + + @cached_property + def named_parameters(self) -> dict[str, str]: + return self._named_parameters + + @cached_property + def product_info(self): + return ProductInfo.from_class(WorkspaceConfig) + + @cached_property + def installation(self): + return Installation.current(self.workspace_client, self.product_info.product_name()) + + @cached_property + def config(self) -> WorkspaceConfig: + return self.installation.load(WorkspaceConfig) + + @cached_property + def connect_config(self) -> core.Config: + return self.workspace_client.config + + @cached_property + def is_azure(self) -> bool: + if self.is_aws: + return False + return self.connect_config.is_azure + + @cached_property + def is_aws(self) -> bool: + return self.connect_config.is_aws + + @cached_property + def inventory_database(self) -> str: + return self.config.inventory_database + + @cached_property + def workspace_listing(self): + return generic.WorkspaceListing( + self.workspace_client, + self.sql_backend, + self.inventory_database, + self.config.num_threads, + self.config.workspace_start_path, + ) + + @cached_property + def generic_permissions_support(self): + models_listing = generic.models_listing(self.workspace_client, self.config.num_threads) + acl_listing = [ + generic.Listing(self.workspace_client.clusters.list, "cluster_id", "clusters"), + generic.Listing(self.workspace_client.cluster_policies.list, "policy_id", "cluster-policies"), + generic.Listing(self.workspace_client.instance_pools.list, "instance_pool_id", "instance-pools"), + generic.Listing(self.workspace_client.warehouses.list, "id", "sql/warehouses"), + generic.Listing(self.workspace_client.jobs.list, "job_id", "jobs"), + generic.Listing(self.workspace_client.pipelines.list_pipelines, "pipeline_id", "pipelines"), + generic.Listing(self.workspace_client.serving_endpoints.list, "id", "serving-endpoints"), + generic.Listing(generic.experiments_listing(self.workspace_client), "experiment_id", "experiments"), + generic.Listing(models_listing, "id", "registered-models"), + generic.Listing(generic.models_root_page, "object_id", "registered-models"), + generic.Listing(generic.tokens_and_passwords, "object_id", "authorization"), + generic.Listing(generic.feature_store_listing(self.workspace_client), "object_id", "feature-tables"), + generic.Listing(generic.feature_tables_root_page, "object_id", "feature-tables"), + self.workspace_listing, + ] + return generic.GenericPermissionsSupport(self.workspace_client, acl_listing) + + @cached_property + def redash_permissions_support(self): + acl_listing = [ + redash.Listing(self.workspace_client.alerts.list, sql.ObjectTypePlural.ALERTS), + redash.Listing(self.workspace_client.dashboards.list, sql.ObjectTypePlural.DASHBOARDS), + redash.Listing(self.workspace_client.queries.list, sql.ObjectTypePlural.QUERIES), + ] + return redash.RedashPermissionsSupport(self.workspace_client, acl_listing) + + @cached_property + def scim_entitlements_support(self): + return ScimSupport(self.workspace_client) + + @cached_property + def secret_scope_acl_support(self): + return SecretScopesSupport(self.workspace_client) + + @cached_property + def legacy_table_acl_support(self): + return TableAclSupport(self.grants_crawler, self.sql_backend) + + @cached_property + def permission_manager(self): + return PermissionManager( + self.sql_backend, + self.inventory_database, + [ + self.generic_permissions_support, + self.redash_permissions_support, + self.secret_scope_acl_support, + self.scim_entitlements_support, + self.legacy_table_acl_support, + ], + ) + + @cached_property + def group_manager(self): + return GroupManager( + self.sql_backend, + self.workspace_client, + self.inventory_database, + self.config.include_group_names, + self.config.renamed_group_prefix, + workspace_group_regex=self.config.workspace_group_regex, + workspace_group_replace=self.config.workspace_group_replace, + account_group_regex=self.config.account_group_regex, + external_id_match=self.config.group_match_by_external_id, + ) + + @cached_property + def grants_crawler(self): + return GrantsCrawler(self.tables_crawler, self.udfs_crawler, self.config.include_databases) + + @cached_property + def udfs_crawler(self): + return UdfsCrawler(self.sql_backend, self.inventory_database, self.config.include_databases) + + @cached_property + def tables_crawler(self): + return TablesCrawler(self.sql_backend, self.inventory_database, self.config.include_databases) + + @cached_property + def tables_migrator(self): + return TablesMigrator( + self.tables_crawler, + self.grants_crawler, + self.workspace_client, + self.sql_backend, + self.table_mapping, + self.group_manager, + self.migration_status_refresher, + self.principal_acl, + ) + + @cached_property + def table_move(self): + return TableMove(self.workspace_client, self.sql_backend) + + @cached_property + def mounts_crawler(self): + return Mounts(self.sql_backend, self.workspace_client, self.inventory_database) + + @cached_property + def azure_service_principal_crawler(self): + return AzureServicePrincipalCrawler(self.workspace_client, self.sql_backend, self.inventory_database) + + @cached_property + def external_locations(self): + return ExternalLocations(self.workspace_client, self.sql_backend, self.inventory_database) + + @cached_property + def azure_acl(self): + return AzureACL( + self.workspace_client, + self.sql_backend, + self.azure_service_principal_crawler, + self.installation, + ) + + @cached_property + def principal_acl(self): + if not self.is_azure: + raise NotImplementedError("Azure only for now") + eligible = self.azure_acl.get_eligible_locations_principals() + return PrincipalACL( + self.workspace_client, + self.sql_backend, + self.installation, + self.tables_crawler, + self.mounts_crawler, + eligible, + ) + + @cached_property + def migration_status_refresher(self): + return MigrationStatusRefresher( + self.workspace_client, + self.sql_backend, + self.inventory_database, + self.tables_crawler, + ) + + @cached_property + def iam_credential_manager(self): + return CredentialManager(self.workspace_client) + + @cached_property + def table_mapping(self): + return TableMapping(self.installation, self.workspace_client, self.sql_backend) + + @cached_property + def catalog_schema(self): + return CatalogSchema(self.workspace_client, self.table_mapping) + + @cached_property + def languages(self): + index = self.tables_migrator.index() + return Languages(index) + + @cached_property + def verify_timeout(self): + return timedelta(minutes=2) + + @cached_property + def wheels(self): + return WheelsV2(self.installation, self.product_info) + + @cached_property + def install_state(self): + return InstallState.from_installation(self.installation) + + @cached_property + def deployed_workflows(self): + return DeployedWorkflows(self.workspace_client, self.install_state, self.verify_timeout) + + @cached_property + def workspace_info(self): + return WorkspaceInfo(self.installation, self.workspace_client) + + @cached_property + def verify_has_metastore(self): + return VerifyHasMetastore(self.workspace_client) diff --git a/src/databricks/labs/ucx/contexts/cli_command.py b/src/databricks/labs/ucx/contexts/cli_command.py new file mode 100644 index 0000000000..f71610741c --- /dev/null +++ b/src/databricks/labs/ucx/contexts/cli_command.py @@ -0,0 +1,182 @@ +import abc +import logging +import os +import shutil +from functools import cached_property + +from databricks.labs.blueprint.tui import Prompts +from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend +from databricks.sdk import AccountClient, WorkspaceClient + +from databricks.labs.ucx.account import AccountWorkspaces +from databricks.labs.ucx.assessment.aws import run_command, AWSResources +from databricks.labs.ucx.aws.access import AWSResourcePermissions +from databricks.labs.ucx.aws.credentials import IamRoleMigration +from databricks.labs.ucx.azure.access import AzureResourcePermissions +from databricks.labs.ucx.azure.credentials import ServicePrincipalMigration, StorageCredentialManager +from databricks.labs.ucx.azure.locations import ExternalLocationsMigration +from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources +from databricks.labs.ucx.contexts.application import GlobalContext +from databricks.labs.ucx.source_code.files import LocalFileMigrator +from databricks.labs.ucx.workspace_access.clusters import ClusterAccess + +logger = logging.getLogger(__name__) + + +class CliContext(GlobalContext, abc.ABC): + @cached_property + def prompts(self) -> Prompts: + return Prompts() + + +class WorkspaceContext(CliContext): + def __init__(self, ws: WorkspaceClient, named_parameters: dict[str, str] | None = None): + super().__init__(named_parameters) + self._ws = ws + + @cached_property + def workspace_client(self) -> WorkspaceClient: + return self._ws + + @cached_property + def sql_backend(self) -> SqlBackend: + return StatementExecutionBackend(self.workspace_client, self.config.warehouse_id) + + @cached_property + def local_file_migrator(self): + return LocalFileMigrator(self.languages) + + @cached_property + def cluster_access(self): + return ClusterAccess(self.installation, self.workspace_client, self.prompts) + + @cached_property + def azure_cli_authenticated(self): + if not self.is_azure: + raise NotImplementedError("Azure only") + if self.connect_config.auth_type != "azure-cli": + raise ValueError("In order to obtain AAD token, Please run azure cli to authenticate.") + return True + + @cached_property + def azure_management_client(self): + if not self.azure_cli_authenticated: + raise NotImplementedError + return AzureAPIClient( + self.workspace_client.config.arm_environment.resource_manager_endpoint, + self.workspace_client.config.arm_environment.service_management_endpoint, + ) + + @cached_property + def microsoft_graph_client(self): + if not self.azure_cli_authenticated: + raise NotImplementedError + return AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") + + @cached_property + def azure_subscription_id(self): + subscription_id = self.named_parameters.get("subscription_id") + if not subscription_id: + raise ValueError("Please enter subscription id to scan storage accounts in.") + return subscription_id + + @cached_property + def azure_resources(self): + return AzureResources( + self.azure_management_client, + self.microsoft_graph_client, + [self.azure_subscription_id], + ) + + @cached_property + def azure_resource_permissions(self): + return AzureResourcePermissions( + self.installation, + self.workspace_client, + self.azure_resources, + self.external_locations, + ) + + @cached_property + def azure_credential_manager(self): + return StorageCredentialManager(self.workspace_client) + + @cached_property + def service_principal_migration(self): + return ServicePrincipalMigration( + self.installation, + self.workspace_client, + self.azure_resource_permissions, + self.azure_service_principal_crawler, + self.azure_credential_manager, + ) + + @cached_property + def azure_external_locations_migration(self): + return ExternalLocationsMigration( + self.workspace_client, + self.external_locations, + self.azure_resource_permissions, + self.azure_resources, + ) + + @cached_property + def aws_cli_run_command(self): + # this is a convenience method for unit testing + if not shutil.which("aws"): + raise ValueError("Couldn't find AWS CLI in path. Please install the CLI from https://aws.amazon.com/cli/") + return run_command + + @cached_property + def aws_profile(self): + aws_profile = self.named_parameters.get("aws_profile") + if not aws_profile: + aws_profile = os.getenv("AWS_DEFAULT_PROFILE") + if not aws_profile: + raise ValueError( + "AWS Profile is not specified. Use the environment variable [AWS_DEFAULT_PROFILE] " + "or use the '--aws-profile=[profile-name]' parameter." + ) + return aws_profile + + @cached_property + def aws_resources(self): + if not self.is_aws: + raise NotImplementedError("AWS only") + return AWSResources(self.aws_profile, self.aws_cli_run_command) + + @cached_property + def aws_resource_permissions(self): + return AWSResourcePermissions( + self.installation, + self.workspace_client, + self.sql_backend, + self.aws_resources, + self.external_locations, + self.inventory_database, + self.named_parameters.get("aws_account_id"), + self.named_parameters.get("kms_key"), + ) + + @cached_property + def iam_role_migration(self): + return IamRoleMigration( + self.installation, + self.workspace_client, + self.aws_resource_permissions, + self.iam_credential_manager, + ) + + +class AccountContext(CliContext): + def __init__(self, ac: AccountClient, named_parameters: dict[str, str] | None = None): + super().__init__(named_parameters) + self._ac = ac + + @cached_property + def account_client(self) -> AccountClient: + return self._ac + + @cached_property + def account_workspaces(self): + return AccountWorkspaces(self.account_client) diff --git a/src/databricks/labs/ucx/contexts/workflow_task.py b/src/databricks/labs/ucx/contexts/workflow_task.py new file mode 100644 index 0000000000..a218d94426 --- /dev/null +++ b/src/databricks/labs/ucx/contexts/workflow_task.py @@ -0,0 +1,92 @@ +from functools import cached_property +from pathlib import Path + +from databricks.labs.blueprint.installation import Installation +from databricks.labs.lsql.backends import RuntimeBackend, SqlBackend +from databricks.sdk import WorkspaceClient, core + +from databricks.labs.ucx.__about__ import __version__ +from databricks.labs.ucx.assessment.clusters import ClustersCrawler, PoliciesCrawler +from databricks.labs.ucx.assessment.init_scripts import GlobalInitScriptCrawler +from databricks.labs.ucx.assessment.jobs import JobsCrawler, SubmitRunsCrawler +from databricks.labs.ucx.assessment.pipelines import PipelinesCrawler +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.contexts.application import GlobalContext +from databricks.labs.ucx.hive_metastore import TablesInMounts +from databricks.labs.ucx.hive_metastore.table_size import TableSizeCrawler + + +class RuntimeContext(GlobalContext): + @cached_property + def _config_path(self) -> Path: + config = self.named_parameters.get("config") + if not config: + raise ValueError("config flag is required") + return Path(config) + + @cached_property + def config(self) -> WorkspaceConfig: + return Installation.load_local(WorkspaceConfig, self._config_path) + + @cached_property + def connect_config(self) -> core.Config: + connect = self.config.connect + assert connect, "connect is required" + return connect + + @cached_property + def workspace_client(self) -> WorkspaceClient: + return WorkspaceClient(config=self.connect_config, product='ucx', product_version=__version__) + + @cached_property + def sql_backend(self) -> SqlBackend: + return RuntimeBackend(debug_truncate_bytes=self.connect_config.debug_truncate_bytes) + + @cached_property + def installation(self): + install_folder = self._config_path.parent.as_posix().removeprefix("/Workspace") + return Installation(self.workspace_client, "ucx", install_folder=install_folder) + + @cached_property + def jobs_crawler(self): + return JobsCrawler(self.workspace_client, self.sql_backend, self.inventory_database) + + @cached_property + def submit_runs_crawler(self): + return SubmitRunsCrawler( + self.workspace_client, + self.sql_backend, + self.inventory_database, + self.config.num_days_submit_runs_history, + ) + + @cached_property + def clusters_crawler(self): + return ClustersCrawler(self.workspace_client, self.sql_backend, self.inventory_database) + + @cached_property + def pipelines_crawler(self): + return PipelinesCrawler(self.workspace_client, self.sql_backend, self.inventory_database) + + @cached_property + def table_size_crawler(self): + return TableSizeCrawler(self.sql_backend, self.inventory_database) + + @cached_property + def policies_crawler(self): + return PoliciesCrawler(self.workspace_client, self.sql_backend, self.inventory_database) + + @cached_property + def global_init_scripts_crawler(self): + return GlobalInitScriptCrawler(self.workspace_client, self.sql_backend, self.inventory_database) + + @cached_property + def tables_in_mounts(self): + return TablesInMounts( + self.sql_backend, + self.workspace_client, + self.inventory_database, + self.mounts_crawler, + self.config.include_mounts, + self.config.exclude_paths_in_mount, + ) diff --git a/src/databricks/labs/ucx/framework/tasks.py b/src/databricks/labs/ucx/framework/tasks.py index 7fa7786520..772ec902ab 100644 --- a/src/databricks/labs/ucx/framework/tasks.py +++ b/src/databricks/labs/ucx/framework/tasks.py @@ -1,11 +1,10 @@ import contextlib import logging import os -from collections.abc import Callable +from collections.abc import Callable, Iterable from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta -from functools import wraps from logging.handlers import TimedRotatingFileHandler from pathlib import Path @@ -24,7 +23,6 @@ @dataclass class Task: - task_id: int workflow: str name: str doc: str @@ -64,64 +62,6 @@ def remove_extra_indentation(doc: str) -> str: return "\n".join(stripped) -def task( - workflow, - *, - depends_on=None, - job_cluster="main", - notebook: str | None = None, - dashboard: str | None = None, - cloud: str | None = None, -): - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Perform any task-specific logic here - # For example, you can log when the task is started and completed - logger = logging.getLogger(func.__name__) - logger.info(f"Task '{workflow}' is starting...") - result = func(*args, **kwargs) - logger.info(f"Task '{workflow}' is completed!") - return result - - deps = [] - if depends_on is not None: - if not isinstance(depends_on, list): - msg = "depends_on has to be a list" - raise SyntaxError(msg) - for fn in depends_on: - if _TASKS[fn.__name__].workflow != workflow: - # for now, we filter out the cross-task - # dependencies within the same job. - # - # Technically, we can check it and fail - # the job if the previous steps didn't - # run before. - continue - deps.append(fn.__name__) - - if not func.__doc__: - msg = f"Task {func.__name__} must have documentation" - raise SyntaxError(msg) - - _TASKS[func.__name__] = Task( - task_id=len(_TASKS), - workflow=workflow, - name=func.__name__, - doc=remove_extra_indentation(func.__doc__), - fn=func, - depends_on=deps, - job_cluster=job_cluster, - notebook=notebook, - dashboard=dashboard, - cloud=cloud, - ) - - return wrapper - - return decorator - - class TaskLogger(contextlib.AbstractContextManager): # files are available in the workspace only once their handlers are closed, # so we rotate files log every 10 minutes. @@ -266,3 +206,63 @@ def trigger(*argv): installation = Installation(workspace_client, "ucx", install_folder=install_folder) run_task(args, config_path.parent, cfg, workspace_client, sql_backend, installation) + + +class Workflow: + def __init__(self, name: str): + self._name = name + + @property + def name(self): + return self._name + + def tasks(self) -> Iterable[Task]: + # return __task__ from every method in this class that has this attribute + for attr in dir(self): + if attr.startswith("_"): + continue + fn = getattr(self, attr) + if hasattr(fn, "__task__"): + yield fn.__task__ + + +def job_task( + fn=None, + *, + depends_on=None, + job_cluster="main", + notebook: str | None = None, + dashboard: str | None = None, + cloud: str | None = None, +) -> Callable[[Callable], Callable]: + def register(func): + if not func.__doc__: + raise SyntaxError(f"{func.__name__} must have some doc comment") + deps = [] + this_class = func.__qualname__.split('.')[0] + if depends_on is not None: + if not isinstance(depends_on, list): + msg = "depends_on has to be a list" + raise SyntaxError(msg) + for fn in depends_on: + other_class, task_name = fn.__qualname__.split('.') + if other_class != this_class: + continue + deps.append(task_name) + func.__task__ = Task( + workflow='', + name=func.__name__, + doc=remove_extra_indentation(func.__doc__), + fn=func, + depends_on=deps, + job_cluster=job_cluster, + notebook=notebook, + dashboard=dashboard, + cloud=cloud, + ) + return func + + if fn is None: + return register + register(fn) + return fn diff --git a/src/databricks/labs/ucx/hive_metastore/catalog_schema.py b/src/databricks/labs/ucx/hive_metastore/catalog_schema.py index a83a8125f7..839e12e07c 100644 --- a/src/databricks/labs/ucx/hive_metastore/catalog_schema.py +++ b/src/databricks/labs/ucx/hive_metastore/catalog_schema.py @@ -1,13 +1,10 @@ import logging from pathlib import PurePath -from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.tui import Prompts -from databricks.labs.lsql.backends import StatementExecutionBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound -from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.hive_metastore.mapping import TableMapping logger = logging.getLogger(__name__) @@ -19,13 +16,6 @@ def __init__(self, ws: WorkspaceClient, table_mapping: TableMapping): self._table_mapping = table_mapping self._external_locations = self._ws.external_locations.list() - @classmethod - def for_cli(cls, ws: WorkspaceClient, installation: Installation): - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - table_mapping = TableMapping(installation, ws, sql_backend) - return cls(ws, table_mapping) - def create_all_catalogs_schemas(self, prompts: Prompts): candidate_catalogs, candidate_schemas = self._get_missing_catalogs_schemas() for candidate_catalog in candidate_catalogs: diff --git a/src/databricks/labs/ucx/hive_metastore/grants.py b/src/databricks/labs/ucx/hive_metastore/grants.py index 6800b00537..8d96204b8c 100644 --- a/src/databricks/labs/ucx/hive_metastore/grants.py +++ b/src/databricks/labs/ucx/hive_metastore/grants.py @@ -6,7 +6,7 @@ from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.parallel import ManyError, Threads -from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend +from databricks.labs.lsql.backends import SqlBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors import ResourceDoesNotExist from databricks.sdk.service.catalog import ExternalLocationInfo, SchemaInfo, TableInfo @@ -19,8 +19,6 @@ AzureResourcePermissions, StoragePermissionMapping, ) -from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources -from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.crawlers import CrawlerBase from databricks.labs.ucx.framework.utils import escape_sql_identifier from databricks.labs.ucx.hive_metastore.locations import ( @@ -341,27 +339,12 @@ def __init__( ws: WorkspaceClient, backend: SqlBackend, spn_crawler: AzureServicePrincipalCrawler, - resource_permissions: AzureResourcePermissions, + installation: Installation, ): self._backend = backend self._ws = ws self._spn_crawler = spn_crawler - self._resource_permissions = resource_permissions - - @classmethod - def for_cli(cls, ws: WorkspaceClient, installation: Installation): - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - locations = ExternalLocations(ws, sql_backend, config.inventory_database) - azure_client = AzureAPIClient( - ws.config.arm_environment.resource_manager_endpoint, - ws.config.arm_environment.service_management_endpoint, - ) - graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") - azurerm = AzureResources(azure_client, graph_client) - resource_permissions = AzureResourcePermissions(installation, ws, azurerm, locations) - spn_crawler = AzureServicePrincipalCrawler(ws, sql_backend, config.inventory_database) - return cls(ws, sql_backend, spn_crawler, resource_permissions) + self._installation = installation def get_eligible_locations_principals(self) -> dict[str, dict]: cluster_locations = {} @@ -381,7 +364,10 @@ def get_eligible_locations_principals(self) -> dict[str, dict]: logger.error(msg) raise ResourceDoesNotExist(msg) from None - permission_mappings = self._resource_permissions.load() + permission_mappings = self._installation.load( + list[StoragePermissionMapping], + filename=AzureResourcePermissions.FILENAME, + ) if len(permission_mappings) == 0: # if permission mapping is empty, raise an error to run principal_prefix cmd msg = ( @@ -437,29 +423,6 @@ def __init__( self._mounts_crawler = mounts_crawler self._cluster_locations = cluster_locations - @classmethod - def for_cli(cls, ws: WorkspaceClient, installation: Installation, sql_backend: SqlBackend): - config = installation.load(WorkspaceConfig) - - tables_crawler = TablesCrawler(sql_backend, config.inventory_database) - mount_crawler = Mounts(sql_backend, ws, config.inventory_database) - if ws.config.is_azure: - azure_acl = AzureACL.for_cli(ws, installation) - return cls( - ws, - sql_backend, - installation, - tables_crawler, - mount_crawler, - azure_acl.get_eligible_locations_principals(), - ) - if ws.config.is_aws: - return None - if ws.config.is_gcp: - logger.error("UCX is not supported for GCP yet. Please run it on azure or aws") - return None - return None - def get_interactive_cluster_grants(self) -> list[Grant]: tables = self._tables_crawler.snapshot() mounts = list(self._mounts_crawler.snapshot()) diff --git a/src/databricks/labs/ucx/hive_metastore/locations.py b/src/databricks/labs/ucx/hive_metastore/locations.py index 3505371982..8f4d7fdb6a 100644 --- a/src/databricks/labs/ucx/hive_metastore/locations.py +++ b/src/databricks/labs/ucx/hive_metastore/locations.py @@ -35,7 +35,7 @@ class Mount: class ExternalLocations(CrawlerBase[ExternalLocation]): _prefix_size: ClassVar[list[int]] = [1, 12] - def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema): + def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema: str): super().__init__(sbe, "hive_metastore", schema, "external_locations", ExternalLocation) self._ws = ws diff --git a/src/databricks/labs/ucx/hive_metastore/mapping.py b/src/databricks/labs/ucx/hive_metastore/mapping.py index ab4a826f22..7664d6e408 100644 --- a/src/databricks/labs/ucx/hive_metastore/mapping.py +++ b/src/databricks/labs/ucx/hive_metastore/mapping.py @@ -6,12 +6,11 @@ from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.parallel import Threads -from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend +from databricks.labs.lsql.backends import SqlBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors import BadRequest, NotFound, ResourceConflict from databricks.labs.ucx.account import WorkspaceInfo -from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.utils import escape_sql_identifier from databricks.labs.ucx.hive_metastore import TablesCrawler from databricks.labs.ucx.hive_metastore.tables import Table @@ -61,21 +60,14 @@ def __eq__(self, other): class TableMapping: + FILENAME = 'mapping.csv' UCX_SKIP_PROPERTY = "databricks.labs.ucx.skip" def __init__(self, installation: Installation, ws: WorkspaceClient, sql_backend: SqlBackend): - self._filename = 'mapping.csv' self._installation = installation self._ws = ws self._sql_backend = sql_backend - @classmethod - def current(cls, ws: WorkspaceClient, product='ucx'): - installation = Installation.current(ws, product) - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - return cls(installation, ws, sql_backend) - def current_tables(self, tables: TablesCrawler, workspace_name: str, catalog_name: str): tables_snapshot = tables.snapshot() if len(tables_snapshot) == 0: @@ -88,11 +80,11 @@ def save(self, tables: TablesCrawler, workspace_info: WorkspaceInfo) -> str: workspace_name = workspace_info.current() default_catalog_name = re.sub(r"\W+", "_", workspace_name) current_tables = self.current_tables(tables, workspace_name, default_catalog_name) - return self._installation.save(list(current_tables), filename=self._filename) + return self._installation.save(list(current_tables), filename=self.FILENAME) def load(self) -> list[Rule]: try: - return self._installation.load(list[Rule], filename=self._filename) + return self._installation.load(list[Rule], filename=self.FILENAME) except NotFound: msg = "Please run: databricks labs ucx table-mapping" raise ValueError(msg) from None diff --git a/src/databricks/labs/ucx/hive_metastore/table_migrate.py b/src/databricks/labs/ucx/hive_metastore/table_migrate.py index e729a8afc3..82a0bfe8e6 100644 --- a/src/databricks/labs/ucx/hive_metastore/table_migrate.py +++ b/src/databricks/labs/ucx/hive_metastore/table_migrate.py @@ -6,13 +6,11 @@ from dataclasses import dataclass from functools import partial -from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.parallel import Threads -from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend +from databricks.labs.lsql.backends import SqlBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound -from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.crawlers import CrawlerBase from databricks.labs.ucx.framework.utils import escape_sql_identifier from databricks.labs.ucx.hive_metastore import TablesCrawler @@ -28,7 +26,6 @@ Table, What, ) -from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler from databricks.labs.ucx.hive_metastore.views_sequencer import ( ViewsMigrationSequencer, ViewToMigrate, @@ -73,30 +70,8 @@ def __init__( self._seen_tables: dict[str, str] = {} self._principal_grants = principal_grants - @classmethod - def for_cli(cls, ws: WorkspaceClient, product='ucx'): - installation = Installation.current(ws, product) - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - table_crawler = TablesCrawler(sql_backend, config.inventory_database) - udfs_crawler = UdfsCrawler(sql_backend, config.inventory_database) - grants_crawler = GrantsCrawler(table_crawler, udfs_crawler) - table_mapping = TableMapping(installation, ws, sql_backend) - group_manager = GroupManager(sql_backend, ws, config.inventory_database) - principal_grants = PrincipalACL.for_cli(ws, installation, sql_backend) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, config.inventory_database, table_crawler) - return cls( - table_crawler, - grants_crawler, - ws, - sql_backend, - table_mapping, - group_manager, - migration_status_refresher, - principal_grants, - ) - def index(self): + # TODO: remove this method return self._migration_status_refresher.index() def migrate_tables(self, what: What, acl_strategy: list[AclMigrationWhat] | None = None): diff --git a/src/databricks/labs/ucx/hive_metastore/table_move.py b/src/databricks/labs/ucx/hive_metastore/table_move.py index bd9f53d1da..8edd706f9f 100644 --- a/src/databricks/labs/ucx/hive_metastore/table_move.py +++ b/src/databricks/labs/ucx/hive_metastore/table_move.py @@ -1,8 +1,7 @@ from functools import partial -from databricks.labs.blueprint.installation import Installation from databricks.labs.blueprint.parallel import Threads -from databricks.labs.lsql.backends import SqlBackend, StatementExecutionBackend +from databricks.labs.lsql.backends import SqlBackend from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound from databricks.sdk.service.catalog import ( @@ -12,7 +11,6 @@ TableType, ) -from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.utils import escape_sql_identifier from databricks.labs.ucx.hive_metastore.table_migrate import logger @@ -24,13 +22,6 @@ def __init__(self, ws: WorkspaceClient, backend: SqlBackend): self._execute = backend.execute self._ws = ws - @classmethod - def for_cli(cls, ws: WorkspaceClient, product='ucx'): - installation = Installation.current(ws, product) - config = installation.load(WorkspaceConfig) - sql_backend = StatementExecutionBackend(ws, config.warehouse_id) - return cls(ws, sql_backend) - def move( self, from_catalog: str, diff --git a/src/databricks/labs/ucx/hive_metastore/workflows.py b/src/databricks/labs/ucx/hive_metastore/workflows.py new file mode 100644 index 0000000000..a7f8f2af6e --- /dev/null +++ b/src/databricks/labs/ucx/hive_metastore/workflows.py @@ -0,0 +1,39 @@ +from databricks.labs.ucx.assessment.workflows import Assessment +from databricks.labs.ucx.contexts.workflow_task import RuntimeContext +from databricks.labs.ucx.framework.tasks import Workflow, job_task +from databricks.labs.ucx.hive_metastore.tables import AclMigrationWhat, What + + +class TableMigration(Workflow): + def __init__(self): + super().__init__('migrate-tables') + + @job_task(job_cluster="table_migration", depends_on=[Assessment.crawl_tables]) + def migrate_external_tables_sync(self, ctx: RuntimeContext): + """This workflow task migrates the *external tables that are supported by SYNC command* from the Hive Metastore to the Unity Catalog. + Following cli commands are required to be run before running this task: + - For Azure: `principal-prefix-access`, `create-table-mapping`, `create-uber-principal`, `migrate-credentials`, `migrate-locations`, `create-catalogs-schemas` + - For AWS: TBD + """ + ctx.tables_migrator.migrate_tables(what=What.EXTERNAL_SYNC, acl_strategy=[AclMigrationWhat.LEGACY_TACL]) + + @job_task(job_cluster="table_migration", depends_on=[Assessment.crawl_tables]) + def migrate_dbfs_root_delta_tables(self, ctx: RuntimeContext): + """This workflow task migrates `delta tables stored in DBFS root` from the Hive Metastore to the Unity Catalog using deep clone. + Following cli commands are required to be run before running this task: + - For Azure: `principal-prefix-access`, `create-table-mapping`, `create-uber-principal`, `migrate-credentials`, `migrate-locations`, `create-catalogs-schemas` + - For AWS: TBD + """ + ctx.tables_migrator.migrate_tables(what=What.DBFS_ROOT_DELTA, acl_strategy=[AclMigrationWhat.LEGACY_TACL]) + + +class MigrateTablesInMounts(Workflow): + def __init__(self): + super().__init__('migrate-tables-in-mounts-experimental') + + @job_task + def scan_tables_in_mounts_experimental(self, ctx: RuntimeContext): + """[EXPERIMENTAL] This workflow scans for Delta tables inside all mount points + captured during the assessment. It will store the results under the `tables` table + located under the assessment.""" + ctx.tables_in_mounts.snapshot() diff --git a/src/databricks/labs/ucx/install.py b/src/databricks/labs/ucx/install.py index 28423585e9..a5c3ddcb3d 100644 --- a/src/databricks/labs/ucx/install.py +++ b/src/databricks/labs/ucx/install.py @@ -44,6 +44,7 @@ from databricks.labs.ucx.assessment.pipelines import PipelineInfo from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.dashboards import DashboardFromFiles +from databricks.labs.ucx.framework.tasks import Task from databricks.labs.ucx.hive_metastore.grants import Grant from databricks.labs.ucx.hive_metastore.locations import ExternalLocation, Mount from databricks.labs.ucx.hive_metastore.table_migrate import MigrationStatus @@ -54,6 +55,7 @@ from databricks.labs.ucx.installer.mixins import InstallationMixin from databricks.labs.ucx.installer.policy import ClusterPolicyInstaller from databricks.labs.ucx.installer.workflows import WorkflowsDeployment +from databricks.labs.ucx.runtime import Workflows from databricks.labs.ucx.workspace_access.base import Permissions from databricks.labs.ucx.workspace_access.generic import WorkspaceObjectInfo from databricks.labs.ucx.workspace_access.groups import ConfigureGroups, MigratedGroup @@ -117,6 +119,7 @@ def __init__( ws: WorkspaceClient, product_info: ProductInfo, environ: dict[str, str] | None = None, + tasks: list[Task] | None = None, ): if not environ: environ = dict(os.environ.items()) @@ -129,6 +132,7 @@ def __init__( self._policy_installer = ClusterPolicyInstaller(installation, ws, prompts) self._product_info = product_info self._force_install = environ.get("UCX_FORCE_INSTALL") + self._tasks = tasks if tasks else Workflows.all().tasks() def run( self, @@ -144,8 +148,15 @@ def run( wheel_builder_factory = self._new_wheel_builder wheels = wheel_builder_factory() install_state = InstallState.from_installation(self._installation) - workflows_installer = WorkflowsDeployment( - config, self._installation, install_state, self._ws, wheels, self._product_info, verify_timeout + workflows_deployment = WorkflowsDeployment( + config, + self._installation, + install_state, + self._ws, + wheels, + self._product_info, + verify_timeout, + self._tasks, ) workspace_installation = WorkspaceInstallation( config, @@ -153,7 +164,7 @@ def run( install_state, sql_backend_factory(config), self._ws, - workflows_installer, + workflows_deployment, self._prompts, self._product_info, ) @@ -364,6 +375,7 @@ def __init__( @classmethod def current(cls, ws: WorkspaceClient): + # TODO: remove this method, it's no longer needed product_info = ProductInfo.from_class(WorkspaceConfig) installation = product_info.current_installation(ws) install_state = InstallState.from_installation(installation) @@ -372,6 +384,7 @@ def current(cls, ws: WorkspaceClient): wheels = product_info.wheels(ws) prompts = Prompts() timeout = timedelta(minutes=2) + tasks = Workflows.all().tasks() workflows_installer = WorkflowsDeployment( config, installation, @@ -380,6 +393,7 @@ def current(cls, ws: WorkspaceClient): wheels, product_info, timeout, + tasks, ) return cls( @@ -419,9 +433,6 @@ def run(self): logger.info("Triggering the assessment workflow") self._trigger_workflow("assessment") - def config_file_link(self): - return self._installation.workspace_link('config.yml') - def _create_database(self): try: deploy_schema(self._sql_backend, self._config.inventory_database) diff --git a/src/databricks/labs/ucx/installer/workflows.py b/src/databricks/labs/ucx/installer/workflows.py index 3bc83c8f89..2b52a7094c 100644 --- a/src/databricks/labs/ucx/installer/workflows.py +++ b/src/databricks/labs/ucx/installer/workflows.py @@ -2,7 +2,6 @@ import re import sys import webbrowser -from collections.abc import Collection from dataclasses import replace from datetime import datetime, timedelta from pathlib import Path @@ -43,9 +42,8 @@ import databricks from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.configure import ConfigureClusterOverrides -from databricks.labs.ucx.framework.tasks import _TASKS, Task +from databricks.labs.ucx.framework.tasks import Task from databricks.labs.ucx.installer.mixins import InstallationMixin -from databricks.labs.ucx.runtime import main logger = logging.getLogger(__name__) @@ -95,6 +93,7 @@ from databricks.labs.ucx.runtime import main main(f'--config=/Workspace{config_file}', + f'--workflow=' + dbutils.widgets.get('workflow'), f'--task=' + dbutils.widgets.get('task'), f'--job_id=' + dbutils.widgets.get('job_id'), f'--run_id=' + dbutils.widgets.get('run_id'), @@ -317,7 +316,7 @@ def __init__( wheels: WheelsV2, product_info: ProductInfo, verify_timeout: timedelta, - tasks: list[Task] | None = None, + tasks: list[Task], ): self._config = config self._installation = installation @@ -326,44 +325,35 @@ def __init__( self._wheels = wheels self._product_info = product_info self._verify_timeout = verify_timeout - self._tasks = self._sort_tasks(tasks if tasks else _TASKS.values()) + self._tasks = tasks self._this_file = Path(__file__) super().__init__(config, installation, ws) - @staticmethod - def _sort_tasks(tasks: Collection[Task]) -> list[Task]: - return sorted(tasks, key=lambda x: x.task_id) - - @classmethod - def for_cli(cls, ws: WorkspaceClient): - product_info = ProductInfo.from_class(WorkspaceConfig) - installation = product_info.current_installation(ws) - install_state = InstallState.from_installation(installation) - timeout = timedelta(minutes=2) - - return DeployedWorkflows(ws, install_state, timeout) - def create_jobs(self, prompts): - logger.debug(f"Creating jobs from tasks in {main.__name__}") remote_wheel = self._upload_wheel(prompts) - desired_steps = {t.workflow for t in self._tasks if t.cloud_compatible(self._ws.config)} + desired_workflows = {t.workflow for t in self._tasks if t.cloud_compatible(self._ws.config)} wheel_runner = None if self._config.override_clusters: wheel_runner = self._upload_wheel_runner(remote_wheel) - for step_name in desired_steps: - settings = self._job_settings(step_name, remote_wheel) + for workflow_name in desired_workflows: + settings = self._job_settings(workflow_name, remote_wheel) if self._config.override_clusters: - settings = self._apply_cluster_overrides(settings, self._config.override_clusters, wheel_runner) - self._deploy_workflow(step_name, settings) + settings = self._apply_cluster_overrides( + workflow_name, + settings, + self._config.override_clusters, + wheel_runner, + ) + self._deploy_workflow(workflow_name, settings) - for step_name, job_id in self._install_state.jobs.items(): - if step_name not in desired_steps: + for workflow_name, job_id in self._install_state.jobs.items(): + if workflow_name not in desired_workflows: try: logger.info(f"Removing job_id={job_id}, as it is no longer needed") self._ws.jobs.delete(job_id) except InvalidParameterValue: - logger.warning(f"step={step_name} does not exist anymore for some reason") + logger.warning(f"step={workflow_name} does not exist anymore for some reason") continue self._install_state.save() @@ -469,7 +459,12 @@ def _upload_wheel_runner(self, remote_wheel: str): return self._installation.upload(f"wheels/wheel-test-runner-{self._product_info.version()}.py", code) @staticmethod - def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str], wheel_runner: str) -> dict: + def _apply_cluster_overrides( + workflow_name: str, + settings: dict[str, Any], + overrides: dict[str, str], + wheel_runner: str, + ) -> dict: settings["job_clusters"] = [_ for _ in settings["job_clusters"] if _.job_cluster_key not in overrides] for job_task in settings["tasks"]: if job_task.job_cluster_key is None: @@ -480,8 +475,8 @@ def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str] job_task.libraries = None if job_task.python_wheel_task is not None: job_task.python_wheel_task = None - params = {"task": job_task.task_key} | EXTRA_TASK_PARAMS - job_task.notebook_task = jobs.NotebookTask(notebook_path=wheel_runner, base_parameters=params) + widget_values = {"task": job_task.task_key, 'workflow': workflow_name} | EXTRA_TASK_PARAMS + job_task.notebook_task = jobs.NotebookTask(notebook_path=wheel_runner, base_parameters=widget_values) return settings def _job_settings(self, step_name: str, remote_wheel: str): @@ -553,16 +548,20 @@ def _job_wheel_task(self, jobs_task: jobs.Task, task: Task, remote_wheel: str) - # Shared mode cluster cannot use dbfs, need to use WSFS libraries = [compute.Library(whl=f"/Workspace{remote_wheel}")] else: - # TODO: check when we can install wheels from WSFS properly - # None UC cluster cannot use WSFS, need to use dbfs + # TODO: https://github.com/databrickslabs/ucx/issues/1098 libraries = [compute.Library(whl=f"dbfs:{remote_wheel}")] + named_parameters = { + "config": f"/Workspace{self._config_file}", + "workflow": task.workflow, + "task": task.name, + } return replace( jobs_task, libraries=libraries, python_wheel_task=jobs.PythonWheelTask( package_name="databricks_labs_ucx", entry_point="runtime", # [project.entry-points.databricks] in pyproject.toml - named_parameters={"task": task.name, "config": f"/Workspace{self._config_file}"} | EXTRA_TASK_PARAMS, + named_parameters=named_parameters | EXTRA_TASK_PARAMS, ), ) diff --git a/src/databricks/labs/ucx/mixins/fixtures.py b/src/databricks/labs/ucx/mixins/fixtures.py index fae62c7296..7c99c93fcb 100644 --- a/src/databricks/labs/ucx/mixins/fixtures.py +++ b/src/databricks/labs/ucx/mixins/fixtures.py @@ -47,7 +47,6 @@ from databricks.sdk.service.workspace import ImportFormat from databricks.labs.ucx.workspace_access.groups import MigratedGroup -from databricks.labs.ucx.workspace_access.manager import PermissionManager # this file will get to databricks-labs-pytester project and be maintained/refactored there # pylint: disable=redefined-outer-name,too-many-try-statements,import-outside-toplevel,unnecessary-lambda,too-complex,invalid-name @@ -162,11 +161,6 @@ def acc(product_info, debug_env) -> AccountClient: ) -@pytest.fixture -def permission_manager(ws, sql_backend, inventory_schema) -> PermissionManager: - return PermissionManager.factory(ws, sql_backend, inventory_schema) - - def _permissions_mapping(): from databricks.sdk.service.iam import PermissionLevel diff --git a/src/databricks/labs/ucx/runtime.py b/src/databricks/labs/ucx/runtime.py index 928c56fa41..b94a118a70 100644 --- a/src/databricks/labs/ucx/runtime.py +++ b/src/databricks/labs/ucx/runtime.py @@ -1,590 +1,96 @@ +import dataclasses import logging import os import sys - -from databricks.labs.blueprint.installation import Installation -from databricks.labs.lsql.backends import SqlBackend -from databricks.sdk import WorkspaceClient - -from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler -from databricks.labs.ucx.assessment.clusters import ClustersCrawler, PoliciesCrawler -from databricks.labs.ucx.assessment.init_scripts import GlobalInitScriptCrawler -from databricks.labs.ucx.assessment.jobs import JobsCrawler, SubmitRunsCrawler -from databricks.labs.ucx.assessment.pipelines import PipelinesCrawler -from databricks.labs.ucx.azure.access import AzureResourcePermissions -from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources -from databricks.labs.ucx.config import WorkspaceConfig -from databricks.labs.ucx.framework.tasks import task, trigger -from databricks.labs.ucx.hive_metastore import ExternalLocations, Mounts, TablesCrawler -from databricks.labs.ucx.hive_metastore.grants import ( - AzureACL, - GrantsCrawler, - PrincipalACL, +from pathlib import Path + +from databricks.labs.ucx.__about__ import __version__ +from databricks.labs.ucx.assessment.workflows import Assessment, DestroySchema +from databricks.labs.ucx.contexts.workflow_task import RuntimeContext +from databricks.labs.ucx.framework.tasks import Task, TaskLogger, Workflow, parse_args +from databricks.labs.ucx.hive_metastore.workflows import ( + MigrateTablesInMounts, + TableMigration, ) -from databricks.labs.ucx.hive_metastore.locations import TablesInMounts -from databricks.labs.ucx.hive_metastore.mapping import TableMapping -from databricks.labs.ucx.hive_metastore.table_migrate import ( - MigrationStatusRefresher, - TablesMigrator, +from databricks.labs.ucx.workspace_access.workflows import ( + GroupMigration, + PermissionsMigrationAPI, + RemoveWorkspaceLocalGroups, + ValidateGroupPermissions, ) -from databricks.labs.ucx.hive_metastore.table_size import TableSizeCrawler -from databricks.labs.ucx.hive_metastore.tables import AclMigrationWhat, What -from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler -from databricks.labs.ucx.hive_metastore.verification import VerifyHasMetastore -from databricks.labs.ucx.workspace_access.generic import WorkspaceListing -from databricks.labs.ucx.workspace_access.groups import GroupManager -from databricks.labs.ucx.workspace_access.manager import PermissionManager logger = logging.getLogger(__name__) -@task("assessment", notebook="hive_metastore/tables.scala") -def crawl_tables(*_): - """Iterates over all tables in the Hive Metastore of the current workspace and persists their metadata, such - as _database name_, _table name_, _table type_, _table location_, etc., in the Delta table named - `$inventory_database.tables`. Note that the `inventory_database` is set in the configuration file. The metadata - stored is then used in the subsequent tasks and workflows to, for example, find all Hive Metastore tables that - cannot easily be migrated to Unity Catalog.""" - - -@task("assessment", job_cluster="tacl") -def setup_tacl(*_): - """(Optimization) Starts `tacl` job cluster in parallel to crawling tables.""" - - -@task("assessment", depends_on=[crawl_tables, setup_tacl], job_cluster="tacl") -def crawl_grants(cfg: WorkspaceConfig, _ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Scans the previously created Delta table named `$inventory_database.tables` and issues a `SHOW GRANTS` - statement for every object to retrieve the permissions it has assigned to it. The permissions include information - such as the _principal_, _action type_, and the _table_ it applies to. This is persisted in the Delta table - `$inventory_database.grants`. Other, migration related jobs use this inventory table to convert the legacy Table - ACLs to Unity Catalog permissions. - - Note: This job runs on a separate cluster (named `tacl`) as it requires the proper configuration to have the Table - ACLs enabled and available for retrieval.""" - tables = TablesCrawler(sql_backend, cfg.inventory_database, cfg.include_databases) - udfs = UdfsCrawler(sql_backend, cfg.inventory_database, cfg.include_databases) - grants = GrantsCrawler(tables, udfs, cfg.include_databases) - grants.snapshot() - - -@task("assessment", depends_on=[crawl_tables]) -def estimate_table_size_for_migration( - cfg: WorkspaceConfig, _ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """Scans the previously created Delta table named `$inventory_database.tables` and locate tables that cannot be - "synced". These tables will have to be cloned in the migration process. - Assesses the size of these tables and create `$inventory_database.table_size` table to list these sizes. - The table size is a factor in deciding whether to clone these tables.""" - table_size = TableSizeCrawler(sql_backend, cfg.inventory_database) - table_size.snapshot() - - -@task("assessment") -def crawl_mounts(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Defines the scope of the _mount points_ intended for migration into Unity Catalog. As these objects are not - compatible with the Unity Catalog paradigm, a key component of the migration process involves transferring them - to Unity Catalog External Locations. - - The assessment involves scanning the workspace to compile a list of all existing mount points and subsequently - storing this information in the `$inventory.mounts` table. This is crucial for planning the migration.""" - mounts = Mounts(sql_backend, ws, cfg.inventory_database) - mounts.snapshot() - - -@task("assessment", depends_on=[crawl_mounts, crawl_tables]) -def guess_external_locations( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """Determines the shared path prefixes of all the tables. Specifically, the focus is on identifying locations that - utilize mount points. The goal is to identify the _external locations_ necessary for a successful migration and - store this information in the `$inventory.external_locations` table. - - The approach taken in this assessment involves the following steps: - - Extracting all the locations associated with tables that do not use DBFS directly, but a mount point instead - - Scanning all these locations to identify folders that can act as shared path prefixes - - These identified external locations will be created subsequently prior to the actual table migration""" - crawler = ExternalLocations(ws, sql_backend, cfg.inventory_database) - crawler.snapshot() - - -@task("assessment") -def assess_jobs(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Scans through all the jobs and identifies those that are not compatible with UC. The list of all the jobs is - stored in the `$inventory.jobs` table. - - It looks for: - - Clusters with Databricks Runtime (DBR) version earlier than 11.3 - - Clusters using Passthrough Authentication - - Clusters with incompatible Spark config tags - - Clusters referencing DBFS locations in one or more config options - """ - crawler = JobsCrawler(ws, sql_backend, cfg.inventory_database) - crawler.snapshot() - - -@task("assessment") -def assess_clusters(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Scan through all the clusters and identifies those that are not compatible with UC. The list of all the clusters - is stored in the`$inventory.clusters` table. - - It looks for: - - Clusters with Databricks Runtime (DBR) version earlier than 11.3 - - Clusters using Passthrough Authentication - - Clusters with incompatible spark config tags - - Clusters referencing DBFS locations in one or more config options - """ - crawler = ClustersCrawler(ws, sql_backend, cfg.inventory_database) - crawler.snapshot() - - -@task("assessment") -def assess_pipelines(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """This module scans through all the Pipelines and identifies those pipelines which has Azure Service Principals - embedded (who has been given access to the Azure storage accounts via spark configurations) in the pipeline - configurations. - - It looks for: - - all the pipelines which has Azure Service Principal embedded in the pipeline configuration - - Subsequently, a list of all the pipelines with matching configurations are stored in the - `$inventory.pipelines` table.""" - crawler = PipelinesCrawler(ws, sql_backend, cfg.inventory_database) - crawler.snapshot() - - -@task("assessment") -def assess_incompatible_submit_runs( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """This module scans through all the Submit Runs and identifies those runs which may become incompatible after - the workspace attachment. - - It looks for: - - All submit runs with DBR >=11.3 and data_security_mode:None - - It also combines several submit runs under a single pseudo_id based on hash of the submit run configuration. - Subsequently, a list of all the incompatible runs with failures are stored in the - `$inventory.submit_runs` table.""" - crawler = SubmitRunsCrawler(ws, sql_backend, cfg.inventory_database, cfg.num_days_submit_runs_history) - crawler.snapshot() - - -@task("assessment") -def crawl_cluster_policies(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """This module scans through all the Cluster Policies and get the necessary information - - It looks for: - - Clusters Policies with Databricks Runtime (DBR) version earlier than 11.3 - - Subsequently, a list of all the policies with matching configurations are stored in the - `$inventory.policies` table.""" - crawler = PoliciesCrawler(ws, sql_backend, cfg.inventory_database) - crawler.snapshot() - - -@task("assessment", cloud="azure") -def assess_azure_service_principals( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """This module scans through all the clusters configurations, cluster policies, job cluster configurations, - Pipeline configurations, Warehouse configuration and identifies all the Azure Service Principals who has been - given access to the Azure storage accounts via spark configurations referred in those entities. - - It looks in: - - all those entities and prepares a list of Azure Service Principal embedded in their configurations - - Subsequently, the list of all the Azure Service Principals referred in those configurations are saved - in the `$inventory.azure_service_principals` table.""" - if ws.config.is_azure: - crawler = AzureServicePrincipalCrawler(ws, sql_backend, cfg.inventory_database) - crawler.snapshot() - - -@task("assessment") -def assess_global_init_scripts( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """This module scans through all the global init scripts and identifies if there is an Azure Service Principal - who has been given access to the Azure storage accounts via spark configurations referred in those scripts. - - It looks in: - - the list of all the global init scripts are saved in the `$inventory.azure_service_principals` table.""" - crawler = GlobalInitScriptCrawler(ws, sql_backend, cfg.inventory_database) - crawler.snapshot() - - -@task("assessment") -def workspace_listing(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Scans the workspace for workspace objects. It recursively list all sub directories - and compiles a list of directories, notebooks, files, repos and libraries in the workspace. - - It uses multi-threading to parallelize the listing process to speed up execution on big workspaces. - It accepts starting path as the parameter defaulted to the root path '/'.""" - crawler = WorkspaceListing(ws, sql_backend, cfg.inventory_database, cfg.num_threads, cfg.workspace_start_path) - crawler.snapshot() - - -@task("assessment", depends_on=[crawl_grants, workspace_listing]) -def crawl_permissions(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Scans the workspace-local groups and all their permissions. The list is stored in the `$inventory.permissions` - Delta table. - - This is the first step for the _group migration_ process, which is continued in the `migrate-groups` workflow. - This step includes preparing Legacy Table ACLs for local group migration.""" - permission_manager = PermissionManager.factory( - ws, - sql_backend, - cfg.inventory_database, - num_threads=cfg.num_threads, - workspace_start_path=cfg.workspace_start_path, - ) - permission_manager.cleanup() - permission_manager.inventorize_permissions() - - -@task("assessment") -def crawl_groups(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Scans all groups for the local group migration scope""" - group_manager = GroupManager( - sql_backend, - ws, - cfg.inventory_database, - cfg.include_group_names, - cfg.renamed_group_prefix, - workspace_group_regex=cfg.workspace_group_regex, - workspace_group_replace=cfg.workspace_group_replace, - account_group_regex=cfg.account_group_regex, - external_id_match=cfg.group_match_by_external_id, - ) - group_manager.snapshot() - - -@task( - "assessment", - depends_on=[ - crawl_grants, - crawl_groups, - crawl_permissions, - guess_external_locations, - assess_jobs, - assess_incompatible_submit_runs, - assess_clusters, - crawl_cluster_policies, - assess_azure_service_principals, - assess_pipelines, - assess_global_init_scripts, - crawl_tables, - ], - dashboard="assessment_main", -) -def assessment_report(*_): - """Refreshes the assessment dashboard after all previous tasks have been completed. Note that you can access the - dashboard _before_ all tasks have been completed, but then only already completed information is shown.""" - - -@task( - "assessment", - depends_on=[ - assess_jobs, - assess_incompatible_submit_runs, - assess_clusters, - assess_pipelines, - crawl_tables, - ], - dashboard="assessment_estimates", -) -def estimates_report(*_): - """Refreshes the assessment dashboard after all previous tasks have been completed. Note that you can access the - dashboard _before_ all tasks have been completed, but then only already completed information is shown.""" - - -@task("migrate-groups", depends_on=[crawl_groups]) -def rename_workspace_local_groups( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """Renames workspace local groups by adding `ucx-renamed-` prefix.""" - verify_has_metastore = VerifyHasMetastore(ws) - if verify_has_metastore.verify_metastore(): - logger.info("Metastore exists in the workspace") - - group_manager = GroupManager( - sql_backend, - ws, - cfg.inventory_database, - cfg.include_group_names, - cfg.renamed_group_prefix, - workspace_group_regex=cfg.workspace_group_regex, - workspace_group_replace=cfg.workspace_group_replace, - account_group_regex=cfg.account_group_regex, - external_id_match=cfg.group_match_by_external_id, - ) - group_manager.rename_groups() - - -@task("migrate-groups", depends_on=[rename_workspace_local_groups]) -def reflect_account_groups_on_workspace( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """Adds matching account groups to this workspace. The matching account level group(s) must preexist(s) for this - step to be successful. This process does not create the account level group(s).""" - group_manager = GroupManager( - sql_backend, - ws, - cfg.inventory_database, - cfg.include_group_names, - cfg.renamed_group_prefix, - workspace_group_regex=cfg.workspace_group_regex, - workspace_group_replace=cfg.workspace_group_replace, - account_group_regex=cfg.account_group_regex, - external_id_match=cfg.group_match_by_external_id, - ) - group_manager.reflect_account_groups_on_workspace() - - -@task("migrate-groups", depends_on=[reflect_account_groups_on_workspace], job_cluster="tacl") -def apply_permissions_to_account_groups( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """Fourth phase of the workspace-local group migration process. It does the following: - - Assigns the full set of permissions of the original group to the account-level one - - It covers local workspace-local permissions for all entities: Legacy Table ACLs, Entitlements, - AWS instance profiles, Clusters, Cluster policies, Instance Pools, Databricks SQL warehouses, Delta Live - Tables, Jobs, MLflow experiments, MLflow registry, SQL Dashboards & Queries, SQL Alerts, Token and Password usage - permissions, Secret Scopes, Notebooks, Directories, Repos, Files. - - See [interactive tutorial here](https://app.getreprise.com/launch/myM3VNn/).""" - group_manager = GroupManager( - sql_backend, - ws, - cfg.inventory_database, - cfg.include_group_names, - cfg.renamed_group_prefix, - workspace_group_regex=cfg.workspace_group_regex, - workspace_group_replace=cfg.workspace_group_replace, - account_group_regex=cfg.account_group_regex, - external_id_match=cfg.group_match_by_external_id, - ) - - migration_state = group_manager.get_migration_state() - if len(migration_state.groups) == 0: - logger.info("Skipping group migration as no groups were found.") - return - - permission_manager = PermissionManager.factory( - ws, - sql_backend, - cfg.inventory_database, - num_threads=cfg.num_threads, - workspace_start_path=cfg.workspace_start_path, - ) - permission_manager.apply_group_permissions(migration_state) - - -@task("validate-groups-permissions", job_cluster="tacl") -def validate_groups_permissions( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """Validate that all the crawled permissions are applied correctly to the destination groups.""" - logger.info("Running validation of permissions applied to destination groups.") - permission_manager = PermissionManager.factory( - ws, - sql_backend, - cfg.inventory_database, - num_threads=cfg.num_threads, - workspace_start_path=cfg.workspace_start_path, - ) - permission_manager.verify_group_permissions() - - -@task("remove-workspace-local-backup-groups", depends_on=[apply_permissions_to_account_groups]) -def delete_backup_groups(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """Last step of the group migration process. Removes all workspace-level backup groups, along with their - permissions. Execute this workflow only after you've confirmed that workspace-local migration worked - successfully for all the groups involved.""" - group_manager = GroupManager( - sql_backend, - ws, - cfg.inventory_database, - cfg.include_group_names, - cfg.renamed_group_prefix, - workspace_group_regex=cfg.workspace_group_regex, - workspace_group_replace=cfg.workspace_group_replace, - account_group_regex=cfg.account_group_regex, - external_id_match=cfg.group_match_by_external_id, - ) - group_manager.delete_original_workspace_groups() - - -@task("099-destroy-schema") -def destroy_schema(cfg: WorkspaceConfig, _ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation): - """This _clean-up_ workflow allows to removes the `$inventory` database, with all the inventory tables created by - the previous workflow runs. Use this to reset the entire state and start with the assessment step again.""" - sql_backend.execute(f"DROP DATABASE {cfg.inventory_database} CASCADE") - - -@task("migrate-tables", job_cluster="table_migration") -def migrate_external_tables_sync( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, install: Installation -): - """This workflow task migrates the *external tables that are supported by SYNC command* from the Hive Metastore to the Unity Catalog. - Following cli commands are required to be run before running this task: - - For Azure: `principal-prefix-access`, `create-table-mapping`, `create-uber-principal`, `migrate-credentials`, `migrate-locations`, `create-catalogs-schemas` - - For AWS: TBD - """ - table_crawler = TablesCrawler(sql_backend, cfg.inventory_database) - udf_crawler = UdfsCrawler(sql_backend, cfg.inventory_database) - grant_crawler = GrantsCrawler(table_crawler, udf_crawler) - table_mappings = TableMapping(install, ws, sql_backend) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, cfg.inventory_database, table_crawler) - group_manager = GroupManager(sql_backend, ws, cfg.inventory_database) - mount_crawler = Mounts(sql_backend, ws, cfg.inventory_database) - cluster_locations = {} - if ws.config.is_azure: - locations = ExternalLocations(ws, sql_backend, cfg.inventory_database) - azure_client = AzureAPIClient( - ws.config.arm_environment.resource_manager_endpoint, - ws.config.arm_environment.service_management_endpoint, +class Workflows: + def __init__(self, workflows: list[Workflow]): + self._tasks: list[Task] = [] + self._workflows: dict[str, Workflow] = {} + for workflow in workflows: + self._workflows[workflow.name] = workflow + for task_definition in workflow.tasks(): + # Add the workflow name to the task definition, because we cannot access + # the workflow name from the method decorator + with_workflow = dataclasses.replace(task_definition, workflow=workflow.name) + self._tasks.append(with_workflow) + + @classmethod + def all(cls): + return cls( + [ + Assessment(), + GroupMigration(), + TableMigration(), + ValidateGroupPermissions(), + RemoveWorkspaceLocalGroups(), + MigrateTablesInMounts(), + PermissionsMigrationAPI(), + DestroySchema(), + ] ) - graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") - azurerm = AzureResources(azure_client, graph_client) - resource_permissions = AzureResourcePermissions(install, ws, azurerm, locations) - spn_crawler = AzureServicePrincipalCrawler(ws, sql_backend, cfg.inventory_database) - cluster_locations = AzureACL( - ws, sql_backend, spn_crawler, resource_permissions - ).get_eligible_locations_principals() - interactive_grants = PrincipalACL(ws, sql_backend, install, table_crawler, mount_crawler, cluster_locations) - TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - table_mappings, - group_manager, - migration_status_refresher, - interactive_grants, - ).migrate_tables(what=What.EXTERNAL_SYNC, acl_strategy=[AclMigrationWhat.LEGACY_TACL, AclMigrationWhat.PRINCIPAL]) - - -@task("migrate-tables", job_cluster="table_migration") -def migrate_dbfs_root_delta_tables( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, install: Installation -): - """This workflow task migrates `delta tables stored in DBFS root` from the Hive Metastore to the Unity Catalog using deep clone. - Following cli commands are required to be run before running this task: - - For Azure: `principal-prefix-access`, `create-table-mapping`, `create-uber-principal`, `migrate-credentials`, `migrate-locations`, `create-catalogs-schemas` - - For AWS: TBD - """ - table_crawler = TablesCrawler(sql_backend, cfg.inventory_database) - udf_crawler = UdfsCrawler(sql_backend, cfg.inventory_database) - grant_crawler = GrantsCrawler(table_crawler, udf_crawler) - table_mappings = TableMapping(install, ws, sql_backend) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, cfg.inventory_database, table_crawler) - group_manager = GroupManager(sql_backend, ws, cfg.inventory_database) - mount_crawler = Mounts(sql_backend, ws, cfg.inventory_database) - cluster_locations = {} - if ws.config.is_azure: - locations = ExternalLocations(ws, sql_backend, cfg.inventory_database) - azure_client = AzureAPIClient( - ws.config.arm_environment.resource_manager_endpoint, - ws.config.arm_environment.service_management_endpoint, - ) - graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") - azurerm = AzureResources(azure_client, graph_client) - resource_permissions = AzureResourcePermissions(install, ws, azurerm, locations) - spn_crawler = AzureServicePrincipalCrawler(ws, sql_backend, cfg.inventory_database) - cluster_locations = AzureACL( - ws, sql_backend, spn_crawler, resource_permissions - ).get_eligible_locations_principals() - interactive_grants = PrincipalACL(ws, sql_backend, install, table_crawler, mount_crawler, cluster_locations) - TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - table_mappings, - group_manager, - migration_status_refresher, - interactive_grants, - ).migrate_tables(what=What.DBFS_ROOT_DELTA, acl_strategy=[AclMigrationWhat.LEGACY_TACL, AclMigrationWhat.PRINCIPAL]) - - -@task("migrate-groups-experimental", depends_on=[crawl_groups]) -def rename_workspace_local_groups_experimental( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """EXPERIMENTAL - Renames workspace local groups by adding `ucx-renamed-` prefix.""" - return rename_workspace_local_groups(cfg, ws, sql_backend, _install) - - -@task("migrate-groups-experimental", depends_on=[rename_workspace_local_groups_experimental]) -def reflect_account_groups_on_workspace_experimental( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """EXPERIMENTAL - Adds matching account groups to this workspace. The matching account level group(s) must preexist(s) for this - step to be successful. This process does not create the account level group(s).""" - return reflect_account_groups_on_workspace(cfg, ws, sql_backend, _install) - - -@task( - "migrate-groups-experimental", - depends_on=[ - reflect_account_groups_on_workspace_experimental, - rename_workspace_local_groups_experimental, - ], -) -def apply_permissions_to_account_groups_experimental( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """EXPERIMENTAL - This task uses the new permission migration API which requires enrolment from Databricks - Fourth phase of the workspace-local group migration process. It does the following: - - Assigns the full set of permissions of the original group to the account-level one - - It covers local workspace-local permissions for most entities: Entitlements, - AWS instance profiles, Clusters, Cluster policies, Instance Pools, Databricks SQL warehouses, Delta Live - Tables, Jobs, MLflow experiments, MLflow registry, SQL Dashboards & Queries, SQL Alerts, Token and Password usage - permissions, Secret Scopes, Notebooks, Directories, Repos, Files. - """ - group_manager = GroupManager( - sql_backend, - ws, - cfg.inventory_database, - cfg.include_group_names, - cfg.renamed_group_prefix, - workspace_group_regex=cfg.workspace_group_regex, - workspace_group_replace=cfg.workspace_group_replace, - account_group_regex=cfg.account_group_regex, - external_id_match=cfg.group_match_by_external_id, - ) - migration_state = group_manager.get_migration_state() - if len(migration_state.groups) == 0: - logger.info("Skipping group migration as no groups were found.") - return - migration_state.apply_to_renamed_groups(ws) - -@task("migrate-tables-in-mounts-experimental") -def scan_tables_in_mounts_experimental( - cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend, _install: Installation -): - """EXPERIMENTAL - This workflow scans for Delta tables inside all mount points captured during the assessment. - It will store the results under the `tables` table located under the assessment. - """ - mounts = Mounts(sql_backend, ws, cfg.inventory_database) - TablesInMounts( - sql_backend, ws, cfg.inventory_database, mounts, cfg.include_mounts, cfg.exclude_paths_in_mount - ).snapshot() + @classmethod + def for_testing(cls, workflow, task_name, **replace): + ctx = RuntimeContext().replace(**replace) + current_task = getattr(cls.all()._workflows[workflow], task_name) + current_task(ctx) + + def tasks(self) -> list[Task]: + return self._tasks + + def trigger(self, *argv): + named_parameters = parse_args(*argv) + config_path = Path(named_parameters["config"]) + ctx = RuntimeContext(named_parameters) + install_dir = config_path.parent + task_name = named_parameters.get("task", "not specified") + workflow_name = named_parameters.get("workflow", "not specified") + if workflow_name not in self._workflows: + msg = f'workflow "{workflow_name}" not found. Valid workflows are: {", ".join(self._workflows.keys())}' + raise KeyError(msg) + print(f"UCX v{__version__}") + workflow = self._workflows[workflow_name] + # `{{parent_run_id}}` is the run of entire workflow, whereas `{{run_id}}` is the run of a task + workflow_run_id = named_parameters.get("parent_run_id", "unknown_run_id") + job_id = named_parameters.get("job_id", "unknown_job_id") + with TaskLogger( + install_dir, + workflow=workflow_name, + workflow_id=job_id, + task_name=task_name, + workflow_run_id=workflow_run_id, + log_level=ctx.config.log_level, + ) as task_logger: + ucx_logger = logging.getLogger("databricks.labs.ucx") + ucx_logger.info(f"UCX v{__version__} After job finishes, see debug logs at {task_logger}") + current_task = getattr(workflow, task_name) + current_task(ctx) def main(*argv): if len(argv) == 0: argv = sys.argv - trigger(*argv) + Workflows.all().trigger(*argv) if __name__ == "__main__": diff --git a/src/databricks/labs/ucx/source_code/files.py b/src/databricks/labs/ucx/source_code/files.py index eb6fe87ae7..d821f4ac8c 100644 --- a/src/databricks/labs/ucx/source_code/files.py +++ b/src/databricks/labs/ucx/source_code/files.py @@ -1,29 +1,20 @@ import logging from pathlib import Path -from databricks.sdk import WorkspaceClient from databricks.sdk.service.workspace import Language -from databricks.labs.ucx.hive_metastore.table_migrate import TablesMigrator from databricks.labs.ucx.source_code.languages import Languages logger = logging.getLogger(__name__) -class Files: +class LocalFileMigrator: """The Files class is responsible for fixing code files based on their language.""" def __init__(self, languages: Languages): self._languages = languages self._extensions = {".py": Language.PYTHON, ".sql": Language.SQL} - @classmethod - def for_cli(cls, ws: WorkspaceClient): - tables_migrate = TablesMigrator.for_cli(ws) - index = tables_migrate.index() - languages = Languages(index) - return cls(languages) - def apply(self, path: Path) -> bool: if path.is_dir(): for child_path in path.iterdir(): diff --git a/src/databricks/labs/ucx/workspace_access/manager.py b/src/databricks/labs/ucx/workspace_access/manager.py index f1911dac20..102febbfff 100644 --- a/src/databricks/labs/ucx/workspace_access/manager.py +++ b/src/databricks/labs/ucx/workspace_access/manager.py @@ -1,26 +1,18 @@ import json import logging -import os from collections.abc import Callable, Iterable, Iterator, Sequence from itertools import groupby from databricks.labs.blueprint.parallel import ManyError, Threads from databricks.labs.lsql.backends import SqlBackend -from databricks.sdk import WorkspaceClient -from databricks.sdk.service import sql from databricks.labs.ucx.framework.crawlers import ( CrawlerBase, Dataclass, DataclassInstance, ) -from databricks.labs.ucx.hive_metastore import TablesCrawler -from databricks.labs.ucx.hive_metastore.grants import GrantsCrawler -from databricks.labs.ucx.hive_metastore.udfs import UdfsCrawler -from databricks.labs.ucx.workspace_access import generic, redash, scim, secrets from databricks.labs.ucx.workspace_access.base import AclSupport, Permissions from databricks.labs.ucx.workspace_access.groups import MigrationState -from databricks.labs.ucx.workspace_access.tacl import TableAclSupport logger = logging.getLogger(__name__) @@ -32,60 +24,6 @@ def __init__(self, backend: SqlBackend, inventory_database: str, crawlers: list[ super().__init__(backend, "hive_metastore", inventory_database, "permissions", Permissions) self._acl_support = crawlers - @classmethod - def factory( - cls, - ws: WorkspaceClient, - sql_backend: SqlBackend, - inventory_database: str, - *, - num_threads: int | None = None, - workspace_start_path: str = "/", - ) -> "PermissionManager": - if num_threads is None: - cpu_count = os.cpu_count() - if not cpu_count: - cpu_count = 1 - num_threads = cpu_count * 2 - generic_acl_listing = [ - generic.Listing(ws.clusters.list, "cluster_id", "clusters"), - generic.Listing(ws.cluster_policies.list, "policy_id", "cluster-policies"), - generic.Listing(ws.instance_pools.list, "instance_pool_id", "instance-pools"), - generic.Listing(ws.warehouses.list, "id", "sql/warehouses"), - generic.Listing(ws.jobs.list, "job_id", "jobs"), - generic.Listing(ws.pipelines.list_pipelines, "pipeline_id", "pipelines"), - generic.Listing(ws.serving_endpoints.list, "id", "serving-endpoints"), - generic.Listing(generic.experiments_listing(ws), "experiment_id", "experiments"), - generic.Listing(generic.models_listing(ws, num_threads), "id", "registered-models"), - generic.Listing(generic.models_root_page, "object_id", "registered-models"), - generic.Listing(generic.tokens_and_passwords, "object_id", "authorization"), - generic.Listing(generic.feature_store_listing(ws), "object_id", "feature-tables"), - generic.Listing(generic.feature_tables_root_page, "object_id", "feature-tables"), - generic.WorkspaceListing( - ws, - sql_backend=sql_backend, - inventory_database=inventory_database, - num_threads=num_threads, - start_path=workspace_start_path, - ), - ] - redash_acl_listing = [ - redash.Listing(ws.alerts.list, sql.ObjectTypePlural.ALERTS), - redash.Listing(ws.dashboards.list, sql.ObjectTypePlural.DASHBOARDS), - redash.Listing(ws.queries.list, sql.ObjectTypePlural.QUERIES), - ] - generic_support = generic.GenericPermissionsSupport(ws, generic_acl_listing) - sql_support = redash.RedashPermissionsSupport(ws, redash_acl_listing) - secrets_support = secrets.SecretScopesSupport(ws) - scim_support = scim.ScimSupport(ws) - tables_crawler = TablesCrawler(sql_backend, inventory_database) - udfs_crawler = UdfsCrawler(sql_backend, inventory_database) - grants_crawler = GrantsCrawler(tables_crawler, udfs_crawler) - tacl_support = TableAclSupport(grants_crawler, sql_backend) - return cls( - sql_backend, inventory_database, [generic_support, sql_support, secrets_support, scim_support, tacl_support] - ) - def inventorize_permissions(self): # TODO: rename into snapshot() logger.debug("Crawling permissions") diff --git a/src/databricks/labs/ucx/workspace_access/workflows.py b/src/databricks/labs/ucx/workspace_access/workflows.py new file mode 100644 index 0000000000..a3305705b3 --- /dev/null +++ b/src/databricks/labs/ucx/workspace_access/workflows.py @@ -0,0 +1,110 @@ +import logging + +from databricks.labs.ucx.assessment.workflows import Assessment +from databricks.labs.ucx.contexts.workflow_task import RuntimeContext +from databricks.labs.ucx.framework.tasks import Workflow, job_task + +logger = logging.getLogger(__name__) + + +class GroupMigration(Workflow): + def __init__(self): + super().__init__('migrate-groups') + + @job_task(depends_on=[Assessment.crawl_groups]) + def rename_workspace_local_groups(self, ctx: RuntimeContext): + """Renames workspace local groups by adding `ucx-renamed-` prefix.""" + if ctx.verify_has_metastore.verify_metastore(): + logger.info("Metastore exists in the workspace") + ctx.group_manager.rename_groups() + + @job_task(depends_on=[rename_workspace_local_groups]) + def reflect_account_groups_on_workspace(self, ctx: RuntimeContext): + """Adds matching account groups to this workspace. The matching account level group(s) must preexist(s) for this + step to be successful. This process does not create the account level group(s).""" + ctx.group_manager.reflect_account_groups_on_workspace() + + @job_task(depends_on=[reflect_account_groups_on_workspace], job_cluster="tacl") + def apply_permissions_to_account_groups(self, ctx: RuntimeContext): + """Fourth phase of the workspace-local group migration process. It does the following: + - Assigns the full set of permissions of the original group to the account-level one + + It covers local workspace-local permissions for all entities: Legacy Table ACLs, Entitlements, + AWS instance profiles, Clusters, Cluster policies, Instance Pools, Databricks SQL warehouses, Delta Live + Tables, Jobs, MLflow experiments, MLflow registry, SQL Dashboards & Queries, SQL Alerts, Token and Password usage + permissions, Secret Scopes, Notebooks, Directories, Repos, Files. + + See [interactive tutorial here](https://app.getreprise.com/launch/myM3VNn/).""" + migration_state = ctx.group_manager.get_migration_state() + if len(migration_state.groups) == 0: + logger.info("Skipping group migration as no groups were found.") + return + ctx.permission_manager.apply_group_permissions(migration_state) + + @job_task(job_cluster="tacl") + def validate_groups_permissions(self, ctx: RuntimeContext): + """Validate that all the crawled permissions are applied correctly to the destination groups.""" + ctx.permission_manager.verify_group_permissions() + + +class PermissionsMigrationAPI(Workflow): + def __init__(self): + super().__init__('migrate-groups-experimental') + + @job_task(depends_on=[Assessment.crawl_groups]) + def rename_workspace_local_groups(self, ctx: RuntimeContext): + """[EXPERIMENTAL] Renames workspace local groups by adding `ucx-renamed-` prefix.""" + ctx.group_manager.rename_groups() + + @job_task(depends_on=[rename_workspace_local_groups]) + def reflect_account_groups_on_workspace(self, ctx: RuntimeContext): + """[EXPERIMENTAL] Adds matching account groups to this workspace. The matching account level group(s) must preexist(s) for this + step to be successful. This process does not create the account level group(s).""" + ctx.group_manager.reflect_account_groups_on_workspace() + + @job_task(depends_on=[reflect_account_groups_on_workspace]) + def apply_permissions(self, ctx: RuntimeContext): + """[EXPERIMENTAL] This task uses the new permission migration API which requires enrolment from Databricks. + Fourth phase of the workspace-local group migration process. It does the following: + - Assigns the full set of permissions of the original group to the account-level one + + The permission migration is not atomic. If we hit InternalError, it is possible that half the permissions + have already been migrated over to the account group, and the other half of the permissions are still with + the workspace local group. Addressing cases like this would require two options that both require manual + intervention: + - Deleting all conflicting permissions for the account group, rerun the permission migration between + the workspace and account group (the workflow is idempotent). + - Creating a new account group, reverting all the permissions that were migrated over to the old account + group, and running the workflow again. + + To make things run smoothly, this workflow should never be run on an account group that already has permissions. + The expectation is that account group has no permissions to begin with. + + It covers local workspace-local permissions for all entities.""" + migration_state = ctx.group_manager.get_migration_state() + if len(migration_state.groups) == 0: + logger.info("Skipping group migration as no groups were found.") + return + migration_state.apply_to_renamed_groups(ctx.workspace_client) + + +class ValidateGroupPermissions(Workflow): + def __init__(self): + super().__init__('validate-groups-permissions') + + @job_task(job_cluster="tacl") + def validate_groups_permissions(self, ctx: RuntimeContext): + """Validate that all the crawled permissions are applied correctly to the destination groups.""" + ctx.permission_manager.verify_group_permissions() + + +class RemoveWorkspaceLocalGroups(Workflow): + def __init__(self): + super().__init__('remove-workspace-local-backup-groups') + + @job_task(depends_on=[GroupMigration.apply_permissions_to_account_groups]) + def delete_backup_groups(self, ctx: RuntimeContext): + """Last step of the group migration process. Removes all workspace-level backup groups, along with their + permissions. Execute this workflow only after you've confirmed that workspace-local migration worked + successfully for all the groups involved.""" + ctx.group_manager.delete_original_workspace_groups() diff --git a/tests/integration/aws/test_access.py b/tests/integration/aws/test_access.py index ecabac5e97..6b483f8c66 100644 --- a/tests/integration/aws/test_access.py +++ b/tests/integration/aws/test_access.py @@ -7,16 +7,18 @@ from databricks.labs.ucx.assessment.aws import AWSInstanceProfile, AWSResources from databricks.labs.ucx.aws.access import AWSResourcePermissions from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.contexts.cli_command import WorkspaceContext from databricks.labs.ucx.hive_metastore import ExternalLocations from databricks.labs.ucx.hive_metastore.locations import ExternalLocation -def test_get_uc_compatible_roles(ws, sql_backend, env_or_skip, make_random, inventory_schema): - profile = env_or_skip("AWS_DEFAULT_PROFILE") +def test_get_uc_compatible_roles(ws, env_or_skip, make_random): installation = Installation(ws, make_random(4)) - aws = AWSResources(profile) - awsrp = AWSResourcePermissions.for_cli(ws, installation, sql_backend, aws, inventory_schema) - compat_roles = awsrp.load_uc_compatible_roles() + ctx = WorkspaceContext(ws).replace( + aws_profile=env_or_skip("AWS_DEFAULT_PROFILE"), + installation=installation, + ) + compat_roles = ctx.aws_resource_permissions.load_uc_compatible_roles() print(compat_roles) assert compat_roles diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 9f3ee74b93..6fb7d1f8ba 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,10 +1,11 @@ import collections import logging -from functools import partial +import warnings +from functools import partial, cached_property import databricks.sdk.core import pytest # pylint: disable=wrong-import-order -from databricks.labs.blueprint.installation import Installation +from databricks.labs.blueprint.installation import Installation, MockInstallation from databricks.labs.lsql.backends import SqlBackend from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.service.catalog import FunctionInfo, TableInfo @@ -15,6 +16,9 @@ AzureServicePrincipalCrawler, AzureServicePrincipalInfo, ) +from databricks.labs.ucx.azure.access import AzureResourcePermissions, StoragePermissionMapping +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.contexts.workflow_task import RuntimeContext from databricks.labs.ucx.hive_metastore import TablesCrawler from databricks.labs.ucx.hive_metastore.grants import Grant, GrantsCrawler from databricks.labs.ucx.hive_metastore.locations import Mount, Mounts @@ -173,6 +177,8 @@ def snapshot(self) -> list[Grant]: class StaticTableMapping(TableMapping): def __init__(self, workspace_client: WorkspaceClient, sb: SqlBackend, rules: list[Rule]): + # TODO: remove this class, it creates difficulties when used together with Permission mapping + warnings.warn("switch to using runtime_ctx fixture", DeprecationWarning) installation = Installation(workspace_client, 'ucx') super().__init__(installation, workspace_client, sb) self._rules = rules @@ -194,9 +200,167 @@ def snapshot(self) -> list[AzureServicePrincipalInfo]: class StaticMountCrawler(Mounts): - def __init__(self, mounts: list[Mount], *args): - super().__init__(*args) + def __init__( + self, + mounts: list[Mount], + sb: SqlBackend, + workspace_client: WorkspaceClient, + inventory_database: str, + ): + super().__init__(sb, workspace_client, inventory_database) self._mounts = mounts def snapshot(self) -> list[Mount]: return self._mounts + + +class TestRuntimeContext(RuntimeContext): + def __init__(self, make_table_fixture, make_schema_fixture, make_udf_fixture, env_or_skip_fixture): + super().__init__() + self._make_table = make_table_fixture + self._make_schema = make_schema_fixture + self._make_udf = make_udf_fixture + self._env_or_skip = env_or_skip_fixture + self._tables = [] + self._schemas = [] + self._udfs = [] + self._grants = [] + # TODO: add methods to pre-populate the following: + self._spn_infos = [] + + def with_dummy_azure_resource_permission(self): + # TODO: in most cases (except prepared_principal_acl) it's just a sign of a bad logic, fix it + self.with_azure_storage_permissions( + [ + StoragePermissionMapping( + # TODO: replace with env variable + prefix='abfss://things@labsazurethings.dfs.core.windows.net', + client_id='dummy_application_id', + principal='principal_1', + privilege='WRITE_FILES', + type='Application', + directory_id='directory_id_ss1', + ) + ] + ) + + def with_azure_storage_permissions(self, mapping: list[StoragePermissionMapping]): + self.installation.save(mapping, filename=AzureResourcePermissions.FILENAME) + + def with_table_mapping_rule( + self, + catalog_name: str, + src_schema: str, + dst_schema: str, + src_table: str, + dst_table: str, + ): + self.with_table_mapping_rules( + [ + Rule( + workspace_name="workspace", + catalog_name=catalog_name, + src_schema=src_schema, + dst_schema=dst_schema, + src_table=src_table, + dst_table=dst_table, + ) + ] + ) + + def with_table_mapping_rules(self, rules): + self.installation.save(rules, filename=TableMapping.FILENAME) + + def make_table(self, **kwargs): + table_info = self._make_table(**kwargs) + self._tables.append(table_info) + return table_info + + def make_udf(self, **kwargs): + udf_info = self._make_udf(**kwargs) + self._udfs.append(udf_info) + return udf_info + + def make_grant( # pylint: disable=too-many-arguments + self, + principal: str, + action_type: str, + catalog: str | None = None, + database: str | None = None, + table: str | None = None, + view: str | None = None, + udf: str | None = None, + any_file: bool = False, + anonymous_function: bool = False, + ): + grant = Grant( + principal=principal, + action_type=action_type, + catalog=catalog, + database=database, + table=table, + view=view, + udf=udf, + any_file=any_file, + anonymous_function=anonymous_function, + ) + for query in grant.hive_grant_sql(): + self.sql_backend.execute(query) + self._grants.append(grant) + return grant + + @cached_property + def config(self) -> WorkspaceConfig: + return WorkspaceConfig( + warehouse_id=self._env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"), + inventory_database=self.inventory_database, + connect=self.workspace_client.config, + ) + + @cached_property + def installation(self): + # TODO: we may need to do a real installation instead of a mock + return MockInstallation() + + @cached_property + def inventory_database(self) -> str: + return self._make_schema(catalog_name="hive_metastore").name + + @cached_property + def tables_crawler(self): + return StaticTablesCrawler(self.sql_backend, self.inventory_database, self._tables) + + @cached_property + def udfs_crawler(self): + return StaticUdfsCrawler(self.sql_backend, self.inventory_database, self._udfs) + + @cached_property + def grants_crawler(self): + return StaticGrantsCrawler(self.tables_crawler, self.udfs_crawler, self._grants) + + @cached_property + def azure_service_principal_crawler(self): + return StaticServicePrincipalCrawler( + self._spn_infos, + self.workspace_client, + self.sql_backend, + self.inventory_database, + ) + + @cached_property + def mounts_crawler(self): + # TODO: replace with env variable and make AWS and Azure versions + real_location = 'abfss://things@labsazurethings.dfs.core.windows.net/a' + mount = Mount(f'/mnt/{self._env_or_skip("TEST_MOUNT_NAME")}/a', real_location) + return StaticMountCrawler( + [mount], + self.sql_backend, + self.workspace_client, + self.inventory_database, + ) + + +@pytest.fixture +def runtime_ctx(ws, sql_backend, make_table, make_schema, make_udf, env_or_skip): + ctx = TestRuntimeContext(make_table, make_schema, make_udf, env_or_skip) + return ctx.replace(workspace_client=ws, sql_backend=sql_backend) diff --git a/tests/integration/hive_metastore/test_migrate.py b/tests/integration/hive_metastore/test_migrate.py index bd2b593cfb..80a315b893 100644 --- a/tests/integration/hive_metastore/test_migrate.py +++ b/tests/integration/hive_metastore/test_migrate.py @@ -1,36 +1,19 @@ import logging from datetime import timedelta -from unittest.mock import create_autospec import pytest -from databricks.labs.blueprint.installation import MockInstallation from databricks.sdk.errors import NotFound from databricks.sdk.retries import retried from databricks.sdk.service.catalog import Privilege, SecurableType from databricks.sdk.service.compute import DataSecurityMode from databricks.sdk.service.iam import PermissionLevel -from databricks.labs.ucx.hive_metastore.grants import ( - AzureACL, - Grant, - GrantsCrawler, - PrincipalACL, -) -from databricks.labs.ucx.hive_metastore.locations import Mount from databricks.labs.ucx.hive_metastore.mapping import Rule -from databricks.labs.ucx.hive_metastore.table_migrate import ( - MigrationStatusRefresher, - TablesMigrator, -) from databricks.labs.ucx.hive_metastore.tables import AclMigrationWhat, Table, What -from databricks.labs.ucx.workspace_access.groups import GroupManager from ..conftest import ( - StaticGrantsCrawler, - StaticMountCrawler, StaticTableMapping, StaticTablesCrawler, - StaticUdfsCrawler, ) logger = logging.getLogger(__name__) @@ -47,43 +30,18 @@ } -def principal_acl(ws, inventory_schema, sql_backend): - installation = MockInstallation( - { - "config.yml": { - 'inventory_database': inventory_schema, - }, - "azure_storage_account_info.csv": [ - { - 'prefix': 'dummy_prefix', - 'client_id': 'dummy_application_id', - 'principal': 'dummy_principal', - 'privilege': 'WRITE_FILES', - 'type': 'Application', - 'directory_id': 'dummy_directory', - } - ], - } - ) - return PrincipalACL.for_cli(ws, installation, sql_backend) - - @retried(on=[NotFound], timeout=timedelta(minutes=2)) -def test_migrate_managed_tables(ws, sql_backend, inventory_schema, make_catalog, make_schema, make_table): - # pylint: disable=too-many-locals +def test_migrate_managed_tables(ws, sql_backend, runtime_ctx, make_catalog, make_schema): if not ws.config.is_azure: pytest.skip("temporary: only works in azure test env") src_schema = make_schema(catalog_name="hive_metastore") - src_managed_table = make_table(catalog_name=src_schema.catalog_name, schema_name=src_schema.name) + src_managed_table = runtime_ctx.make_table(catalog_name=src_schema.catalog_name, schema_name=src_schema.name) dst_catalog = make_catalog() dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) logger.info(f"dst_catalog={dst_catalog.name}, managed_table={src_managed_table.full_name}") - table_crawler = StaticTablesCrawler(sql_backend, inventory_schema, [src_managed_table]) - udf_crawler = StaticUdfsCrawler(sql_backend, inventory_schema, []) - grant_crawler = GrantsCrawler(table_crawler, udf_crawler) rules = [ Rule( "workspace", @@ -94,22 +52,10 @@ def test_migrate_managed_tables(ws, sql_backend, inventory_schema, make_catalog, src_managed_table.name, ), ] - table_mapping = StaticTableMapping(ws, sql_backend, rules=rules) - group_manager = GroupManager(sql_backend, ws, inventory_schema) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler) - principal_grants = principal_acl(ws, inventory_schema, sql_backend) - table_migrate = TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - table_mapping, - group_manager, - migration_status_refresher, - principal_grants, - ) - table_migrate.migrate_tables(what=What.DBFS_ROOT_DELTA) + runtime_ctx.with_table_mapping_rules(rules) + runtime_ctx.with_dummy_azure_resource_permission() + runtime_ctx.tables_migrator.migrate_tables(what=What.DBFS_ROOT_DELTA) target_tables = list(sql_backend.fetch(f"SHOW TABLES IN {dst_schema.full_name}")) assert len(target_tables) == 1 @@ -119,10 +65,15 @@ def test_migrate_managed_tables(ws, sql_backend, inventory_schema, make_catalog, assert target_table_properties[Table.UPGRADED_FROM_WS_PARAM] == str(ws.get_workspace_id()) -@retried(on=[NotFound], timeout=timedelta(minutes=5)) +@retried(on=[NotFound], timeout=timedelta(minutes=2)) def test_migrate_tables_with_cache_should_not_create_table( - ws, sql_backend, inventory_schema, make_random, make_catalog, make_schema, make_table -): # pylint: disable=too-many-locals + ws, + sql_backend, + runtime_ctx, + make_random, + make_catalog, + make_schema, +): if not ws.config.is_azure: pytest.skip("temporary: only works in azure test env") src_schema = make_schema(catalog_name="hive_metastore") @@ -131,13 +82,13 @@ def test_migrate_tables_with_cache_should_not_create_table( dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) table_name = make_random().lower() - src_managed_table = make_table( + src_managed_table = runtime_ctx.make_table( catalog_name=src_schema.catalog_name, schema_name=src_schema.name, name=table_name, tbl_properties={"upgraded_from": f"{dst_schema.full_name}.{table_name}"}, ) - dst_managed_table = make_table( + dst_managed_table = runtime_ctx.make_table( catalog_name=dst_schema.catalog_name, schema_name=dst_schema.name, name=table_name, @@ -150,9 +101,6 @@ def test_migrate_tables_with_cache_should_not_create_table( f"target_managed_table={dst_managed_table}" ) - table_crawler = StaticTablesCrawler(sql_backend, inventory_schema, [src_managed_table]) - udf_crawler = StaticUdfsCrawler(sql_backend, inventory_schema, []) - grant_crawler = GrantsCrawler(table_crawler, udf_crawler) rules = [ Rule( "workspace", @@ -163,23 +111,11 @@ def test_migrate_tables_with_cache_should_not_create_table( dst_managed_table.name, ), ] - table_mapping = StaticTableMapping(ws, sql_backend, rules=rules) - group_manager = GroupManager(sql_backend, ws, inventory_schema) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler) - principal_grants = principal_acl(ws, inventory_schema, sql_backend) - table_migrate = TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - table_mapping, - group_manager, - migration_status_refresher, - principal_grants, - ) + runtime_ctx.with_table_mapping_rules(rules) + runtime_ctx.with_dummy_azure_resource_permission() # FIXME: flaky: databricks.sdk.errors.platform.NotFound: Catalog 'ucx_cjazg' does not exist. - table_migrate.migrate_tables(what=What.DBFS_ROOT_DELTA) + runtime_ctx.tables_migrator.migrate_tables(what=What.DBFS_ROOT_DELTA) target_tables = list(sql_backend.fetch(f"SHOW TABLES IN {dst_schema.full_name}")) assert len(target_tables) == 1 @@ -187,14 +123,13 @@ def test_migrate_tables_with_cache_should_not_create_table( assert target_tables[0]["tableName"] == table_name -@retried(on=[NotFound], timeout=timedelta(minutes=5)) -def test_migrate_external_table( # pylint: disable=too-many-locals +@retried(on=[NotFound], timeout=timedelta(minutes=2)) +def test_migrate_external_table( ws, sql_backend, - inventory_schema, + runtime_ctx, make_catalog, make_schema, - make_table, env_or_skip, make_random, make_dbfs_data_copy, @@ -207,13 +142,10 @@ def test_migrate_external_table( # pylint: disable=too-many-locals existing_mounted_location = f'dbfs:/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a/b/c' new_mounted_location = f'dbfs:/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a/b/{make_random(4)}' make_dbfs_data_copy(src_path=existing_mounted_location, dst_path=new_mounted_location) - src_external_table = make_table(schema_name=src_schema.name, external_csv=new_mounted_location) + src_external_table = runtime_ctx.make_table(schema_name=src_schema.name, external_csv=new_mounted_location) dst_catalog = make_catalog() dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) logger.info(f"dst_catalog={dst_catalog.name}, external_table={src_external_table.full_name}") - table_crawler = StaticTablesCrawler(sql_backend, inventory_schema, [src_external_table]) - udf_crawler = StaticUdfsCrawler(sql_backend, inventory_schema, []) - grant_crawler = GrantsCrawler(table_crawler, udf_crawler) rules = [ Rule( "workspace", @@ -224,21 +156,10 @@ def test_migrate_external_table( # pylint: disable=too-many-locals src_external_table.name, ), ] - group_manager = GroupManager(sql_backend, ws, inventory_schema) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler) - principal_grants = principal_acl(ws, inventory_schema, sql_backend) - table_migrate = TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - StaticTableMapping(ws, sql_backend, rules=rules), - group_manager, - migration_status_refresher, - principal_grants, - ) - table_migrate.migrate_tables(what=What.EXTERNAL_SYNC) + runtime_ctx.with_table_mapping_rules(rules) + runtime_ctx.with_dummy_azure_resource_permission() + runtime_ctx.tables_migrator.migrate_tables(what=What.EXTERNAL_SYNC) target_tables = list(sql_backend.fetch(f"SHOW TABLES IN {dst_schema.full_name}")) assert len(target_tables) == 1 @@ -246,7 +167,7 @@ def test_migrate_external_table( # pylint: disable=too-many-locals assert target_table_properties["upgraded_from"] == src_external_table.full_name assert target_table_properties[Table.UPGRADED_FROM_WS_PARAM] == str(ws.get_workspace_id()) - _migration_status = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler).snapshot() + _migration_status = runtime_ctx.migration_status_refresher.snapshot() migration_status = list(_migration_status) assert len(migration_status) == 1 assert migration_status[0].src_schema == src_external_table.schema_name @@ -256,24 +177,15 @@ def test_migrate_external_table( # pylint: disable=too-many-locals assert migration_status[0].dst_table == src_external_table.name -@retried(on=[NotFound], timeout=timedelta(minutes=5)) -def test_migrate_external_table_failed_sync( - ws, - caplog, - sql_backend, - inventory_schema, - make_schema, - make_table, - env_or_skip, -): +@retried(on=[NotFound], timeout=timedelta(minutes=1)) +def test_migrate_external_table_failed_sync(ws, caplog, runtime_ctx, make_schema, env_or_skip): if not ws.config.is_azure: pytest.skip("temporary: only works in azure test env") src_schema = make_schema(catalog_name="hive_metastore") existing_mounted_location = f'dbfs:/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a/b/c' - src_external_table = make_table(schema_name=src_schema.name, external_csv=existing_mounted_location) - table_crawler = StaticTablesCrawler(sql_backend, inventory_schema, [src_external_table]) - grant_crawler = create_autospec(GrantsCrawler) + src_external_table = runtime_ctx.make_table(schema_name=src_schema.name, external_csv=existing_mounted_location) + # create a mapping that will fail the SYNC because the target catalog and schema does not exist rules = [ Rule( @@ -285,43 +197,25 @@ def test_migrate_external_table_failed_sync( src_external_table.name, ), ] - group_manager = GroupManager(sql_backend, ws, inventory_schema) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler) - principal_grants = principal_acl(ws, inventory_schema, sql_backend) - table_migrate = TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - StaticTableMapping(ws, sql_backend, rules=rules), - group_manager, - migration_status_refresher, - principal_grants, - ) + runtime_ctx.with_table_mapping_rules(rules) + runtime_ctx.with_dummy_azure_resource_permission() + runtime_ctx.tables_migrator.migrate_tables(what=What.EXTERNAL_SYNC) - table_migrate.migrate_tables(what=What.EXTERNAL_SYNC) assert "SYNC command failed to migrate" in caplog.text -@retried(on=[NotFound], timeout=timedelta(minutes=5)) -def test_revert_migrated_table( - ws, sql_backend, inventory_schema, make_schema, make_table, make_catalog -): # pylint: disable=too-many-locals +@retried(on=[NotFound], timeout=timedelta(minutes=2)) +def test_revert_migrated_table(sql_backend, runtime_ctx, make_schema, make_catalog): src_schema1 = make_schema(catalog_name="hive_metastore") src_schema2 = make_schema(catalog_name="hive_metastore") - table_to_revert = make_table(schema_name=src_schema1.name) - table_not_migrated = make_table(schema_name=src_schema1.name) - table_to_not_revert = make_table(schema_name=src_schema2.name) - all_tables = [table_to_revert, table_not_migrated, table_to_not_revert] + table_to_revert = runtime_ctx.make_table(schema_name=src_schema1.name) + table_not_migrated = runtime_ctx.make_table(schema_name=src_schema1.name) + table_to_not_revert = runtime_ctx.make_table(schema_name=src_schema2.name) dst_catalog = make_catalog() dst_schema1 = make_schema(catalog_name=dst_catalog.name, name=src_schema1.name) dst_schema2 = make_schema(catalog_name=dst_catalog.name, name=src_schema2.name) - table_crawler = StaticTablesCrawler(sql_backend, inventory_schema, all_tables) - udf_crawler = StaticUdfsCrawler(sql_backend, inventory_schema, []) - grant_crawler = GrantsCrawler(table_crawler, udf_crawler) - rules = [ Rule( "workspace", @@ -340,31 +234,20 @@ def test_revert_migrated_table( table_to_not_revert.name, ), ] - table_mapping = StaticTableMapping(ws, sql_backend, rules=rules) - group_manager = GroupManager(sql_backend, ws, inventory_schema) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler) - principal_grants = principal_acl(ws, inventory_schema, sql_backend) - table_migrate = TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - table_mapping, - group_manager, - migration_status_refresher, - principal_grants, - ) - table_migrate.migrate_tables(what=What.DBFS_ROOT_DELTA) + runtime_ctx.with_table_mapping_rules(rules) + runtime_ctx.with_dummy_azure_resource_permission() + + runtime_ctx.tables_migrator.migrate_tables(what=What.DBFS_ROOT_DELTA) - table_migrate.revert_migrated_tables(src_schema1.name, delete_managed=True) + runtime_ctx.tables_migrator.revert_migrated_tables(src_schema1.name, delete_managed=True) # Checking that two of the tables were reverted and one was left intact. # The first two table belongs to schema 1 and should have not "upgraded_to" property - assert not table_migrate.is_migrated(table_to_revert.schema_name, table_to_revert.name) + assert not runtime_ctx.tables_migrator.is_migrated(table_to_revert.schema_name, table_to_revert.name) # The second table didn't have the "upgraded_to" property set and should remain that way. - assert not table_migrate.is_migrated(table_not_migrated.schema_name, table_not_migrated.name) + assert not runtime_ctx.tables_migrator.is_migrated(table_not_migrated.schema_name, table_not_migrated.name) # The third table belongs to schema2 and had the "upgraded_to" property set and should remain that way. - assert table_migrate.is_migrated(table_to_not_revert.schema_name, table_to_not_revert.name) + assert runtime_ctx.tables_migrator.is_migrated(table_to_not_revert.schema_name, table_to_not_revert.name) target_tables_schema1 = list(sql_backend.fetch(f"SHOW TABLES IN {dst_schema1.full_name}")) assert len(target_tables_schema1) == 0 @@ -432,49 +315,25 @@ def test_mapping_skips_tables_databases(ws, sql_backend, inventory_schema, make_ assert len(table_mapping.get_tables_to_migrate(table_crawler)) == 1 -@retried(on=[NotFound], timeout=timedelta(minutes=5)) -def test_mapping_reverts_table( - ws, sql_backend, inventory_schema, make_schema, make_table, make_catalog -): # pylint: disable=too-many-locals +@retried(on=[NotFound], timeout=timedelta(minutes=2)) +def test_mapping_reverts_table(ws, sql_backend, runtime_ctx, make_schema, make_catalog): src_schema = make_schema(catalog_name="hive_metastore") - table_to_revert = make_table(schema_name=src_schema.name) - table_to_skip = make_table(schema_name=src_schema.name) - all_tables = [ - table_to_revert, - table_to_skip, - ] + table_to_revert = runtime_ctx.make_table(schema_name=src_schema.name) + table_to_skip = runtime_ctx.make_table(schema_name=src_schema.name) dst_catalog = make_catalog() dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) - table_crawler = StaticTablesCrawler(sql_backend, inventory_schema, all_tables) - udf_crawler = StaticUdfsCrawler(sql_backend, inventory_schema, []) - grant_crawler = GrantsCrawler(table_crawler, udf_crawler) - rules = [ - Rule( - "workspace", - dst_catalog.name, - src_schema.name, - dst_schema.name, - table_to_skip.name, - table_to_skip.name, - ), - ] - table_mapping = StaticTableMapping(ws, sql_backend, rules=rules) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler) - group_manager = GroupManager(sql_backend, ws, inventory_schema) - principal_grants = principal_acl(ws, inventory_schema, sql_backend) - table_migrate = TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - table_mapping, - group_manager, - migration_status_refresher, - principal_grants, + runtime_ctx.with_dummy_azure_resource_permission() + runtime_ctx.with_table_mapping_rule( + catalog_name=dst_catalog.name, + src_schema=src_schema.name, + dst_schema=dst_schema.name, + src_table=table_to_skip.name, + dst_table=table_to_skip.name, ) - table_migrate.migrate_tables(what=What.DBFS_ROOT_DELTA) + + runtime_ctx.tables_migrator.migrate_tables(what=What.DBFS_ROOT_DELTA) target_table_properties = ws.tables.get(f"{dst_schema.full_name}.{table_to_skip.name}").properties assert target_table_properties["upgraded_from"] == table_to_skip.full_name @@ -507,7 +366,7 @@ def test_mapping_reverts_table( ), ] table_mapping2 = StaticTableMapping(ws, sql_backend, rules=rules2) - mapping2 = table_mapping2.get_tables_to_migrate(table_crawler) + mapping2 = table_mapping2.get_tables_to_migrate(runtime_ctx.tables_crawler) # Checking to validate that table_to_skip was omitted from the list of rules assert len(mapping2) == 1 @@ -525,28 +384,31 @@ def test_mapping_reverts_table( assert "upgraded_to" not in results2 -@retried(on=[NotFound], timeout=timedelta(minutes=3)) -def test_migrate_managed_tables_with_acl( - ws, sql_backend, inventory_schema, make_catalog, make_schema, make_table, make_user -): # pylint: disable=too-many-locals +@retried(on=[NotFound], timeout=timedelta(minutes=2)) +def test_migrate_managed_tables_with_acl(ws, sql_backend, runtime_ctx, make_catalog, make_schema, make_user): if not ws.config.is_azure: pytest.skip("temporary: only works in azure test env") src_schema = make_schema(catalog_name="hive_metastore") - src_managed_table = make_table(catalog_name=src_schema.catalog_name, schema_name=src_schema.name) + src_managed_table = runtime_ctx.make_table(catalog_name=src_schema.catalog_name, schema_name=src_schema.name) user = make_user() - src_grant = [ - Grant(principal=user.user_name, action_type="SELECT", table=src_managed_table.name, database=src_schema.name), - Grant(principal=user.user_name, action_type="MODIFY", table=src_managed_table.name, database=src_schema.name), - ] + + runtime_ctx.make_grant( + principal=user.user_name, + action_type="SELECT", + table=src_managed_table.name, + database=src_schema.name, + ) + runtime_ctx.make_grant( + principal=user.user_name, + action_type="MODIFY", + table=src_managed_table.name, + database=src_schema.name, + ) dst_catalog = make_catalog() dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) - logger.info(f"dst_catalog={dst_catalog.name}, managed_table={src_managed_table.full_name}") - table_crawler = StaticTablesCrawler(sql_backend, inventory_schema, [src_managed_table]) - udf_crawler = StaticUdfsCrawler(sql_backend, inventory_schema, []) - grant_crawler = StaticGrantsCrawler(table_crawler, udf_crawler, src_grant) rules = [ Rule( "workspace", @@ -557,51 +419,10 @@ def test_migrate_managed_tables_with_acl( src_managed_table.name, ), ] - table_mapping = StaticTableMapping(ws, sql_backend, rules=rules) - group_manager = GroupManager(sql_backend, ws, inventory_schema) - migration_status_refresher = MigrationStatusRefresher(ws, sql_backend, inventory_schema, table_crawler) - installation = MockInstallation( - { - "config.yml": { - 'inventory_database': inventory_schema, - }, - "azure_storage_account_info.csv": [ - { - 'prefix': 'dummy_prefix', - 'client_id': 'dummy_application_id', - 'principal': 'dummy_principal', - 'privilege': 'WRITE_FILES', - 'type': 'Application', - 'directory_id': 'dummy_directory', - } - ], - } - ) - principal_grants = PrincipalACL( - ws, - sql_backend, - installation, - StaticTablesCrawler(sql_backend, inventory_schema, [src_managed_table]), - StaticMountCrawler( - [Mount('dummy_mount', 'abfss://dummy@dummy.dfs.core.windows.net/a')], - sql_backend, - ws, - inventory_schema, - ), - AzureACL.for_cli(ws, installation).get_eligible_locations_principals(), - ) - table_migrate = TablesMigrator( - table_crawler, - grant_crawler, - ws, - sql_backend, - table_mapping, - group_manager, - migration_status_refresher, - principal_grants, - ) + runtime_ctx.with_table_mapping_rules(rules) + runtime_ctx.with_dummy_azure_resource_permission() - table_migrate.migrate_tables(what=What.DBFS_ROOT_DELTA, acl_strategy=[AclMigrationWhat.LEGACY_TACL]) + runtime_ctx.tables_migrator.migrate_tables(what=What.DBFS_ROOT_DELTA, acl_strategy=[AclMigrationWhat.LEGACY_TACL]) target_tables = list(sql_backend.fetch(f"SHOW TABLES IN {dst_schema.full_name}")) assert len(target_tables) == 1 @@ -614,24 +435,16 @@ def test_migrate_managed_tables_with_acl( assert target_table_grants.privilege_assignments[0].privileges == [Privilege.MODIFY, Privilege.SELECT] -@pytest.fixture() -def test_prepare_principal_acl( - ws, - sql_backend, - inventory_schema, - env_or_skip, - make_dbfs_data_copy, - make_table, - make_catalog, - make_schema, - make_cluster, -): +@pytest.fixture +def prepared_principal_acl(runtime_ctx, env_or_skip, make_dbfs_data_copy, make_catalog, make_schema, make_cluster): cluster = make_cluster(single_node=True, spark_conf=_SPARK_CONF, data_security_mode=DataSecurityMode.NONE) - new_mounted_location = f'dbfs:/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a/b/{inventory_schema}' + new_mounted_location = f'dbfs:/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a/b/{runtime_ctx.inventory_database}' make_dbfs_data_copy(src_path=f'dbfs:/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a/b/c', dst_path=new_mounted_location) src_schema = make_schema(catalog_name="hive_metastore") - src_external_table = make_table( - catalog_name=src_schema.catalog_name, schema_name=src_schema.name, external_csv=new_mounted_location + src_external_table = runtime_ctx.make_table( + catalog_name=src_schema.catalog_name, + schema_name=src_schema.name, + external_csv=new_mounted_location, ) dst_catalog = make_catalog() dst_schema = make_schema(catalog_name=dst_catalog.name, name=src_schema.name) @@ -645,72 +458,25 @@ def test_prepare_principal_acl( src_external_table.name, ), ] - installation = MockInstallation( - { - "config.yml": { - 'warehouse_id': env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"), - 'inventory_database': inventory_schema, - }, - "azure_storage_account_info.csv": [ - { - 'prefix': 'abfss://things@labsazurethings.dfs.core.windows.net', - 'client_id': 'dummy_application_id', - 'principal': 'principal_1', - 'privilege': 'WRITE_FILES', - 'type': 'Application', - 'directory_id': 'directory_id_ss1', - } - ], - } + runtime_ctx.with_table_mapping_rules(rules) + runtime_ctx.with_dummy_azure_resource_permission() + return ( + runtime_ctx.tables_migrator, + f"{dst_catalog.name}.{dst_schema.name}.{src_external_table.name}", + cluster.cluster_id, ) - principal_grants = PrincipalACL( - ws, - sql_backend, - installation, - StaticTablesCrawler(sql_backend, inventory_schema, [src_external_table]), - StaticMountCrawler( - [ - Mount( - f'/mnt/{env_or_skip("TEST_MOUNT_NAME")}/a', 'abfss://things@labsazurethings.dfs.core.windows.net/a' - ) - ], - sql_backend, - ws, - inventory_schema, - ), - AzureACL.for_cli(ws, installation).get_eligible_locations_principals(), - ) - table_migrate = TablesMigrator( - StaticTablesCrawler(sql_backend, inventory_schema, [src_external_table]), - StaticGrantsCrawler( - StaticTablesCrawler(sql_backend, inventory_schema, [src_external_table]), - StaticUdfsCrawler(sql_backend, inventory_schema, []), - [], - ), - ws, - sql_backend, - StaticTableMapping(ws, sql_backend, rules=rules), - GroupManager(sql_backend, ws, inventory_schema), - MigrationStatusRefresher( - ws, sql_backend, inventory_schema, StaticTablesCrawler(sql_backend, inventory_schema, [src_external_table]) - ), - principal_grants, - ) - return table_migrate, f"{dst_catalog.name}.{dst_schema.name}.{src_external_table.name}", cluster.cluster_id - -@retried(on=[NotFound], timeout=timedelta(minutes=3)) +@retried(on=[NotFound], timeout=timedelta(minutes=2)) def test_migrate_managed_tables_with_principal_acl_azure( ws, make_user, - test_prepare_principal_acl, + prepared_principal_acl, make_cluster_permissions, - make_cluster, ): if not ws.config.is_azure: pytest.skip("temporary: only works in azure test env") - table_migrate, table_full_name, cluster_id = test_prepare_principal_acl + table_migrate, table_full_name, cluster_id = prepared_principal_acl user = make_user() make_cluster_permissions( object_id=cluster_id, diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index eb0fe2e36a..2f06b36226 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -35,6 +35,7 @@ DeployedWorkflows, WorkflowsDeployment, ) +from databricks.labs.ucx.runtime import Workflows from databricks.labs.ucx.workspace_access import redash from databricks.labs.ucx.workspace_access.generic import ( GenericPermissionsSupport, @@ -95,6 +96,7 @@ def factory( [ functools.partial(ws.clusters.ensure_cluster_is_running, default_cluster_id), functools.partial(ws.clusters.ensure_cluster_is_running, tacl_cluster_id), + functools.partial(ws.clusters.ensure_cluster_is_running, table_migration_cluster_id), ], ) @@ -114,6 +116,9 @@ def factory( installation.save(workspace_config) + # TODO: inject the smallest number of tasks possible for a workflow, to speed up installation in tests + tasks = Workflows.all().tasks() + # TODO: see if we want to move building wheel as a context manager for yield factory, # so that we can shave off couple of seconds and build wheel only once per session # instead of every test @@ -125,6 +130,7 @@ def factory( product_info.wheels(ws), product_info, timedelta(minutes=3), + tasks, ) workspace_installation = WorkspaceInstallation( workspace_config, @@ -267,7 +273,7 @@ def test_new_job_cluster_with_policy_assessment( assert before[ws_group_a.display_name] == PermissionLevel.CAN_USE -@retried(on=[NotFound, InvalidParameterValue, TimeoutError], timeout=timedelta(minutes=5)) +@retried(on=[NotFound, InvalidParameterValue], timeout=timedelta(minutes=5)) def test_running_real_assessment_job( ws, new_installation, make_ucx_group, make_cluster_policy, make_cluster_policy_permissions ): diff --git a/tests/performance/test_performance.py b/tests/performance/test_performance.py index 1585fd7877..1b3225460d 100644 --- a/tests/performance/test_performance.py +++ b/tests/performance/test_performance.py @@ -11,9 +11,9 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.service import iam +from databricks.labs.ucx.contexts.cli_command import WorkspaceContext from databricks.labs.ucx.workspace_access.base import Permissions from databricks.labs.ucx.workspace_access.groups import MigratedGroup, MigrationState -from databricks.labs.ucx.workspace_access.manager import PermissionManager logger = logging.getLogger(__name__) @@ -75,8 +75,8 @@ def test_apply_group_permissions_experimental_performance( logger.info(f"Migration using experimental API takes {process_time() - start}s") start = process_time() - permission_manager = PermissionManager.factory(ws, sql_backend, inventory_schema) - permission_manager.apply_group_permissions(MigrationState([migrated_group])) + ctx = WorkspaceContext(ws).replace(inventory_schema=inventory_schema, sql_backend=sql_backend) + ctx.permission_manager.apply_group_permissions(MigrationState([migrated_group])) logger.info(f"Migration using normal approach takes {process_time() - start}s") diff --git a/tests/unit/assessment/test_workflows.py b/tests/unit/assessment/test_workflows.py new file mode 100644 index 0000000000..830586426f --- /dev/null +++ b/tests/unit/assessment/test_workflows.py @@ -0,0 +1,86 @@ +from unittest.mock import create_autospec + +import pytest +from databricks.labs.lsql.backends import SqlBackend + +from databricks.labs.ucx.assessment.workflows import Assessment, DestroySchema + + +def test_assess_azure_service_principals(run_workflow): + sql_backend = create_autospec(SqlBackend) + sql_backend.fetch.return_value = [ + ["1", "secret_scope", "secret_key", "tenant_id", "storage_account"], + ] + run_workflow(Assessment.assess_azure_service_principals, sql_backend=sql_backend) + + +def test_runtime_workspace_listing(run_workflow): + ctx = run_workflow(Assessment.workspace_listing) + assert "SELECT * FROM ucx.workspace_objects" in ctx.sql_backend.queries + + +def test_runtime_crawl_grants(run_workflow): + ctx = run_workflow(Assessment.crawl_grants) + assert "SELECT * FROM hive_metastore.ucx.grants" in ctx.sql_backend.queries + + +@pytest.mark.skip("crawl_permissions fails, filed GH issue #1129") +def test_runtime_crawl_permissions(run_workflow): + ctx = run_workflow(Assessment.crawl_permissions) + assert "SELECT * FROM hive_metastore.ucx.permissions" in ctx.sql_backend.queries + + +def test_runtime_crawl_groups(run_workflow): + ctx = run_workflow(Assessment.crawl_groups) + assert "SELECT * FROM hive_metastore.ucx.groups" in ctx.sql_backend.queries + + +def test_runtime_crawl_cluster_policies(run_workflow): + ctx = run_workflow(Assessment.crawl_cluster_policies) + assert "SELECT * FROM ucx.policies" in ctx.sql_backend.queries + + +def test_runtime_crawl_init_scripts(run_workflow): + ctx = run_workflow(Assessment.assess_global_init_scripts) + assert "SELECT * FROM ucx.global_init_scripts" in ctx.sql_backend.queries + + +def test_estimate_table_size_for_migration(run_workflow): + ctx = run_workflow(Assessment.estimate_table_size_for_migration) + assert "SELECT * FROM hive_metastore.ucx.table_size" in ctx.sql_backend.queries + assert "SHOW DATABASES" in ctx.sql_backend.queries + + +def test_runtime_mounts(run_workflow): + ctx = run_workflow(Assessment.crawl_mounts) + assert "SELECT * FROM ucx.mounts" in ctx.sql_backend.queries + + +def test_guess_external_locations(run_workflow): + ctx = run_workflow(Assessment.guess_external_locations) + assert "SELECT * FROM ucx.mounts" in ctx.sql_backend.queries + + +def test_assess_jobs(run_workflow): + ctx = run_workflow(Assessment.assess_jobs) + assert "SELECT * FROM ucx.jobs" in ctx.sql_backend.queries + + +def test_assess_clusters(run_workflow): + ctx = run_workflow(Assessment.assess_clusters) + assert "SELECT * FROM ucx.clusters" in ctx.sql_backend.queries + + +def test_assess_pipelines(run_workflow): + ctx = run_workflow(Assessment.assess_pipelines) + assert "SELECT * FROM ucx.pipelines" in ctx.sql_backend.queries + + +def test_incompatible_submit_runs(run_workflow): + ctx = run_workflow(Assessment.assess_incompatible_submit_runs) + assert "SELECT * FROM ucx.submit_runs" in ctx.sql_backend.queries + + +def test_runtime_destroy_schema(run_workflow): + ctx = run_workflow(DestroySchema.destroy_schema) + assert "DROP DATABASE ucx CASCADE" in ctx.sql_backend.queries diff --git a/tests/unit/aws/test_credentials.py b/tests/unit/aws/test_credentials.py index 62c487a82a..482ed2395e 100644 --- a/tests/unit/aws/test_credentials.py +++ b/tests/unit/aws/test_credentials.py @@ -14,7 +14,7 @@ StorageCredentialInfo, ) -from databricks.labs.ucx.assessment.aws import AWSResources, AWSRoleAction +from databricks.labs.ucx.assessment.aws import AWSRoleAction from databricks.labs.ucx.aws.access import AWSResourcePermissions from databricks.labs.ucx.aws.credentials import CredentialManager, IamRoleMigration from tests.unit import DEFAULT_CONFIG @@ -103,41 +103,6 @@ def generate_instance_profiles(num_instance_profiles: int): return generate_instance_profiles -def test_for_cli_not_aws(caplog, ws, installation): - ws.config.is_aws = False - with pytest.raises(SystemExit): - aws = create_autospec(AWSResources) - IamRoleMigration.for_cli(ws, installation, aws, MockPrompts({})) - assert "Workspace is not on AWS, please run this command on a Databricks on AWS workspaces." in caplog.text - - -def test_for_cli_not_prompts(ws, installation): - ws.config.is_aws = True - prompts = MockPrompts( - { - f"Have you reviewed the {AWSResourcePermissions.UC_ROLES_FILE_NAMES} " - "and confirm listed IAM roles to be migrated*": "No" - } - ) - with pytest.raises(SystemExit): - aws = create_autospec(AWSResources) - IamRoleMigration.for_cli(ws, installation, aws, prompts) - - -def test_for_cli(ws, installation): - ws.config.is_aws = True - prompts = MockPrompts( - { - f"Have you reviewed the {AWSResourcePermissions.UC_ROLES_FILE_NAMES} " - "and confirm listed IAM roles to be migrated*": "Yes" - } - ) - aws = create_autospec(AWSResources) - aws.validate_connection.return_value = {"Account": "123456789012"} - - assert isinstance(IamRoleMigration.for_cli(ws, installation, aws, prompts), IamRoleMigration) - - def test_print_action_plan(caplog, ws, instance_profile_migration, credential_manager): caplog.set_level(logging.INFO) diff --git a/tests/unit/azure/test_credentials.py b/tests/unit/azure/test_credentials.py index da0be770cb..fdffd9e271 100644 --- a/tests/unit/azure/test_credentials.py +++ b/tests/unit/azure/test_credentials.py @@ -268,21 +268,6 @@ def sp_migration(ws, installation, credential_manager): return ServicePrincipalMigration(installation, ws, arp, sp_crawler, credential_manager) -def test_for_cli_not_prompts(ws, installation): - ws.config.is_azure = True - prompts = MockPrompts({"Have you reviewed the azure_storage_account_info.csv *": "No"}) - with pytest.raises(SystemExit): - ServicePrincipalMigration.for_cli(ws, installation, prompts) - - -def test_for_cli(ws, installation): - ws.config.is_azure = True - ws.config.auth_type = "azure-cli" - prompts = MockPrompts({"Have you reviewed the azure_storage_account_info.csv *": "Yes"}) - - assert isinstance(ServicePrincipalMigration.for_cli(ws, installation, prompts), ServicePrincipalMigration) - - @pytest.mark.parametrize( "secret_bytes_value, num_migrated", [ diff --git a/tests/unit/azure/test_locations.py b/tests/unit/azure/test_locations.py index 8bc024281f..abb6d4bf17 100644 --- a/tests/unit/azure/test_locations.py +++ b/tests/unit/azure/test_locations.py @@ -462,29 +462,3 @@ def test_corner_cases_with_missing_fields(ws, caplog, mocker): ws.external_locations.create.assert_not_called() assert "External locations below are not created in UC." in caplog.text - - -def test_for_cli(ws): - mock_installation = MockInstallation( - { - "config.yml": { - 'version': 2, - 'inventory_database': 'test', - 'connect': { - 'host': 'test', - 'token': 'test', - }, - }, - "azure_storage_account_info.csv": [ - { - 'prefix': 'dummy', - 'client_id': 'dummy', - 'principal': 'dummy', - 'privilege': 'WRITE_FILES', - 'type': 'Application', - 'directory_id': 'dummy', - }, - ], - } - ) - assert isinstance(ExternalLocationsMigration.for_cli(ws, mock_installation), ExternalLocationsMigration) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d1f5db08b6..98fc436e97 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,3 +1,71 @@ +import os +import sys +import threading +from unittest.mock import patch, create_autospec + import pytest +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.lsql.backends import MockBackend +from databricks.sdk import WorkspaceClient + +from databricks.labs.ucx.config import WorkspaceConfig +from databricks.labs.ucx.contexts.workflow_task import RuntimeContext pytest.register_assert_rewrite('databricks.labs.blueprint.installation') + +# Lock to prevent concurrent execution of tests that patch the environment +_lock = threading.Lock() + + +def mock_installation() -> MockInstallation: + return MockInstallation( + { + 'config.yml': { + 'connect': { + 'host': 'adb-9999999999999999.14.azuredatabricks.net', + 'token': '...', + }, + 'inventory_database': 'ucx', + 'warehouse_id': 'abc', + }, + 'mapping.csv': [ + { + 'catalog_name': 'catalog', + 'dst_schema': 'schema', + 'dst_table': 'table', + 'src_schema': 'schema', + 'src_table': 'table', + 'workspace_name': 'workspace', + }, + ], + } + ) + + +@pytest.fixture +def run_workflow(mocker): + def inner(cb, **replace) -> RuntimeContext: + with _lock, patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = mocker.Mock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + installation = mock_installation() + if 'installation' not in replace: + replace['installation'] = installation + if 'workspace_client' not in replace: + replace['workspace_client'] = create_autospec(WorkspaceClient) + if 'sql_backend' not in replace: + replace['sql_backend'] = MockBackend() + if 'config' not in replace: + replace['config'] = installation.load(WorkspaceConfig) + + module = __import__(cb.__module__, fromlist=[cb.__name__]) + klass, method = cb.__qualname__.split('.', 1) + workflow = getattr(module, klass)() + current_task = getattr(workflow, method) + + ctx = RuntimeContext().replace(**replace) + current_task(ctx) + + return ctx + + yield inner diff --git a/tests/unit/framework/test_tasks.py b/tests/unit/framework/test_tasks.py index bc40fe4bca..357a238fd9 100644 --- a/tests/unit/framework/test_tasks.py +++ b/tests/unit/framework/test_tasks.py @@ -1,21 +1,14 @@ import logging -import shutil -from pathlib import Path from unittest.mock import create_autospec import pytest -from databricks.labs.blueprint.installation import MockInstallation -from databricks.labs.lsql.backends import RuntimeBackend from databricks.sdk import WorkspaceClient -from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.framework.tasks import ( Task, TaskLogger, parse_args, remove_extra_indentation, - run_task, - task, ) @@ -40,9 +33,9 @@ def test_task_cloud(): ws.config.is_gcp = False tasks = [ - Task(task_id=0, workflow="wl_1", name="n3", doc="d3", fn=lambda: None, cloud="aws"), - Task(task_id=1, workflow="wl_2", name="n2", doc="d2", fn=lambda: None, cloud="azure"), - Task(task_id=2, workflow="wl_1", name="n1", doc="d1", fn=lambda: None, cloud="gcp"), + Task(workflow="wl_1", name="n3", doc="d3", fn=lambda: None, cloud="aws"), + Task(workflow="wl_2", name="n2", doc="d2", fn=lambda: None, cloud="azure"), + Task(workflow="wl_1", name="n1", doc="d1", fn=lambda: None, cloud="gcp"), ] filter_tasks = sorted([t.name for t in tasks if t.cloud_compatible(ws.config)]) @@ -91,40 +84,3 @@ def test_parse_args(): assert args["task"] == "test" with pytest.raises(KeyError): parse_args("--foo=bar") - - -def test_run_task(capsys): - # mock a task function to be tested - @task("migrate-tables", job_cluster="migration_sync") - def mock_migrate_external_tables_sync(cfg, workspace_client, sql_backend, installation): - """This mock task of migrate-tables""" - return f"Hello, World! {cfg} {workspace_client} {sql_backend} {installation}" - - args = parse_args("--config=foo", "--task=mock_migrate_external_tables_sync", "--parent_run_id=abc", "--job_id=123") - cfg = WorkspaceConfig("test_db", log_level="INFO") - - # test the task function is called - install_dir = Path("foo") - run_task( - args, - install_dir, - cfg, - create_autospec(WorkspaceClient), - create_autospec(RuntimeBackend), - MockInstallation(), - ) - # clean up the log folder created by TaskLogger - shutil.rmtree(install_dir) - - assert "This mock task of migrate-tables" in capsys.readouterr().out - - # test KeyError if task not found - with pytest.raises(KeyError): - run_task( - parse_args("--config=foo", "--task=not_found"), - Path("foo"), - cfg, - create_autospec(WorkspaceClient), - create_autospec(RuntimeBackend), - MockInstallation(), - ) diff --git a/tests/unit/hive_metastore/test_catalog_schema.py b/tests/unit/hive_metastore/test_catalog_schema.py index 59cd38cf95..73034193e8 100644 --- a/tests/unit/hive_metastore/test_catalog_schema.py +++ b/tests/unit/hive_metastore/test_catalog_schema.py @@ -89,21 +89,3 @@ def test_no_catalog_storage(): catalog_schema = prepare_test(ws) catalog_schema.create_all_catalogs_schemas(mock_prompts) ws.catalogs.create.assert_called_once_with("catalog2", comment="Created by UCX") - - -def test_for_cli(): - ws = create_autospec(WorkspaceClient) - installation = MockInstallation( - { - "config.yml": { - 'version': 2, - 'inventory_database': 'test', - 'connect': { - 'host': 'test', - 'token': 'test', - }, - } - } - ) - catalog_schema = CatalogSchema.for_cli(ws, installation) - assert isinstance(catalog_schema, CatalogSchema) diff --git a/tests/unit/hive_metastore/test_principal_grants.py b/tests/unit/hive_metastore/test_principal_grants.py index e485b34897..e7c03f090e 100644 --- a/tests/unit/hive_metastore/test_principal_grants.py +++ b/tests/unit/hive_metastore/test_principal_grants.py @@ -13,12 +13,10 @@ AzureServicePrincipalInfo, ServicePrincipalClusterMapping, ) -from databricks.labs.ucx.azure.access import AzureResourcePermissions -from databricks.labs.ucx.azure.resources import AzureAPIClient, AzureResources from databricks.labs.ucx.config import WorkspaceConfig from databricks.labs.ucx.hive_metastore import Mounts, TablesCrawler from databricks.labs.ucx.hive_metastore.grants import AzureACL, Grant, PrincipalACL -from databricks.labs.ucx.hive_metastore.locations import ExternalLocations, Mount +from databricks.labs.ucx.hive_metastore.locations import Mount from databricks.labs.ucx.hive_metastore.tables import Table @@ -63,17 +61,9 @@ def ws(): def azure_acl(w, install, cluster_spn: list): config = install.load(WorkspaceConfig) sql_backend = StatementExecutionBackend(w, config.warehouse_id) - locations = create_autospec(ExternalLocations) - azure_client = AzureAPIClient( - w.config.arm_environment.resource_manager_endpoint, - w.config.arm_environment.service_management_endpoint, - ) - graph_client = AzureAPIClient("https://graph.microsoft.com", "https://graph.microsoft.com") - azurerm = AzureResources(azure_client, graph_client) - resource_permissions = AzureResourcePermissions(install, w, azurerm, locations) spn_crawler = create_autospec(AzureServicePrincipalCrawler) spn_crawler.get_cluster_to_storage_mapping.return_value = cluster_spn - return AzureACL(w, sql_backend, spn_crawler, resource_permissions) + return AzureACL(w, sql_backend, spn_crawler, install) def principal_acl(w, install, cluster_spn: list): @@ -158,31 +148,6 @@ def installation(): ) -def test_for_cli_azure_acl(ws, installation): - assert isinstance(AzureACL.for_cli(ws, installation), AzureACL) - - -def test_for_cli_azure(ws, installation): - ws.config.is_azure = True - sql_backend = StatementExecutionBackend(ws, ws.config.warehouse_id) - assert isinstance(PrincipalACL.for_cli(ws, installation, sql_backend), PrincipalACL) - - -def test_for_cli_aws(ws, installation): - ws.config.is_azure = False - ws.config.is_aws = True - sql_backend = StatementExecutionBackend(ws, ws.config.warehouse_id) - assert PrincipalACL.for_cli(ws, installation, sql_backend) is None - - -def test_for_cli_gcp(ws, installation): - ws.config.is_azure = False - ws.config.is_aws = False - ws.config.is_gcp = True - sql_backend = StatementExecutionBackend(ws, ws.config.warehouse_id) - assert PrincipalACL.for_cli(ws, installation, sql_backend) is None - - def test_get_eligible_locations_principals_no_cluster_mapping(ws, installation): locations = azure_acl(ws, installation, []) locations.get_eligible_locations_principals() diff --git a/tests/unit/hive_metastore/test_workflows.py b/tests/unit/hive_metastore/test_workflows.py new file mode 100644 index 0000000000..1b30b00a04 --- /dev/null +++ b/tests/unit/hive_metastore/test_workflows.py @@ -0,0 +1,11 @@ +from databricks.labs.ucx.hive_metastore.workflows import TableMigration + + +def test_migrate_external_tables_sync(run_workflow): + ctx = run_workflow(TableMigration.migrate_external_tables_sync) + ctx.workspace_client.catalogs.list.assert_called_once() + + +def test_migrate_dbfs_root_delta_tables(run_workflow): + ctx = run_workflow(TableMigration.migrate_dbfs_root_delta_tables) + ctx.workspace_client.catalogs.list.assert_called_once() diff --git a/tests/unit/source_code/test_files.py b/tests/unit/source_code/test_files.py index a46d4fb4c9..a0e050d9af 100644 --- a/tests/unit/source_code/test_files.py +++ b/tests/unit/source_code/test_files.py @@ -2,24 +2,22 @@ from pathlib import Path from unittest.mock import Mock, create_autospec -import pytest from databricks.sdk.service.workspace import Language -from databricks.labs.ucx.source_code.files import Files +from databricks.labs.ucx.source_code.files import LocalFileMigrator from databricks.labs.ucx.source_code.languages import Languages -from tests.unit import workspace_client_mock def test_files_fix_ignores_unsupported_extensions(): languages = create_autospec(Languages) - files = Files(languages) + files = LocalFileMigrator(languages) path = Path('unsupported.ext') assert not files.apply(path) def test_files_fix_ignores_unsupported_language(): languages = create_autospec(Languages) - files = Files(languages) + files = LocalFileMigrator(languages) files._extensions[".py"] = None # pylint: disable=protected-access path = Path('unsupported.py') assert not files.apply(path) @@ -27,7 +25,7 @@ def test_files_fix_ignores_unsupported_language(): def test_files_fix_reads_supported_extensions(): languages = create_autospec(Languages) - files = Files(languages) + files = LocalFileMigrator(languages) path = Path(__file__) assert not files.apply(path) @@ -35,7 +33,7 @@ def test_files_fix_reads_supported_extensions(): def test_files_supported_language_no_diagnostics(): languages = create_autospec(Languages) languages.linter(Language.PYTHON).lint.return_value = [] - files = Files(languages) + files = LocalFileMigrator(languages) path = Path(__file__) files.apply(path) languages.fixer.assert_not_called() @@ -45,7 +43,7 @@ def test_files_supported_language_no_fixer(): languages = create_autospec(Languages) languages.linter(Language.PYTHON).lint.return_value = [Mock(code='some-code')] languages.fixer.return_value = None - files = Files(languages) + files = LocalFileMigrator(languages) path = Path(__file__) files.apply(path) languages.fixer.assert_called_once_with(Language.PYTHON, 'some-code') @@ -55,7 +53,7 @@ def test_files_supported_language_with_fixer(): languages = create_autospec(Languages) languages.linter(Language.PYTHON).lint.return_value = [Mock(code='some-code')] languages.fixer(Language.PYTHON, 'some-code').apply.return_value = "Hi there!" - files = Files(languages) + files = LocalFileMigrator(languages) with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py") as file: file.writelines(["import tempfile"]) path = Path(file.name) @@ -67,15 +65,8 @@ def test_files_walks_directory(): languages = create_autospec(Languages) languages.linter(Language.PYTHON).lint.return_value = [Mock(code='some-code')] languages.fixer.return_value = None - files = Files(languages) - path = Path(Path(__file__).parent.parent, "aws/") + files = LocalFileMigrator(languages) + path = Path(__file__).parent files.apply(path) languages.fixer.assert_called_with(Language.PYTHON, 'some-code') assert languages.fixer.call_count > 1 - - -@pytest.mark.skip("the below is unmanageably slow when ran locally, so disabling for now, created GH issue #1127") -def test_files_for_cli(): - ws = workspace_client_mock() - clazz = Files.for_cli(ws) - assert clazz is not None diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 1ee4ccb775..78cbb11a05 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -1,6 +1,5 @@ import io import json -import subprocess from unittest.mock import create_autospec, patch import pytest @@ -41,6 +40,7 @@ validate_groups_membership, workflows, ) +from databricks.labs.ucx.contexts.cli_command import WorkspaceContext @pytest.fixture @@ -200,11 +200,12 @@ def test_no_step_in_repair_run(ws): def test_revert_migrated_tables(ws, caplog): # test with no schema and no table, user confirm to not retry prompts = MockPrompts({'.*': 'no'}) - assert revert_migrated_tables(ws, prompts, schema=None, table=None) is None + ctx = WorkspaceContext(ws).replace(is_azure=True, azure_cli_authenticated=True, azure_subscription_id='test') + assert revert_migrated_tables(ws, prompts, schema=None, table=None, ctx=ctx) is None # test with no schema and no table, user confirm to retry, but no ucx installation found prompts = MockPrompts({'.*': 'yes'}) - assert revert_migrated_tables(ws, prompts, schema=None, table=None) is None + assert revert_migrated_tables(ws, prompts, schema=None, table=None, ctx=ctx) is None assert 'No migrated tables were found.' in caplog.messages @@ -266,31 +267,17 @@ def test_alias(ws): ws.tables.list.assert_called_once() -def test_save_storage_and_principal_azure_no_azure_cli(ws, caplog): - ws.config.auth_type = "azure_clis" +def test_save_storage_and_principal_azure_no_azure_cli(ws): ws.config.is_azure = True - prompts = MockPrompts({}) - principal_prefix_access(ws, prompts, "") - - assert 'In order to obtain AAD token, Please run azure cli to authenticate.' in caplog.messages - - -def test_save_storage_and_principal_azure_no_subscription_id(ws, caplog): - ws.config.auth_type = "azure-cli" - ws.config.is_azure = True - - prompts = MockPrompts({}) - principal_prefix_access(ws, prompts) - - assert "Please enter subscription id to scan storage accounts in." in caplog.messages + ctx = WorkspaceContext(ws) + with pytest.raises(ValueError): + principal_prefix_access(ws, ctx=ctx) def test_save_storage_and_principal_azure(ws, caplog): - ws.config.auth_type = "azure-cli" - ws.config.is_azure = True - prompts = MockPrompts({}) azure_resource_permissions = create_autospec(AzureResourcePermissions) - principal_prefix_access(ws, prompts, subscription_id="test", azure_resource_permissions=azure_resource_permissions) + ctx = WorkspaceContext(ws).replace(is_azure=True, azure_resource_permissions=azure_resource_permissions) + principal_prefix_access(ws, ctx=ctx) azure_resource_permissions.save_spn_permissions.assert_called_once() @@ -299,113 +286,53 @@ def test_validate_groups_membership(ws): ws.groups.list.assert_called() -def test_save_storage_and_principal_aws_no_profile(ws, caplog, mocker): - mocker.patch("shutil.which", return_value="/path/aws") - ws.config.is_azure = False - ws.config.is_aws = True - prompts = MockPrompts({}) - principal_prefix_access(ws, prompts) - assert any({"AWS Profile is not specified." in message for message in caplog.messages}) - - -def test_save_storage_and_principal_aws_no_connection(ws, mocker): - mocker.patch("shutil.which", return_value="/path/aws") - pop = create_autospec(subprocess.Popen) - ws.config.is_azure = False - ws.config.is_aws = True - pop.communicate.return_value = (bytes("message", "utf-8"), bytes("error", "utf-8")) - pop.returncode = 127 - mocker.patch("subprocess.Popen.__init__", return_value=None) - mocker.patch("subprocess.Popen.__enter__", return_value=pop) - mocker.patch("subprocess.Popen.__exit__", return_value=None) - prompts = MockPrompts({}) - - with pytest.raises(ResourceWarning, match="AWS CLI is not configured properly."): - principal_prefix_access(ws, prompts, aws_profile="profile") - - -def test_save_storage_and_principal_aws_no_cli(ws, mocker, caplog): - mocker.patch("shutil.which", return_value=None) - ws.config.is_azure = False - ws.config.is_aws = True - prompts = MockPrompts({}) - principal_prefix_access(ws, prompts, aws_profile="profile") - assert any({"Couldn't find AWS" in message for message in caplog.messages}) - - -def test_save_storage_and_principal_aws(ws, mocker, caplog): - mocker.patch("shutil.which", return_value=True) - ws.config.is_azure = False - ws.config.is_aws = True +def test_save_storage_and_principal_aws(ws): aws_resource_permissions = create_autospec(AWSResourcePermissions) - prompts = MockPrompts({}) - principal_prefix_access(ws, prompts, aws_profile="profile", aws_resource_permissions=aws_resource_permissions) + ctx = WorkspaceContext(ws).replace(is_aws=True, is_azure=False, aws_resource_permissions=aws_resource_permissions) + principal_prefix_access(ws, ctx=ctx) aws_resource_permissions.save_instance_profile_permissions.assert_called_once() -def test_save_storage_and_principal_gcp(ws, caplog): - ws.config.is_azure = False - ws.config.is_aws = False - ws.config.is_gcp = True - prompts = MockPrompts({}) - principal_prefix_access(ws, prompts) - assert "This cmd is only supported for azure and aws workspaces" in caplog.messages +def test_save_storage_and_principal_gcp(ws): + ctx = WorkspaceContext(ws).replace(is_aws=False, is_azure=False) + with pytest.raises(ValueError): + principal_prefix_access(ws, ctx=ctx) def test_migrate_credentials_azure(ws): - ws.config.is_azure = True ws.workspace.upload.return_value = "test" prompts = MockPrompts({'.*': 'yes'}) - migrate_credentials(ws, prompts) + ctx = WorkspaceContext(ws).replace(is_azure=True, azure_cli_authenticated=True, azure_subscription_id='test') + migrate_credentials(ws, prompts, ctx=ctx) ws.storage_credentials.list.assert_called() def test_migrate_credentials_aws(ws, mocker): - mocker.patch("shutil.which", return_value=True) - ws.config.is_azure = False - ws.config.is_aws = True - ws.config.is_gcp = False aws_resources = create_autospec(AWSResources) aws_resources.validate_connection.return_value = {"Account": "123456789012"} prompts = MockPrompts({'.*': 'yes'}) - migrate_credentials(ws, prompts, aws_profile="profile", aws_resources=aws_resources) + ctx = WorkspaceContext(ws).replace(is_aws=True, aws_resources=aws_resources) + migrate_credentials(ws, prompts, ctx=ctx) ws.storage_credentials.list.assert_called() aws_resources.update_uc_trust_role.assert_called_once() -def test_migrate_credentials_aws_no_profile(ws, caplog, mocker): - mocker.patch("shutil.which", return_value="/path/aws") - ws.config.is_azure = False - ws.config.is_aws = True - prompts = MockPrompts({}) - migrate_credentials(ws, prompts) - assert ( - "AWS Profile is not specified. Use the environment variable [AWS_DEFAULT_PROFILE] or use the " - "'--aws-profile=[profile-name]' parameter." in caplog.messages - ) - - def test_create_master_principal_not_azure(ws): ws.config.is_azure = False + ws.config.is_aws = False prompts = MockPrompts({}) - create_uber_principal(ws, prompts, subscription_id="") - ws.workspace.get_status.assert_not_called() - - -def test_create_master_principal_no_azure_cli(ws): - ws.config.auth_type = "azure_clis" - ws.config.is_azure = True - prompts = MockPrompts({}) - create_uber_principal(ws, prompts, subscription_id="") - ws.workspace.get_status.assert_not_called() + ctx = WorkspaceContext(ws) + with pytest.raises(ValueError): + create_uber_principal(ws, prompts, ctx=ctx) def test_create_master_principal_no_subscription(ws): ws.config.auth_type = "azure-cli" ws.config.is_azure = True prompts = MockPrompts({}) - create_uber_principal(ws, prompts, subscription_id="") - ws.workspace.get_status.assert_not_called() + ctx = WorkspaceContext(ws) + with pytest.raises(ValueError): + create_uber_principal(ws, prompts, ctx=ctx, subscription_id="") def test_create_uber_principal(ws): @@ -417,38 +344,21 @@ def test_create_uber_principal(ws): def test_migrate_locations_azure(ws): - ws.config.is_azure = True - ws.config.is_aws = False - ws.config.is_gcp = False - migrate_locations(ws) + ctx = WorkspaceContext(ws).replace(is_azure=True, azure_cli_authenticated=True, azure_subscription_id='test') + migrate_locations(ws, ctx=ctx) ws.external_locations.list.assert_called() -def test_migrate_locations_aws(ws, caplog, mocker): - mocker.patch("shutil.which", return_value=True) - ws.config.is_azure = False - ws.config.is_aws = True - ws.config.is_gcp = False - migrate_locations(ws, aws_profile="profile") +def test_migrate_locations_aws(ws, caplog): + ctx = WorkspaceContext(ws).replace(is_aws=True, aws_profile="profile") + migrate_locations(ws, ctx=ctx) ws.external_locations.list.assert_called() -def test_missing_aws_cli(ws, caplog, mocker): - # Test to verify the CLI is called. Fail it intentionally to test the error message. - mocker.patch("shutil.which", return_value=None) - ws.config.is_azure = False - ws.config.is_aws = True - ws.config.is_gcp = False - migrate_locations(ws, aws_profile="profile") - assert "Couldn't find AWS CLI in path. Please install the CLI from https://aws.amazon.com/cli/" in caplog.messages - - -def test_migrate_locations_gcp(ws, caplog): - ws.config.is_azure = False - ws.config.is_aws = False - ws.config.is_gcp = True - migrate_locations(ws) - assert "migrate_locations is not yet supported in GCP" in caplog.messages +def test_migrate_locations_gcp(ws): + ctx = WorkspaceContext(ws).replace(is_aws=False, is_azure=False) + with pytest.raises(ValueError): + migrate_locations(ws, ctx=ctx) def test_create_catalogs_schemas(ws): diff --git a/tests/unit/test_factories.py b/tests/unit/test_factories.py new file mode 100644 index 0000000000..628af30268 --- /dev/null +++ b/tests/unit/test_factories.py @@ -0,0 +1,75 @@ +import base64 +from unittest.mock import create_autospec + +from databricks.labs.blueprint.installation import MockInstallation +from databricks.labs.blueprint.tui import MockPrompts +from databricks.labs.lsql.backends import MockBackend +from databricks.sdk import WorkspaceClient + +from databricks.labs.ucx.contexts.cli_command import WorkspaceContext + + +def test_replace_installation(): + ws = create_autospec(WorkspaceClient) + ws.config.auth_type = 'azure-cli' + ws.secrets.get_secret.return_value.value = base64.b64encode(b'1234').decode('utf-8') + + spn_info_rows = MockBackend.rows('application_id', 'secret_scope', 'secret_key', 'tenant_id', 'storage_account') + + mock_installation = MockInstallation( + { + 'config.yml': { + 'inventory_database': 'some', + 'warehouse_id': 'other', + 'connect': { + 'host': 'localhost', + 'token': '1234', + }, + }, + 'azure_storage_account_info.csv': [ + { + 'prefix': 'abfss://uctest@ziyuanqintest.dfs.core.windows.net/', + 'client_id': "first-application-id", + 'directory_id': 'tenant', + 'principal': "oneenv-adls", + 'privilege': "WRITE_FILES", + 'type': "Application", + }, + { + 'prefix': 'abfss://ucx2@ziyuanqintest.dfs.core.windows.net/', + 'client_id': "second-application-id", + 'principal': "ziyuan-user-assigned-mi", + 'privilege': "WRITE_FILES", + 'type': "ManagedIdentity", + }, + ], + } + ) + ctx = WorkspaceContext(ws).replace( + is_azure=True, + azure_subscription_id='foo', + installation=mock_installation, + sql_backend=MockBackend( + rows={ + r'some.azure_service_principals': spn_info_rows[ + ('first-application-id', 'foo', 'bar', 'tenant', 'ziyuanqintest'), + ('second-application-id', 'foo', 'bar', 'tenant', 'ziyuanqintest'), + ] + } + ), + ) + prompts = MockPrompts({'.*': 'yes'}) + ctx.service_principal_migration.run(prompts) + + ws.storage_credentials.create.assert_called_once() + mock_installation.assert_file_written( + 'azure_service_principal_migration_result.csv', + [ + { + 'application_id': 'first-application-id', + 'directory_id': 'tenant', + 'name': 'oneenv-adls', + 'validated_on': 'abfss://uctest@ziyuanqintest.dfs.core.windows.net/', + } + ], + ) diff --git a/tests/unit/test_install.py b/tests/unit/test_install.py index 9715b635f8..d088b7b9b1 100644 --- a/tests/unit/test_install.py +++ b/tests/unit/test_install.py @@ -65,6 +65,7 @@ DeployedWorkflows, WorkflowsDeployment, ) +from databricks.labs.ucx.runtime import Workflows PRODUCT_INFO = ProductInfo.from_class(WorkspaceConfig) @@ -215,6 +216,7 @@ def test_create_database(ws, caplog, mock_installation, any_prompt): create_autospec(WheelsV2), PRODUCT_INFO, timedelta(seconds=1), + [], ) workspace_installation = WorkspaceInstallation( @@ -248,6 +250,7 @@ def test_install_cluster_override_jobs(ws, mock_installation, any_prompt): wheels, PRODUCT_INFO, timedelta(seconds=1), + Workflows.all().tasks(), ) workflows_installation.create_jobs(any_prompt) @@ -280,6 +283,7 @@ def test_write_protected_dbfs(ws, tmp_path, mock_installation): wheels, PRODUCT_INFO, timedelta(seconds=1), + Workflows.all().tasks(), ) workflows_installation.create_jobs(prompts) @@ -318,6 +322,7 @@ def test_writeable_dbfs(ws, tmp_path, mock_installation, any_prompt): wheels, PRODUCT_INFO, timedelta(seconds=1), + Workflows.all().tasks(), ) workflows_installation.create_jobs(any_prompt) @@ -609,6 +614,7 @@ def test_main_with_existing_conf_does_not_recreate_config(ws, mocker, mock_insta create_autospec(WheelsV2), PRODUCT_INFO, timedelta(seconds=1), + [], ) workspace_installation = WorkspaceInstallation( WorkspaceConfig(inventory_database="...", policy_id='123'), @@ -678,6 +684,7 @@ def test_remove_jobs_no_state(ws): create_autospec(WheelsV2), PRODUCT_INFO, timedelta(seconds=1), + [], ) workspace_installation = WorkspaceInstallation( config, installation, install_state, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO @@ -709,6 +716,7 @@ def test_remove_jobs_with_state_missing_job(ws, caplog, mock_installation_with_j create_autospec(WheelsV2), PRODUCT_INFO, timedelta(seconds=1), + [], ) workspace_installation = WorkspaceInstallation( config, @@ -1230,6 +1238,7 @@ def test_triggering_assessment_wf(ws, mocker, mock_installation): wheels, PRODUCT_INFO, timedelta(seconds=1), + Workflows.all().tasks(), ) workspace_installation = WorkspaceInstallation( config, installation, install_state, sql_backend, ws, workflows_installer, prompts, PRODUCT_INFO @@ -1334,6 +1343,7 @@ def test_remove_jobs(ws, caplog, mock_installation_extra_jobs, any_prompt): create_autospec(WheelsV2), PRODUCT_INFO, timedelta(seconds=1), + [], ) workspace_installation = WorkspaceInstallation( @@ -1515,6 +1525,7 @@ def test_user_not_admin(ws, mock_installation, any_prompt): wheels, PRODUCT_INFO, timedelta(seconds=1), + Workflows.all().tasks(), ) with pytest.raises(PermissionError) as failure: diff --git a/tests/unit/test_runtime.py b/tests/unit/test_runtime.py deleted file mode 100644 index c6368145a2..0000000000 --- a/tests/unit/test_runtime.py +++ /dev/null @@ -1,387 +0,0 @@ -import os.path -import sys -from unittest.mock import call, create_autospec, patch - -import pytest -from databricks.labs.blueprint.installation import MockInstallation -from databricks.labs.lsql.backends import MockBackend, SqlBackend -from databricks.sdk import WorkspaceClient -from databricks.sdk.config import Config -from databricks.sdk.service.iam import PermissionMigrationResponse - -from databricks.labs.ucx.config import WorkspaceConfig -from databricks.labs.ucx.framework.tasks import ( # pylint: disable=import-private-name - _TASKS, - Task, -) -from databricks.labs.ucx.runtime import ( - apply_permissions_to_account_groups, - apply_permissions_to_account_groups_experimental, - assess_azure_service_principals, - assess_clusters, - assess_global_init_scripts, - assess_incompatible_submit_runs, - assess_jobs, - assess_pipelines, - crawl_cluster_policies, - crawl_grants, - crawl_groups, - crawl_mounts, - crawl_permissions, - delete_backup_groups, - destroy_schema, - estimate_table_size_for_migration, - guess_external_locations, - migrate_dbfs_root_delta_tables, - migrate_external_tables_sync, - reflect_account_groups_on_workspace_experimental, - rename_workspace_local_groups_experimental, - workspace_listing, -) -from tests.unit import GROUPS, PERMISSIONS - - -def azure_mock_config() -> WorkspaceConfig: - config = WorkspaceConfig( - connect=Config( - host="adb-9999999999999999.14.azuredatabricks.net", - token="dapifaketoken", - ), - inventory_database="ucx", - ) - return config - - -def mock_installation() -> MockInstallation: - return MockInstallation( - { - 'config.yml': {'warehouse_id': 'abc', 'connect': {'host': 'a', 'token': 'b'}, 'inventory_database': 'ucx'}, - 'mapping.csv': [ - { - 'catalog_name': 'catalog', - 'dst_schema': 'schema', - 'dst_table': 'table', - 'src_schema': 'schema', - 'src_table': 'table', - 'workspace_name': 'workspace', - }, - ], - } - ) - - -def test_azure_crawler(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - - ws = create_autospec(WorkspaceClient) - sql_backend = create_autospec(SqlBackend) - sql_backend.fetch.return_value = [ - ["1", "secret_scope", "secret_key", "tenant_id", "storage_account"], - ] - assess_azure_service_principals(cfg, ws, sql_backend, mock_installation()) - - -def test_tasks(): - tasks = [ - Task(task_id=0, workflow="wl_1", name="n3", doc="d3", fn=lambda: None, cloud="azure"), - Task(task_id=1, workflow="wl_2", name="n2", doc="d2", fn=lambda: None, cloud="aws"), - Task(task_id=2, workflow="wl_1", name="n1", doc="d1", fn=lambda: None, cloud="gcp"), - ] - - assert len([_ for _ in tasks if _.cloud == "azure"]) == 1 - assert len([_ for _ in tasks if _.cloud == "aws"]) == 1 - assert len([_ for _ in tasks if _.cloud == "gcp"]) == 1 - - -def test_assessment_tasks(): - """Test task decorator""" - assert len(_TASKS) >= 19 - azure = [v for k, v in _TASKS.items() if v.cloud == "azure"] - assert len(azure) >= 1 - - -def test_runtime_workspace_listing(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - workspace_listing(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.workspace_objects" in sql_backend.queries - - -def test_runtime_crawl_grants(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - crawl_grants(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM hive_metastore.ucx.grants" in sql_backend.queries - - -@pytest.mark.skip("1crawl_permissions fails, filed GH issue #1129") -def test_runtime_crawl_permissions(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - crawl_permissions(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM hive_metastore.ucx.permissions" in sql_backend.queries - - -def test_runtime_crawl_groups(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - crawl_groups(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM hive_metastore.ucx.groups" in sql_backend.queries - - -def test_runtime_crawl_cluster_policies(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - crawl_cluster_policies(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.policies" in sql_backend.queries - - -def test_runtime_crawl_init_scripts(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - assess_global_init_scripts(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.global_init_scripts" in sql_backend.queries - - -def test_estimate_table_size_for_migration(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - estimate_table_size_for_migration(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM hive_metastore.ucx.table_size" in sql_backend.queries - assert "SHOW DATABASES" in sql_backend.queries - - -def test_runtime_mounts(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - crawl_mounts(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.mounts" in sql_backend.queries - - -def test_guess_external_locations(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - guess_external_locations(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.mounts" in sql_backend.queries - - -def test_assess_jobs(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - assess_jobs(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.jobs" in sql_backend.queries - - -def test_assess_clusters(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - assess_clusters(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.clusters" in sql_backend.queries - - -def test_assess_pipelines(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - assess_pipelines(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.pipelines" in sql_backend.queries - - -def test_incompatible_submit_runs(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - assess_incompatible_submit_runs(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM ucx.submit_runs" in sql_backend.queries - - -def test_migrate_external_tables_sync(): - ws = create_autospec(WorkspaceClient) - migrate_external_tables_sync(azure_mock_config(), ws, MockBackend(), mock_installation()) - ws.catalogs.list.assert_called_once() - - -def test_migrate_dbfs_root_delta_tables(): - ws = create_autospec(WorkspaceClient) - migrate_dbfs_root_delta_tables(azure_mock_config(), ws, MockBackend(), mock_installation()) - ws.catalogs.list.assert_called_once() - - -def test_runtime_destroy_schema(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - destroy_schema(cfg, ws, sql_backend, mock_installation()) - - assert "DROP DATABASE ucx CASCADE" in sql_backend.queries - - -@pytest.mark.skip( - "smells like delete_backup_groups isn't deleting anything, but maybe that's because there's nothing to delete ?" -) -def test_runtime_delete_backup_groups(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - delete_backup_groups(cfg, ws, sql_backend, mock_installation()) - - assert "DELETE" in sql_backend.queries # TODO - - -def test_runtime_apply_permissions_to_account_groups(mocker): - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): - pyspark_sql_session = mocker.Mock() - sys.modules["pyspark.sql.session"] = pyspark_sql_session - cfg = azure_mock_config() - ws = create_autospec(WorkspaceClient) - sql_backend = MockBackend() - apply_permissions_to_account_groups(cfg, ws, sql_backend, mock_installation()) - - assert "SELECT * FROM hive_metastore.ucx.groups" in sql_backend.queries - - -def test_rename_workspace_local_group(caplog): - ws = create_autospec(WorkspaceClient) - rename_workspace_local_groups_experimental(azure_mock_config(), ws, MockBackend(), mock_installation()) - - -def test_reflect_account_groups_on_workspace(caplog): - ws = create_autospec(WorkspaceClient) - reflect_account_groups_on_workspace_experimental(azure_mock_config(), ws, MockBackend(), mock_installation()) - - -def test_migrate_permissions_experimental(): - rows = { - 'SELECT \\* FROM hive_metastore.ucx.groups': GROUPS[ - ("", "workspace_group_1", "account_group_1", "temp_1", "", "", "", ""), - ("", "workspace_group_2", "account_group_2", "temp_2", "", "", "", ""), - ("", "workspace_group_3", "account_group_3", "temp_3", "", "", "", ""), - ], - 'SELECT COUNT\\(\\*\\) as cnt FROM hive_metastore.ucx.permissions': PERMISSIONS[("123", "QUERIES", "temp")], - } - ws = create_autospec(WorkspaceClient) - ws.get_workspace_id.return_value = "12345678" - ws.permission_migration.migrate_permissions.return_value = PermissionMigrationResponse(0) - apply_permissions_to_account_groups_experimental( - azure_mock_config(), ws, MockBackend(rows=rows), mock_installation() - ) - calls = [ - call("12345678", "temp_1", "account_group_1", size=1000), - call("12345678", "temp_2", "account_group_2", size=1000), - call("12345678", "temp_3", "account_group_3", size=1000), - ] - ws.permission_migration.migrate_permissions.assert_has_calls(calls, any_order=True) - - -def test_migrate_permissions_experimental_paginated(): - rows = { - 'SELECT \\* FROM hive_metastore.ucx.groups': GROUPS[ - ("", "workspace_group_1", "account_group_1", "temp_1", "", "", "", ""), - ("", "workspace_group_2", "account_group_2", "temp_2", "", "", "", ""), - ("", "workspace_group_3", "account_group_3", "temp_3", "", "", "", ""), - ], - 'SELECT COUNT\\(\\*\\) as cnt FROM hive_metastore.ucx.permissions': PERMISSIONS[("123", "QUERIES", "temp")], - } - ws = create_autospec(WorkspaceClient) - ws.get_workspace_id.return_value = "12345678" - ws.permission_migration.migrate_permissions.side_effect = [ - PermissionMigrationResponse(i) for i in (1000, None, 1000, 10, 0, 1000, 10, 0) - ] - apply_permissions_to_account_groups_experimental( - azure_mock_config(), ws, MockBackend(rows=rows), mock_installation() - ) - calls = [ - call("12345678", "temp_1", "account_group_1", size=1000), - call("12345678", "temp_2", "account_group_2", size=1000), - call("12345678", "temp_3", "account_group_3", size=1000), - ] - ws.permission_migration.migrate_permissions.assert_has_calls(calls, any_order=True) - - -def test_migrate_permissions_experimental_error(caplog): - rows = { - 'SELECT \\* FROM hive_metastore.ucx.groups': GROUPS[ - ("", "workspace_group_1", "account_group_1", "temp_1", "", "", "", ""), - ("", "workspace_group_2", "account_group_2", "temp_2", "", "", "", ""), - ("", "workspace_group_3", "account_group_3", "temp_3", "", "", "", ""), - ], - } - ws = create_autospec(WorkspaceClient) - ws.get_workspace_id.return_value = "12345678" - ws.permission_migration.migrate_permissions.side_effect = NotImplementedError("api not enabled") - with pytest.raises(NotImplementedError): - apply_permissions_to_account_groups_experimental( - azure_mock_config(), ws, MockBackend(rows=rows), mock_installation() - ) diff --git a/tests/unit/test_workflows.py b/tests/unit/test_workflows.py new file mode 100644 index 0000000000..f43ae7511d --- /dev/null +++ b/tests/unit/test_workflows.py @@ -0,0 +1,9 @@ +from databricks.labs.ucx.runtime import Workflows + + +def test_tasks_detected(): + workflows = Workflows.all() + + tasks = workflows.tasks() + + assert len(tasks) > 1 diff --git a/tests/unit/workspace_access/test_manager.py b/tests/unit/workspace_access/test_manager.py index e9ae8421ed..a2a2c88b0e 100644 --- a/tests/unit/workspace_access/test_manager.py +++ b/tests/unit/workspace_access/test_manager.py @@ -197,46 +197,6 @@ def test_unregistered_support(): permission_manager.apply_group_permissions(migration_state=MigrationState([])) -def test_factory(mock_ws): - mock_ws.groups.list.return_value = [] - sql_backend = MockBackend() - permission_manager = PermissionManager.factory(mock_ws, sql_backend, "test") - appliers = permission_manager.object_type_support() - - assert sorted( - { - "sql/warehouses", - "registered-models", - "instance-pools", - "jobs", - "directories", - "experiments", - "clusters", - "notebooks", - "repos", - "files", - "authorization", - "pipelines", - "cluster-policies", - "dashboards", - "queries", - "alerts", - "secrets", - "entitlements", - "roles", - 'serving-endpoints', - "feature-tables", - "ANY FILE", - "FUNCTION", - "ANONYMOUS FUNCTION", - "CATALOG", - "TABLE", - "VIEW", - "DATABASE", - } - ) == sorted(appliers.keys()) - - def test_manager_verify(): sql_backend = MockBackend( rows={ diff --git a/tests/unit/workspace_access/test_workflows.py b/tests/unit/workspace_access/test_workflows.py new file mode 100644 index 0000000000..79dd63fe90 --- /dev/null +++ b/tests/unit/workspace_access/test_workflows.py @@ -0,0 +1,99 @@ +from unittest.mock import create_autospec, call + +import pytest +from databricks.labs.lsql.backends import MockBackend +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.iam import PermissionMigrationResponse + +from databricks.labs.ucx.workspace_access.workflows import ( + RemoveWorkspaceLocalGroups, + GroupMigration, + PermissionsMigrationAPI, +) +from tests.unit import GROUPS, PERMISSIONS + + +def test_runtime_delete_backup_groups(run_workflow): + ctx = run_workflow(RemoveWorkspaceLocalGroups.delete_backup_groups) + assert 'SELECT * FROM hive_metastore.ucx.groups' in ctx.sql_backend.queries + + +def test_runtime_apply_permissions_to_account_groups(run_workflow): + ctx = run_workflow(GroupMigration.apply_permissions_to_account_groups) + assert 'SELECT * FROM hive_metastore.ucx.groups' in ctx.sql_backend.queries + + +def test_rename_workspace_local_group(run_workflow): + ctx = run_workflow(GroupMigration.rename_workspace_local_groups) + assert 'SELECT * FROM hive_metastore.ucx.groups' in ctx.sql_backend.queries + + +def test_reflect_account_groups_on_workspace(run_workflow): + ctx = run_workflow(PermissionsMigrationAPI.reflect_account_groups_on_workspace) + assert 'SELECT * FROM hive_metastore.ucx.groups' in ctx.sql_backend.queries + + +def test_migrate_permissions_experimental(run_workflow): + rows = { + 'SELECT \\* FROM hive_metastore.ucx.groups': GROUPS[ + ("", "workspace_group_1", "account_group_1", "temp_1", "", "", "", ""), + ("", "workspace_group_2", "account_group_2", "temp_2", "", "", "", ""), + ("", "workspace_group_3", "account_group_3", "temp_3", "", "", "", ""), + ], + 'SELECT COUNT\\(\\*\\) as cnt FROM hive_metastore.ucx.permissions': PERMISSIONS[("123", "QUERIES", "temp")], + } + ws = create_autospec(WorkspaceClient) + ws.get_workspace_id.return_value = "12345678" + ws.permission_migration.migrate_permissions.return_value = PermissionMigrationResponse(0) + sql_backend = MockBackend(rows=rows) + + run_workflow(PermissionsMigrationAPI.apply_permissions, sql_backend=sql_backend, workspace_client=ws) + + calls = [ + call("12345678", "temp_1", "account_group_1", size=1000), + call("12345678", "temp_2", "account_group_2", size=1000), + call("12345678", "temp_3", "account_group_3", size=1000), + ] + ws.permission_migration.migrate_permissions.assert_has_calls(calls, any_order=True) + + +def test_migrate_permissions_experimental_paginated(run_workflow): + rows = { + 'SELECT \\* FROM hive_metastore.ucx.groups': GROUPS[ + ("", "workspace_group_1", "account_group_1", "temp_1", "", "", "", ""), + ("", "workspace_group_2", "account_group_2", "temp_2", "", "", "", ""), + ("", "workspace_group_3", "account_group_3", "temp_3", "", "", "", ""), + ], + 'SELECT COUNT\\(\\*\\) as cnt FROM hive_metastore.ucx.permissions': PERMISSIONS[("123", "QUERIES", "temp")], + } + ws = create_autospec(WorkspaceClient) + ws.get_workspace_id.return_value = "12345678" + ws.permission_migration.migrate_permissions.side_effect = [ + PermissionMigrationResponse(i) for i in (1000, None, 1000, 10, 0, 1000, 10, 0) + ] + sql_backend = MockBackend(rows=rows) + + run_workflow(PermissionsMigrationAPI.apply_permissions, sql_backend=sql_backend, workspace_client=ws) + + calls = [ + call("12345678", "temp_1", "account_group_1", size=1000), + call("12345678", "temp_2", "account_group_2", size=1000), + call("12345678", "temp_3", "account_group_3", size=1000), + ] + ws.permission_migration.migrate_permissions.assert_has_calls(calls, any_order=True) + + +def test_migrate_permissions_experimental_error(run_workflow): + rows = { + 'SELECT \\* FROM hive_metastore.ucx.groups': GROUPS[ + ("", "workspace_group_1", "account_group_1", "temp_1", "", "", "", ""), + ("", "workspace_group_2", "account_group_2", "temp_2", "", "", "", ""), + ("", "workspace_group_3", "account_group_3", "temp_3", "", "", "", ""), + ], + } + sql_backend = MockBackend(rows=rows) + ws = create_autospec(WorkspaceClient) + ws.get_workspace_id.return_value = "12345678" + ws.permission_migration.migrate_permissions.side_effect = NotImplementedError("api not enabled") + with pytest.raises(NotImplementedError): + run_workflow(PermissionsMigrationAPI.apply_permissions, sql_backend=sql_backend, workspace_client=ws)