Skip to content

Commit

Permalink
Add instance pool to cluster policy (#1078)
Browse files Browse the repository at this point in the history
  • Loading branch information
qziyuan authored and dmoore247 committed Mar 23, 2024
1 parent 85ada92 commit 2955fee
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 4 deletions.
30 changes: 28 additions & 2 deletions src/databricks/labs/ucx/installer/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def _policy_config(value: str):
def create(self, inventory_database: str) -> tuple[str, str, dict]:
instance_profile = ""
spark_conf_dict = {}
# get instance pool id to be put into the cluster policy
instance_pool_id = self._get_instance_pool_id()
policies_with_external_hms = list(self._get_cluster_policies_with_external_hive_metastores())
if len(policies_with_external_hms) > 0 and self._prompts.confirm(
"We have identified one or more cluster policies set up for an external metastore"
Expand Down Expand Up @@ -54,7 +56,7 @@ def create(self, inventory_database: str) -> tuple[str, str, dict]:
logger.info("Creating UCX cluster policy.")
policy_id = self._ws.cluster_policies.create(
name=policy_name,
definition=self._definition(spark_conf_dict, instance_profile),
definition=self._definition(spark_conf_dict, instance_profile, instance_pool_id),
description="Custom cluster policy for Unity Catalog Migration (UCX)",
).policy_id
assert policy_id is not None
Expand All @@ -64,11 +66,35 @@ def create(self, inventory_database: str) -> tuple[str, str, dict]:
spark_conf_dict,
)

def _definition(self, conf: dict, instance_profile: str | None) -> str:
def _get_instance_pool_id(self) -> str | None:
try:
instance_pool_id = self._prompts.question(
"Instance pool id to be set in cluster policy for all workflow clusters", default="None"
)
except OSError:
# when unit test v0.15.0_added_cluster_policy.py MockPromots cannot be injected to ClusterPolicyInstaller
# return None to pass the test
return None
if instance_pool_id.lower() == "none":
return None
try:
self._ws.instance_pools.get(instance_pool_id)
return instance_pool_id
except NotFound:
logger.warning(
f"Instance pool id {instance_pool_id} does not exist. Will not set instance pool in the cluster policy. You can manually edit the cluster policy after installation."
)
return None

def _definition(self, conf: dict, instance_profile: str | None, instance_pool_id: str | None) -> str:
policy_definition = {
"spark_version": self._policy_config(self._ws.clusters.select_spark_version(latest=True)),
"node_type_id": self._policy_config(self._ws.clusters.select_node_type(local_disk=True)),
}
if instance_pool_id:
policy_definition["instance_pool_id"] = self._policy_config(instance_pool_id)
# 'node_type_id' cannot be supplied when an instance pool ID is provided
policy_definition.pop("node_type_id")
for key, value in conf.items():
policy_definition[f"spark_conf.{key}"] = self._policy_config(value)
if self._ws.config.is_aws:
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_installation.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,13 @@ def test_check_inventory_database_exists(ws, new_installation):
assert err.value.args[0] == f"Inventory database '{inventory_database}' already exists in another installation"


@pytest.mark.skip
@retried(on=[NotFound], timeout=timedelta(minutes=10))
def test_table_migration_job( # pylint: disable=too-many-locals
ws, new_installation, make_catalog, make_schema, make_table, env_or_skip, make_random, make_dbfs_data_copy
):
# skip this test if not in nightly test job: TEST_NIGHTLY is missing or is not set to "true"
if env_or_skip("TEST_NIGHTLY").lower() != "true":
pytest.skip("TEST_NIGHTLY is not true")
# create external and managed tables to be migrated
src_schema = make_schema(catalog_name="hive_metastore")
src_managed_table = make_table(schema_name=src_schema.name)
Expand All @@ -489,6 +491,7 @@ def test_table_migration_job( # pylint: disable=too-many-locals
r"Parallelism for migrating.*": "1000",
r"Min workers for auto-scale.*": "2",
r"Max workers for auto-scale.*": "20",
r"Instance pool id to be set.*": env_or_skip("TEST_INSTANCE_POOL_ID"),
},
)
# save table mapping for migration before trigger the run
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/installer/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service import iam
from databricks.sdk.service.compute import ClusterSpec, Policy
from databricks.sdk.service.compute import ClusterSpec, GetInstancePool, Policy
from databricks.sdk.service.jobs import Job, JobCluster, JobSettings
from databricks.sdk.service.sql import (
EndpointConfPair,
Expand All @@ -34,6 +34,7 @@ def common():
{
r".*We have identified one or more cluster.*": "Yes",
r".*Choose a cluster policy.*": "0",
r".*Instance pool id to be set in cluster policy.*": "",
}
)
return w, prompts
Expand Down Expand Up @@ -252,6 +253,7 @@ def test_cluster_policy_definition_azure_hms_warehouse():
{
r".*We have identified one or more cluster.*": "No",
r".*We have identified the workspace warehouse.*": "Yes",
r".*Instance pool id to be set in cluster policy.*": "",
}
)
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
Expand Down Expand Up @@ -303,6 +305,7 @@ def test_cluster_policy_definition_aws_glue_warehouse():
{
r".*We have identified one or more cluster.*": "No",
r".*We have identified the workspace warehouse.*": "Yes",
r".*Instance pool id to be set in cluster policy.*": "",
}
)
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
Expand Down Expand Up @@ -357,6 +360,7 @@ def test_cluster_policy_definition_gcp_hms_warehouse():
{
r".*We have identified one or more cluster.*": "No",
r".*We have identified the workspace warehouse.*": "Yes",
r".*Instance pool id to be set in cluster policy.*": "",
}
)
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
Expand Down Expand Up @@ -414,3 +418,43 @@ def test_cluster_policy_definition_empty_config():
definition=json.dumps(policy_definition_actual),
description="Custom cluster policy for Unity Catalog Migration (UCX)",
)


def test_cluster_policy_instance_pool():
ws, prompts = common()
prompts = prompts.extend({r".*Instance pool id to be set in cluster policy.*": "instance_pool_1"})

ws.instance_pools.get.return_value = GetInstancePool("instance_pool_1")
ws.cluster_policies.list.return_value = []
ws.config.is_aws = True
ws.config.is_azure = False
ws.config.is_gcp = False

policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
policy_installer.create('ucx')

policy_expected = {
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
"instance_pool_id": {"type": "fixed", "value": "instance_pool_1"},
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
}
# test the instance pool is added to the cluster policy
ws.cluster_policies.create.assert_called_with(
name="Unity Catalog Migration (ucx) ([email protected])",
definition=json.dumps(policy_expected),
description="Custom cluster policy for Unity Catalog Migration (UCX)",
)

# test the instance pool is not found
ws.instance_pools.get.side_effect = NotFound()
policy_expected = {
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
}
policy_installer.create('ucx')
ws.cluster_policies.create.assert_called_with(
name="Unity Catalog Migration (ucx) ([email protected])",
definition=json.dumps(policy_expected),
description="Custom cluster policy for Unity Catalog Migration (UCX)",
)

0 comments on commit 2955fee

Please sign in to comment.