Skip to content

Commit

Permalink
Extract command codes and unify the checks for spark_conf, cluster_po…
Browse files Browse the repository at this point in the history
…licy, init_scripts (#855)
  • Loading branch information
qziyuan authored and dmoore247 committed Mar 23, 2024
1 parent bc00b18 commit 8fded13
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 92 deletions.
130 changes: 82 additions & 48 deletions src/databricks/labs/ucx/assessment/clusters.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import base64
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service.compute import ClusterDetails, ClusterSource, Policy
from databricks.sdk.service.compute import (
ClusterDetails,
ClusterSource,
InitScriptInfo,
Policy,
)

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_INIT_SCRIPT_DBFS_PATH,
INCOMPATIBLE_SPARK_CONFIG_KEYS,
_azure_sp_conf_in_init_scripts,
_azure_sp_conf_present_check,
_get_init_script_data,
logger,
spark_version_compatibility,
)
from databricks.labs.ucx.assessment.init_scripts import CheckInitScriptMixin
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class ClusterInfo:
Expand All @@ -27,7 +35,7 @@ class ClusterInfo:
creator: str | None = None


class ClustersMixin:
class CheckClusterMixin(CheckInitScriptMixin):
_ws: WorkspaceClient

def _safe_get_cluster_policy(self, policy_id: str) -> Policy | None:
Expand All @@ -37,62 +45,77 @@ def _safe_get_cluster_policy(self, policy_id: str) -> Policy | None:
logger.warning(f"The cluster policy was deleted: {policy_id}")
return None

def _check_spark_conf(self, cluster, failures):
def _check_cluster_policy(self, policy_id: str, source: str) -> list[str]:
failures: list[str] = []
policy = self._safe_get_cluster_policy(policy_id)
if policy:
if policy.definition:
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures

def _get_init_script_data(self, init_script_info: InitScriptInfo) -> str | None:
if init_script_info.dbfs is not None and init_script_info.dbfs.destination is not None:
if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
if file_api_format_destination:
try:
data = self._ws.dbfs.read(file_api_format_destination).data
if data is not None:
return base64.b64decode(data).decode("utf-8")
except NotFound:
return None
if init_script_info.workspace is not None and init_script_info.workspace.destination is not None:
workspace_file_destination = init_script_info.workspace.destination
try:
data = self._ws.workspace.export(workspace_file_destination).content
if data is not None:
return base64.b64decode(data).decode("utf-8")
except NotFound:
return None
return None

def _check_cluster_init_script(self, init_scripts: list[InitScriptInfo], source: str) -> list[str]:
failures: list[str] = []
for init_script_info in init_scripts:
init_script_data = self._get_init_script_data(init_script_info)
failures.extend(self.check_init_script(init_script_data, source))
return failures

