Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract command codes and unify the checks for spark_conf, cluster_policy, init_scripts #855

Merged
merged 9 commits into from
Jan 30, 2024
130 changes: 82 additions & 48 deletions src/databricks/labs/ucx/assessment/clusters.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import base64
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.service.compute import ClusterDetails, ClusterSource, Policy
from databricks.sdk.service.compute import (
ClusterDetails,
ClusterSource,
InitScriptInfo,
Policy,
)

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_INIT_SCRIPT_DBFS_PATH,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constant can live in the init script mixin

INCOMPATIBLE_SPARK_CONFIG_KEYS,
_azure_sp_conf_in_init_scripts,
_azure_sp_conf_present_check,
_get_init_script_data,
logger,
spark_version_compatibility,
)
from databricks.labs.ucx.assessment.init_scripts import CheckInitScriptMixin
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class ClusterInfo:
Expand All @@ -27,7 +35,7 @@
creator: str | None = None


class ClustersMixin:
class CheckClusterMixin(CheckInitScriptMixin):
_ws: WorkspaceClient

def _safe_get_cluster_policy(self, policy_id: str) -> Policy | None:
Expand All @@ -37,62 +45,77 @@
logger.warning(f"The cluster policy was deleted: {policy_id}")
return None

def _check_spark_conf(self, cluster, failures):
def _check_cluster_policy(self, policy_id: str, source: str) -> list[str]:
failures: list[str] = []
policy = self._safe_get_cluster_policy(policy_id)
if policy:
if policy.definition:
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures

def _get_init_script_data(self, init_script_info: InitScriptInfo) -> str | None:
if init_script_info.dbfs is not None and init_script_info.dbfs.destination is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nesting can be reduced by refactoring "if COND: 20 lines" into "if not COND: continue 20 lines"

if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
if file_api_format_destination:
try:
data = self._ws.dbfs.read(file_api_format_destination).data
if data is not None:
return base64.b64decode(data).decode("utf-8")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's 9 levels of nesting. Can you reduce the nesting, so that's a bit more readable?

except NotFound:
return None

Check warning on line 70 in src/databricks/labs/ucx/assessment/clusters.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/ucx/assessment/clusters.py#L69-L70

Added lines #L69 - L70 were not covered by tests
if init_script_info.workspace is not None and init_script_info.workspace.destination is not None:
workspace_file_destination = init_script_info.workspace.destination
try:
data = self._ws.workspace.export(workspace_file_destination).content
if data is not None:
return base64.b64decode(data).decode("utf-8")
except NotFound:
return None

Check warning on line 78 in src/databricks/labs/ucx/assessment/clusters.py

View check run for this annotation

Codecov / codecov/patch

src/databricks/labs/ucx/assessment/clusters.py#L77-L78

Added lines #L77 - L78 were not covered by tests
return None

def _check_cluster_init_script(self, init_scripts: list[InitScriptInfo], source: str) -> list[str]:
failures: list[str] = []
for init_script_info in init_scripts:
init_script_data = self._get_init_script_data(init_script_info)
failures.extend(self.check_init_script(init_script_data, source))
return failures

def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
failures: list[str] = []
for k in INCOMPATIBLE_SPARK_CONFIG_KEYS:
if k in cluster.spark_conf:
if k in conf:
failures.append(f"unsupported config: {k}")
for value in cluster.spark_conf.values():
for value in conf.values():
if "dbfs:/mnt" in value or "/dbfs/mnt" in value:
failures.append(f"using DBFS mount in configuration: {value}")
# Checking if Azure cluster config is present in spark config
if _azure_sp_conf_present_check(cluster.spark_conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
if _azure_sp_conf_present_check(conf):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures

def _check_cluster_policy(self, cluster, failures):
policy = self._safe_get_cluster_policy(cluster.policy_id)
if policy:
if policy.definition:
if _azure_sp_conf_present_check(json.loads(policy.definition)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
if policy.policy_family_definition_overrides:
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
failures: list[str] = []

def _check_init_scripts(self, cluster, failures):
for init_script_info in cluster.init_scripts:
init_script_data = _get_init_script_data(self._ws, init_script_info)
if not init_script_data:
continue
if not _azure_sp_conf_in_init_scripts(init_script_data):
continue
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")

def _check_cluster_failures(self, cluster: ClusterDetails):
failures = []
cluster_info = ClusterInfo(
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
cluster_name=cluster.cluster_name,
creator=cluster.creator_user_name,
success=1,
failures="[]",
)
support_status = spark_version_compatibility(cluster.spark_version)
if support_status != "supported":
failures.append(f"not supported DBR: {cluster.spark_version}")
if cluster.spark_conf is not None:
self._check_spark_conf(cluster, failures)
failures.extend(self.check_spark_conf(cluster.spark_conf, source))
# Checking if Azure cluster config is present in cluster policies
if cluster.policy_id:
self._check_cluster_policy(cluster, failures)
if cluster.init_scripts:
self._check_init_scripts(cluster, failures)
cluster_info.failures = json.dumps(failures)
if len(failures) > 0:
cluster_info.success = 0
return cluster_info
if cluster.policy_id is not None:
failures.extend(self._check_cluster_policy(cluster.policy_id, source))
if cluster.init_scripts is not None:
failures.extend(self._check_cluster_init_script(cluster.init_scripts, source))

return failures


class ClustersCrawler(CrawlerBase[ClusterInfo], ClustersMixin):
class ClustersCrawler(CrawlerBase[ClusterInfo], CheckClusterMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "clusters", ClusterInfo)
self._ws = ws
Expand All @@ -110,7 +133,18 @@
f"Cluster {cluster.cluster_id} have Unknown creator, it means that the original creator "
f"has been deleted and should be re-created"
)
yield self._check_cluster_failures(cluster)
cluster_info = ClusterInfo(
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
cluster_name=cluster.cluster_name,
creator=cluster.creator_user_name,
success=1,
failures="[]",
)
failures = self.check_cluster_failures(cluster, "cluster")
if len(failures) > 0:
cluster_info.success = 0
cluster_info.failures = json.dumps(failures)
yield cluster_info

def snapshot(self) -> Iterable[ClusterInfo]:
return self._snapshot(self._try_fetch, self._crawl)
Expand Down
24 changes: 0 additions & 24 deletions src/databricks/labs/ucx/assessment/crawlers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import base64
import logging
import re

from databricks.sdk.errors import NotFound

logger = logging.getLogger(__name__)

INCOMPATIBLE_SPARK_CONFIG_KEYS = [
Expand All @@ -27,27 +24,6 @@
_INIT_SCRIPT_DBFS_PATH = 2


def _get_init_script_data(w, init_script_info):
if init_script_info.dbfs:
if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
if file_api_format_destination:
try:
data = w.dbfs.read(file_api_format_destination).data
return base64.b64decode(data).decode("utf-8")
except NotFound:
return None
if init_script_info.workspace:
workspace_file_destination = init_script_info.workspace.destination
if workspace_file_destination:
try:
data = w.workspace.export(workspace_file_destination).content
return base64.b64decode(data).decode("utf-8")
except NotFound:
return None
return None


def _azure_sp_conf_in_init_scripts(init_script_data: str) -> bool:
for conf in _AZURE_SP_CONF:
if re.search(conf, init_script_data):
Expand Down
23 changes: 18 additions & 5 deletions src/databricks/labs/ucx/assessment/init_scripts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

Expand All @@ -8,10 +9,11 @@
from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_azure_sp_conf_in_init_scripts,
logger,
)
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class GlobalInitScriptInfo:
Expand All @@ -23,7 +25,19 @@ class GlobalInitScriptInfo:
enabled: bool | None = None


class GlobalInitScriptCrawler(CrawlerBase[GlobalInitScriptInfo]):
class CheckInitScriptMixin:
_ws: WorkspaceClient

def check_init_script(self, init_script_data: str | None, source: str) -> list[str]:
failures: list[str] = []
if not init_script_data:
return failures
if _azure_sp_conf_in_init_scripts(init_script_data):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
return failures


class GlobalInitScriptCrawler(CrawlerBase[GlobalInitScriptInfo], CheckInitScriptMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "global_init_scripts", GlobalInitScriptInfo)
self._ws = ws
Expand Down Expand Up @@ -52,9 +66,8 @@ def _assess_global_init_scripts(self, all_global_init_scripts):
global_init_script = base64.b64decode(script.script).decode("utf-8")
if not global_init_script:
continue
if _azure_sp_conf_in_init_scripts(global_init_script):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} global init script.")
global_init_script_info.failures = json.dumps(failures)
failures.extend(self.check_init_script(global_init_script, "global init script"))
global_init_script_info.failures = json.dumps(failures)