def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
failures: list[str] = []
for k in INCOMPATIBLE_SPARK_CONFIG_KEYS:
if k in cluster.spark_conf:
if k in conf:
failures.append(f"unsupported config: {k}")
for value in cluster.spark_conf.values():
for value in conf.values():
if "dbfs:/mnt" in value or "/dbfs/mnt" in value:
failures.append(f"using DBFS mount in configuration: {value}")
# Checking if Azure cluster config is present in spark config
if _azure_sp_conf_present_check(cluster.spark_conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
if _azure_sp_conf_present_check(conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures

def _check_cluster_policy(self, cluster, failures):
policy = self._safe_get_cluster_policy(cluster.policy_id)
if policy:
if policy.definition:
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
failures: list[str] = []

def _check_init_scripts(self, cluster, failures):
for init_script_info in cluster.init_scripts:
init_script_data = _get_init_script_data(self._ws, init_script_info)
if not init_script_data:
continue
if not _azure_sp_conf_in_init_scripts(init_script_data):
continue
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

def _check_cluster_failures(self, cluster: ClusterDetails):
failures = []
cluster_info = ClusterInfo(
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
cluster_name=cluster.cluster_name,
creator=cluster.creator_user_name,
success=1,
failures="[]",
)
support_status = spark_version_compatibility(cluster.spark_version)
if support_status != "supported":
failures.append(f"not supported DBR: {cluster.spark_version}")
if cluster.spark_conf is not None:
self._check_spark_conf(cluster, failures)
failures.extend(self.check_spark_conf(cluster.spark_conf, source))
# Checking if Azure cluster config is present in cluster policies
if cluster.policy_id:
self._check_cluster_policy(cluster, failures)
if cluster.init_scripts:
self._check_init_scripts(cluster, failures)
cluster_info.failures = json.dumps(failures)
if len(failures) > 0:
cluster_info.success = 0
return cluster_info
if cluster.policy_id is not None:
failures.extend(self._check_cluster_policy(cluster.policy_id, source))
if cluster.init_scripts is not None:
failures.extend(self._check_cluster_init_script(cluster.init_scripts, source))

return failures


class ClustersCrawler(CrawlerBase[ClusterInfo], ClustersMixin):
class ClustersCrawler(CrawlerBase[ClusterInfo], CheckClusterMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "clusters", ClusterInfo)
self._ws = ws
Expand All @@ -110,7 +133,18 @@ def _assess_clusters(self, all_clusters):
f"Cluster {cluster.cluster_id} have Unknown creator, it means that the original creator "
f"has been deleted and should be re-created"
)
yield self._check_cluster_failures(cluster)
cluster_info = ClusterInfo(
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
cluster_name=cluster.cluster_name,
creator=cluster.creator_user_name,
success=1,
failures="[]",
)
failures = self.check_cluster_failures(cluster, "cluster")
if len(failures) > 0:
cluster_info.success = 0
cluster_info.failures = json.dumps(failures)
yield cluster_info

def snapshot(self) -> Iterable[ClusterInfo]:
return self._snapshot(self._try_fetch, self._crawl)
Expand Down
24 changes: 0 additions & 24 deletions src/databricks/labs/ucx/assessment/crawlers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import base64
import logging
import re

from databricks.sdk.errors import NotFound

logger = logging.getLogger(__name__)

INCOMPATIBLE_SPARK_CONFIG_KEYS = [
Expand All @@ -27,27 +24,6 @@
_INIT_SCRIPT_DBFS_PATH = 2


def _get_init_script_data(w, init_script_info):
if init_script_info.dbfs:
if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
if file_api_format_destination:
try:
data = w.dbfs.read(file_api_format_destination).data
return base64.b64decode(data).decode("utf-8")
except NotFound:
return None
if init_script_info.workspace:
workspace_file_destination = init_script_info.workspace.destination
if workspace_file_destination:
try:
data = w.workspace.export(workspace_file_destination).content
return base64.b64decode(data).decode("utf-8")
except NotFound:
return None
return None


def _azure_sp_conf_in_init_scripts(init_script_data: str) -> bool:
for conf in _AZURE_SP_CONF:
if re.search(conf, init_script_data):
Expand Down
23 changes: 18 additions & 5 deletions src/databricks/labs/ucx/assessment/init_scripts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

Expand All @@ -8,10 +9,11 @@
from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_azure_sp_conf_in_init_scripts,
logger,
)
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class GlobalInitScriptInfo:
Expand All @@ -23,7 +25,19 @@ class GlobalInitScriptInfo:
enabled: bool | None = None


class GlobalInitScriptCrawler(CrawlerBase[GlobalInitScriptInfo]):
class CheckInitScriptMixin:
_ws: WorkspaceClient

def check_init_script(self, init_script_data: str | None, source: str) -> list[str]:
failures: list[str] = []
if not init_script_data:
return failures
if _azure_sp_conf_in_init_scripts(init_script_data):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures


class GlobalInitScriptCrawler(CrawlerBase[GlobalInitScriptInfo], CheckInitScriptMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "global_init_scripts", GlobalInitScriptInfo)
self._ws = ws
Expand Down Expand Up @@ -52,9 +66,8 @@ def _assess_global_init_scripts(self, all_global_init_scripts):
global_init_script = base64.b64decode(script.script).decode("utf-8")
if not global_init_script:
continue
if _azure_sp_conf_in_init_scripts(global_init_script):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} global init script.")
global_init_script_info.failures = json.dumps(failures)
failures.extend(self.check_init_script(global_init_script, "global init script"))
global_init_script_info.failures = json.dumps(failures)

if len(failures) > 0:
global_init_script_info.success = 0
Expand Down
15 changes: 8 additions & 7 deletions src/databricks/labs/ucx/assessment/jobs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.compute import ClusterDetails
from databricks.sdk.service.jobs import BaseJob

from databricks.labs.ucx.assessment.clusters import ClustersMixin
from databricks.labs.ucx.assessment.crawlers import logger
from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class JobInfo:
Expand All @@ -20,7 +22,7 @@ class JobInfo:
creator: str | None = None


class JobsMixin(ClustersMixin):
class JobsMixin:
@staticmethod
def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id):
for j in all_jobs:
Expand All @@ -44,7 +46,7 @@ def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id):
yield j, t.new_cluster


class JobsCrawler(CrawlerBase[JobInfo], JobsMixin):
class JobsCrawler(CrawlerBase[JobInfo], JobsMixin, CheckClusterMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "jobs", JobInfo)
self._ws = ws
Expand Down Expand Up @@ -86,9 +88,8 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> Iterable[
if not job_id:
continue
cluster_details = ClusterDetails.from_dict(cluster_config.as_dict())
cluster_failures = self._check_cluster_failures(cluster_details)
for failure in json.loads(cluster_failures.failures):
job_assessment[job_id].add(failure)
cluster_failures = self.check_cluster_failures(cluster_details, "Job cluster")
job_assessment[job_id].update(cluster_failures)

# TODO: next person looking at this - rewrite, as this code makes no sense
for job_key in job_details.keys(): # pylint: disable=consider-using-dict-items,consider-iterating-dictionary
Expand Down
14 changes: 6 additions & 8 deletions src/databricks/labs/ucx/assessment/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_azure_sp_conf_present_check,
logger,
)
from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class PipelineInfo:
Expand All @@ -21,7 +20,7 @@ class PipelineInfo:
creator_name: str | None = None


class PipelinesCrawler(CrawlerBase[PipelineInfo]):
class PipelinesCrawler(CrawlerBase[PipelineInfo], CheckClusterMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "pipelines", PipelineInfo)
self._ws = ws
Expand Down Expand Up @@ -50,8 +49,7 @@ def _assess_pipelines(self, all_pipelines) -> Iterable[PipelineInfo]:
assert pipeline_response.spec is not None
pipeline_config = pipeline_response.spec.configuration
if pipeline_config:
if _azure_sp_conf_present_check(pipeline_config):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} pipeline.")
failures.extend(self.check_spark_conf(pipeline_config, "pipeline"))

pipeline_info.failures = json.dumps(failures)
if len(failures) > 0:
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/assessment/clusters/job-source-cluster.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[
{
"autoscale": {
"max_workers": 6,
"min_workers": 1
},
"cluster_source": "JOB",
"creator_user_name":"[email protected]",
"cluster_id": "0123-190044-1122334422",
"cluster_name": "Single User Cluster Name",
"policy_id": "single-user-with-spn",
"spark_version": "9.3.x-cpu-ml-scala2.12",
"spark_conf" : {
"spark.databricks.delta.preview.enabled": "true"
},
"spark_context_id":"5134472582179565315"
},
{
"autoscale": {
"max_workers": 6,
"min_workers": 1
},
"creator_user_name":"[email protected]",
"cluster_id": "0123-190044-1122334411",
"cluster_name": "Single User Cluster Name",
"policy_id": "azure-oauth",
"spark_version": "13.3.x-cpu-ml-scala2.12",
"spark_conf" : {
"spark.databricks.delta.preview.enabled": "true"
},
"spark_context_id":"5134472582179565315"
}
]
Loading

0 comments on commit 8fded13

Please sign in to comment.