if len(failures) > 0:
global_init_script_info.success = 0
Expand Down
15 changes: 8 additions & 7 deletions src/databricks/labs/ucx/assessment/jobs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.compute import ClusterDetails
from databricks.sdk.service.jobs import BaseJob

from databricks.labs.ucx.assessment.clusters import ClustersMixin
from databricks.labs.ucx.assessment.crawlers import logger
from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class JobInfo:
Expand All @@ -20,7 +22,7 @@ class JobInfo:
creator: str | None = None


class JobsMixin(ClustersMixin):
class JobsMixin:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, shouldn't it be named CheckJobsMixin?

@staticmethod
def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id):
for j in all_jobs:
Expand All @@ -44,7 +46,7 @@ def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id):
yield j, t.new_cluster


class JobsCrawler(CrawlerBase[JobInfo], JobsMixin):
class JobsCrawler(CrawlerBase[JobInfo], JobsMixin, CheckClusterMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "jobs", JobInfo)
self._ws = ws
Expand Down Expand Up @@ -86,9 +88,8 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> Iterable[
if not job_id:
continue
cluster_details = ClusterDetails.from_dict(cluster_config.as_dict())
cluster_failures = self._check_cluster_failures(cluster_details)
for failure in json.loads(cluster_failures.failures):
job_assessment[job_id].add(failure)
cluster_failures = self.check_cluster_failures(cluster_details, "Job cluster")
job_assessment[job_id].update(cluster_failures)

# TODO: next person looking at this - rewrite, as this code makes no sense
for job_key in job_details.keys(): # pylint: disable=consider-using-dict-items,consider-iterating-dictionary
Expand Down
14 changes: 6 additions & 8 deletions src/databricks/labs/ucx/assessment/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import json
import logging
from collections.abc import Iterable
from dataclasses import dataclass

from databricks.sdk import WorkspaceClient

from databricks.labs.ucx.assessment.crawlers import (
_AZURE_SP_CONF_FAILURE_MSG,
_azure_sp_conf_present_check,
logger,
)
from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend

logger = logging.getLogger(__name__)


@dataclass
class PipelineInfo:
Expand All @@ -21,7 +20,7 @@ class PipelineInfo:
creator_name: str | None = None


class PipelinesCrawler(CrawlerBase[PipelineInfo]):
class PipelinesCrawler(CrawlerBase[PipelineInfo], CheckClusterMixin):
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
super().__init__(sbe, "hive_metastore", schema, "pipelines", PipelineInfo)
self._ws = ws
Expand Down Expand Up @@ -50,8 +49,7 @@ def _assess_pipelines(self, all_pipelines) -> Iterable[PipelineInfo]:
assert pipeline_response.spec is not None
pipeline_config = pipeline_response.spec.configuration
if pipeline_config:
if _azure_sp_conf_present_check(pipeline_config):
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} pipeline.")
failures.extend(self.check_spark_conf(pipeline_config, "pipeline"))

pipeline_info.failures = json.dumps(failures)
if len(failures) > 0:
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/assessment/clusters/job-source-cluster.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[
{
"autoscale": {
"max_workers": 6,
"min_workers": 1
},
"cluster_source": "JOB",
"creator_user_name":"[email protected]",
"cluster_id": "0123-190044-1122334422",
"cluster_name": "Single User Cluster Name",
"policy_id": "single-user-with-spn",
"spark_version": "9.3.x-cpu-ml-scala2.12",
"spark_conf" : {
"spark.databricks.delta.preview.enabled": "true"
},
"spark_context_id":"5134472582179565315"
},
{
"autoscale": {
"max_workers": 6,
"min_workers": 1
},
"creator_user_name":"[email protected]",
"cluster_id": "0123-190044-1122334411",
"cluster_name": "Single User Cluster Name",
"policy_id": "azure-oauth",
"spark_version": "13.3.x-cpu-ml-scala2.12",
"spark_conf" : {
"spark.databricks.delta.preview.enabled": "true"
},
"spark_context_id":"5134472582179565315"
}
]
Loading
Loading